Introdução ao Trillium (v6e)
O v6e é usado para se referir ao Trillium nesta documentação, na API TPU e nos registros. O v6e representa a 6ª geração de TPU do Google.
Com 256 chips por pod, a arquitetura v6e compartilha muitas semelhanças com a v5e. Esse sistema é otimizado para treinamento, ajuste fino e veiculação de transformadores, redes neurais convolucionais (CNNs) e conversão de texto em imagem.
Consulte o documento v6e para informações sobre a arquitetura e as configurações do sistema v6e.
Este documento introdutório se concentra nos processos de treinamento e fornecimento de modelos usando os frameworks JAX, PyTorch ou TensorFlow. Com cada framework, é possível provisionar TPUs usando recursos em fila ou o Google Kubernetes Engine (GKE). A configuração do GKE pode ser feita usando comandos do XPK ou do GKE.
Procedimento geral para treinar ou exibir um modelo usando a v6e
- Preparar um Google Cloud projeto
- Capacidade segura
- Configurar o ambiente da TPU
- Provisionar o ambiente do Cloud TPU
- Executar uma carga de trabalho de treinamento ou inferência de modelo
- Limpeza
Preparar um Google Cloud projeto
- Faça login na sua Conta do Google. Se ainda não tiver uma, crie uma nova conta.
- No console do Google Cloud, selecione ou crie um projeto do Cloud na página do seletor de projetos.
- Ative o faturamento para seu projeto do Google Cloud. O faturamento é obrigatório para todo o uso do Google Cloud.
- Instale os componentes da gcloud alfa.
Execute o comando a seguir para instalar a versão mais recente dos componentes
gcloud
.gcloud components update
Ative a API TPU usando o comando
gcloud
a seguir no Cloud Shell. Também é possível ativá-la no console do Google Cloud.gcloud services enable tpu.googleapis.com
Ativar permissões com a conta de serviço do TPU para a API Compute Engine
As contas de serviço permitem que o serviço do Cloud TPU acesse outros serviços do Google Cloud. Uma conta de serviço gerenciado pelo usuário é uma prática recomendada do Google Cloud. Siga estes guias para criar e conceder funções. Os seguintes papéis são necessários:
- Administrador da TPU
- Administrador de armazenamento
- Gravador de registros
- Gravador de métricas do Monitoring
a. Configure as permissões do XPK com sua conta de usuário para o GKE: XPK.
Autentique com sua conta do Google e defina o ID e a zona padrão do projeto.
auth login
autoriza o gcloud a acessar Google Cloud com as credenciais de usuário do Google.
PROJECT_ID
é o nome do Google Cloud projeto.
ZONE
é a zona em que você quer criar o TPU.gcloud auth login gcloud config set project ${PROJECT_ID} gcloud config set compute/zone ${ZONE}
Crie uma identidade de serviço para a VM da TPU.
gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
Capacidade segura
Entre em contato com o suporte de vendas/conta do Cloud TPU para solicitar a cota da TPU e tirar dúvidas sobre a capacidade.
Provisionar o ambiente do Cloud TPU
As TPUs v6e podem ser provisionadas e gerenciadas com o GKE, com o GKE e o XPK (uma ferramenta de CLI de wrapper sobre o GKE) ou como recursos em fila.
Pré-requisitos
- Verifique se o projeto tem cota de
TPUS_PER_TPU_FAMILY
suficiente, que especifica o número máximo de chips que você pode acessar no projetoGoogle Cloud . - O v6e foi testado com a seguinte configuração:
- Python
3.10
ou mais recente - Versões noturnas do software:
0.4.32.dev20240912
noturno JAX- LibTPU noturno
0.1.dev20240912+nightly
- Versões estáveis do software:
- JAX + JAX Lib da v0.4.37
- Python
Verifique se o projeto tem cota de TPU suficiente para:
- Cota da VM de TPU
- Quota de endereço IP
Cota do Hyperdisk equilibrado
Permissões de projeto do usuário
- Se você estiver usando o GKE com XPK, consulte Permissões do console do Google Cloud na conta de usuário ou de serviço para conferir as permissões necessárias para executar o XPK.
Variáveis de ambiente
No Cloud Shell, crie as seguintes variáveis de ambiente:
export NODE_ID=TPU_NODE_ID # TPU name export PROJECT_ID=PROJECT_ID export ACCELERATOR_TYPE=v6e-16 export ZONE=us-east1-d export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID export VALID_DURATION=VALID_DURATION # Additional environment variable needed for provisioning Multislice: export NUM_SLICES=NUM_SLICES # Use a custom network for better performance as well as to avoid having the default network becoming overloaded. export NETWORK_NAME=${PROJECT_ID}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
Descrições de sinalizações de comando
Variável | Descrição |
NODE_ID | O ID atribuído pelo usuário do TPU, que é criado quando a solicitação de recurso em fila é alocada. |
PROJECT_ID | Google Cloud Nome do projeto. Use um projeto atual ou crie um novo em |
ZONA | Consulte o documento Regiões e zonas de TPU para conferir as zonas compatíveis. |
ACCELERATOR_TYPE | Consulte Tipos de acelerador. |
RUNTIME_VERSION | v2-alpha-tpuv6e
|
SERVICE_ACCOUNT | Esse é o endereço de e-mail da sua conta de serviço, que pode ser encontrado em
Console do Google Cloud -> IAM -> Contas de serviço
Por exemplo: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com |
NUM_SLICES | O número de fatias a serem criadas (somente para fatias múltiplas). |
QUEUED_RESOURCE_ID | O ID de texto atribuído pelo usuário da solicitação de recurso em fila. |
VALID_DURATION | O período em que a solicitação de recurso em fila é válida. |
NETWORK_NAME | O nome de uma rede secundária a ser usada. |
NETWORK_FW_NAME | O nome de um firewall de rede secundário a ser usado. |
Otimizações de desempenho da rede
Para ter o melhor desempenho,use uma rede com 8.896 MTU (unidade máxima de transmissão).
Por padrão, uma nuvem privada virtual (VPC) fornece apenas uma MTU de 1.460 bytes,o que vai gerar um desempenho de rede subótimo. É possível definir a MTU de uma rede VPC como qualquer valor entre 1.300 e 8.896 bytes (inclusivo). Os tamanhos comuns de MTU personalizados são 1.500 bytes (Ethernet padrão) ou 8.896 bytes (o máximo possível). Para mais informações, consulte Tamanhos válidos de MTU da rede VPC.
Para mais informações sobre como mudar a configuração de MTU de uma rede existente ou padrão, consulte Alterar a configuração de MTU de uma rede VPC.
O exemplo a seguir cria uma rede com 8.896 MTU.
export RESOURCE_NAME=RESOURCE_NAME export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall export PROJECT=X gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \ --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} --allow tcp,icmp,udp --project=${PROJECT}
Como usar a multi-NIC (opção para multislice)
As variáveis de ambiente a seguir são necessárias para uma sub-rede secundária quando você está usando um ambiente de várias fatias.
export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2
Use os comandos a seguir para criar um roteamento IP personalizado para a rede e a sub-rede.
gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
--bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
--network="${NETWORK_NAME_2}" \
--range=10.10.0.0/18 --region="${REGION}" \
--project=$PROJECT
gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
--network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
--source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
gcloud compute routers create "${ROUTER_NAME}" \
--project="${PROJECT_ID}" \
--network="${NETWORK_NAME_2}" \
--region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \
--router="${ROUTER_NAME}" \
--region="${REGION}" \
--auto-allocate-nat-external-ips \
--nat-all-subnet-ip-ranges \
--project="${PROJECT_ID}" \
--enable-logging
Depois que uma fatia de várias redes for criada, você poderá validar se
as duas NICs estão sendo usadas configurando um cluster XPK e executando --command ifconfig
como parte
da carga de trabalho XPK.
Use o comando xpk workload
a seguir para mostrar a saída do comando ifconfig
nos registros do console do Google Cloud e verificar se eth0 e eth1 têm mtu=8896.
python3 xpk.py workload create \ --cluster CLUSTER_NAME \ (--base-docker-image maxtext_base_image|--docker-image CLOUD_IMAGE_NAME \ --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone $ZONE \ --project $PROJECT_ID \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command "ifconfig"
Verifique se eth0 e eth1 têm mtu=8,896. Uma maneira de verificar se a multi-NIC está em execução é executar o comando --command "ifconfig" como parte da carga de trabalho do XPK. Em seguida, analise a saída impressa dessa carga de trabalho xpk nos registros do console do Cloud e verifique se eth0 e eth1 têm mtu=8896.
Configurações TCP aprimoradas
Para TPUs criadas usando a interface de recursos em fila, execute o comando a seguir para melhorar o desempenho da rede aumentando os limites do buffer de recebimento do TCP.
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "$PROJECT" \ --zone "$ZONE" \ --node=all \ --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \ --worker=all
Provisionamento com recursos na fila
A capacidade alocada pode ser provisionada usando o comando create
de recursos em fila.
Crie uma solicitação de recurso em fila da TPU.
A flag
--reserved
é necessária apenas para recursos reservados, não para recursos sob demanda.gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --valid-until-duration ${VALID_DURATION} \ --service-account ${SERVICE_ACCOUNT} \ [--reserved] # The following flags are only needed if you are using Multislice. --node-count node-count # Number of slices in a Multislice \ --node-prefix node-prefix # An optional user-defined node prefix; the default is QUEUED_RESOURCE_ID.
Se a solicitação de recurso em fila for criada, o estado no campo "response" será "WAITING_FOR_RESOURCES" ou "FAILED". Se a solicitação de recurso enfileirado estiver no estado "WAITING_FOR_RESOURCES", o recurso foi adicionado à fila e será provisionado quando houver capacidade de TPU alocada suficiente. Se a solicitação de recurso em fila estiver no estado "FAILED", o motivo da falha vai aparecer na saída. A solicitação de recurso em fila vai expirar se um v6e não for provisionado dentro da duração especificada e o estado se tornar "FAILED". Consulte a documentação pública Recursos em fila para mais informações.
Quando a solicitação de recurso enfileirado está no estado "ATIVO", é possível se conectar às VMs de TPU usando o SSH. Use os comandos
list
oudescribe
para consultar o status do recurso em fila.gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE}
Quando o recurso na fila está no estado "ATIVO", a saída é semelhante a esta:
state: state: ACTIVE
Gerenciar VMs TPU. Para opções de gerenciamento de VMs de TPU, consulte Gerenciar VMs de TPU.
Conectar-se às VMs do TPU usando SSH
É possível instalar binários em cada VM de TPU na fatia de TPU e executar o código. Consulte a seção Tipos de VM para determinar quantas VMs seu slice terá.
Para instalar os binários ou executar o código, use o SSH para se conectar a uma VM usando o comando
tpu-vm ssh
.gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --node=all # add this flag if you are using Multislice
Para usar o SSH para se conectar a uma VM específica, use a flag
--worker
que segue um índice baseado em 0:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
Se você tiver formas de fatia maiores que 8 chips, terá várias VMs em uma fatia. Nesse caso, use os parâmetros
--worker=all
e--command
no comandogcloud alpha compute tpus tpu-vm ssh
para executar um comando em todas as VMs simultaneamente. Exemplo:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \ --zone ${ZONE} --worker=all \ --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Excluir um recurso na fila
Exclua um recurso na fila no fim da sessão ou remova solicitações de recursos na fila que estejam no estado "FAILED". Para excluir um recurso na fila, exclua a fatia e, em seguida, a solicitação de recurso na fila em duas etapas:
gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \ --zone=${ZONE} --quiet gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet
gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
Provisionamento de TPUs v6e com o GKE ou XPK
Se você estiver usando comandos do GKE com o v6e, poderá usar comandos do Kubernetes ou o XPK para provisionar TPUs e treinar ou oferecer modelos. Consulte Planejar TPUs no GKE para saber como planejar as configurações de TPU nos clusters do GKE. As seções a seguir fornecem comandos para criar um cluster XPK com suporte a uma NIC e suporte a várias NICs.
Comandos para criar um cluster XPK com suporte a uma NIC
export CLUSTER_NAME xpk-cluster-name export ZONE=us-central2-b export PROJECT=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME=${CLUSTER_NAME}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT} \ --subnet-mode=auto \ --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \ --network ${NETWORK_NAME} \ --allow tcp,icmp,udp \ --project=${PROJECT}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster $CLUSTER_NAME \ --cluster-cpu-machine-type=n1-standard-8 \ --num-slices=$NUM_SLICES \ --tpu-type=$TPU_TYPE \ --zone=$ZONE \ --project=$PROJECT \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --create-vertex-tensorboard
Descrições de sinalizações de comando
Variável | Descrição |
CLUSTER_NAME | O nome atribuído pelo usuário ao cluster XPK. |
PROJECT_ID | Google Cloud Nome do projeto. Use um projeto atual ou crie um novo em |
ZONA | Consulte o documento Regiões e zonas de TPU para conferir as zonas compatíveis. |
TPU_TYPE | Consulte Tipos de acelerador. |
NUM_SLICES | O número de fatias que você quer criar |
CLUSTER_ARGUMENTS | A rede e a sub-rede a serem usadas.
Por exemplo: "--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}" |
NUM_SLICES | O número de fatias a serem criadas. |
NETWORK_NAME | O nome de uma rede secundária a ser usada. |
NETWORK_FW_NAME | O nome de um firewall de rede secundário a ser usado. |
Comandos para criar um cluster XPK com suporte a várias NICs
export CLUSTER_NAME xpk-cluster-name export ZONE=us-central2-b export PROJECT=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE} export exportSUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE} export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE} export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE} export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE} export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create "${NETWORK_NAME_1}" \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_1}" \ --network="${NETWORK_NAME_1}" \ --range=10.11.0.0/18 \ --region="${REGION}" \ --project=$PROJECT
gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \ --network "${NETWORK_NAME_1}" \ --allow tcp,icmp,udp \ --project="${PROJECT}"
gcloud compute routers create "${ROUTER_NAME}" \ --project="${PROJECT}" \ --network="${NETWORK_NAME_1}" \ --region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \ --router="${ROUTER_NAME}" \ --region="${REGION}" \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project="${PROJECT}" \ --enable-logging
Secondary subnet for multi-nic experience. Need custom ip routing to be different from the first network's subnet.
export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create "${NETWORK_NAME_2}" \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_2}" \ --network="${NETWORK_NAME_2}" \ --range=10.10.0.0/18 \ --region="${REGION}" \ --project=$PROJECT
gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \ --network "${NETWORK_NAME_2}" \ --allow tcp,icmp,udp \ --project="${PROJECT}"
gcloud compute routers create "${ROUTER_NAME}" \ --project="${PROJECT}" \ --network="${NETWORK_NAME_2}" \ --region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \ --router="${ROUTER_NAME}" \ --region="${REGION}" \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project="${PROJECT}" \ --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking
--network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network
network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 ~/xpk/xpk.py cluster create \
--cluster $CLUSTER_NAME \
--num-slices=$NUM_SLICES \
--tpu-type=$TPU_TYPE \
--zone=$ZONE \
--project=$PROJECT \
--on-demand \
--custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
--custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
--create-vertex-tensorboard
Descrições de sinalizações de comando
Variável | Descrição |
CLUSTER_NAME | O nome atribuído pelo usuário ao cluster XPK. |
PROJECT_ID | Google Cloud Nome do projeto. Use um projeto atual ou crie um novo em |
ZONA | Consulte o documento Regiões e zonas de TPU para conferir as zonas compatíveis. |
TPU_TYPE | Consulte Tipos de acelerador. |
NUM_SLICES | O número de fatias que você quer criar |
CLUSTER_ARGUMENTS | A rede e a sub-rede a serem usadas.
Por exemplo: "--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}" |
NODE_POOL_ARGUMENTS | Rede de nó adicional a ser usada.
Por exemplo: "--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}" |
NUM_SLICES | O número de fatias a serem criadas (somente para fatias múltiplas). |
NETWORK_NAME | O nome de uma rede secundária a ser usada. |
NETWORK_FW_NAME | O nome de um firewall de rede secundário a ser usado. |
Configuração do framework
Esta seção descreve o processo de configuração geral para treinamento de modelo de ML usando os frameworks JAX, PyTorch ou TensorFlow. É possível provisionar TPUs usando recursos em fila ou o GKE. A configuração do GKE pode ser feita usando comandos do XPK ou do Kubernetes.
Configuração para o JAX
Esta seção fornece exemplos de execução de cargas de trabalho do JAX no GKE, com ou sem XPK, além de usar recursos em fila.
Configurar o JAX usando o GKE
O exemplo a seguir configura um host único 2x2 usando um arquivo YAML do Kubernetes.
Uma fatia em um host
apiVersion: v1
kind: Pod
metadata:
name: tpu-pod-jax-v6e-a
spec:
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 2x2
containers:
- name: tpu-job
image: python:3.10
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Após a conclusão, você vai encontrar a seguinte mensagem no registro do GKE:
Total TPU chips: 4
Fração única em vários hosts
O exemplo a seguir configura um pool de nós de vários hosts 4x4 usando um arquivo YAML do Kubernetes.
apiVersion: v1
kind: Service
metadata:
name: headless-svc
spec:
clusterIP: None
selector:
job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
name: tpu-available-chips
spec:
backoffLimit: 0
completions: 4
parallelism: 4
completionMode: Indexed
template:
spec:
subdomain: headless-svc
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
containers:
- name: tpu-job
image: python:3.10
ports:
- containerPort: 8471 # Default port using which TPU VMs communicate
- containerPort: 8431 # Port to export TPU runtime metrics, if supported.
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Após a conclusão, você vai encontrar a seguinte mensagem no registro do GKE:
Total TPU chips: 16
Multislice em vários hosts
O exemplo a seguir configura dois pools de nós multi-host 4x4 usando um arquivo YAML do Kubernetes.
Como pré-requisito, é necessário instalar o JobSet v0.2.3 ou mais recente.
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: multislice-job
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
failurePolicy:
maxRestarts: 4
replicatedJobs:
- name: slice
replicas: 2
template:
spec:
parallelism: 4
completions: 4
backoffLimit: 0
template:
spec:
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
hostNetwork: true
containers:
- name: jax-tpu
image: python:3.10
ports:
- containerPort: 8471
- containerPort: 8080
- containerPort: 8431
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
limits:
google.com/tpu: 4
requests:
google.com/tpu: 4
Após a conclusão, você vai encontrar a seguinte mensagem no registro do GKE:
Total TPU chips: 32
Para mais informações, consulte Executar uma carga de trabalho com vários setores na documentação do GKE.
Para melhorar o desempenho, ative a hostNetwork.
Várias NICs
Para aproveitar o recurso de várias NICs no GKE, o manifesto do pod do Kubernetes precisa ter outras anotações. Confira a seguir um exemplo de manifesto de carga de trabalho multi-NIC sem TPU.
apiVersion: v1
kind: Pod
metadata:
name: sample-netdevice-pod-1
annotations:
networking.gke.io/default-interface: 'eth0'
networking.gke.io/interfaces: |
[
{"interfaceName":"eth0","network":"default"},
{"interfaceName":"eth1","network":"netdevice-network"}
]
spec:
containers:
- name: sample-netdevice-pod
image: busybox
command: ["sleep", "infinity"]
ports:
- containerPort: 80
restartPolicy: Always
tolerations:
- key: "google.com/tpu"
operator: "Exists"
effect: "NoSchedule"
Se você exec
no pod do Kubernetes, a NIC adicional vai aparecer
usando o código abaixo.
$ k exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
inet 127.0.0.1/8 scope host lo
valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
inet 172.24.0.4/32 scope global eth1
valid_lft forever preferred_lft forever
Configurar o JAX usando o GKE com XPK
Confira um exemplo no README do xpk.
Para configurar e executar o XPK com o MaxText, consulte: Como executar o MaxText.
Configurar o JAX usando recursos na fila
Instale o JAX em todas as VMs de TPU no seu slice ou em vários slices simultaneamente usando
gcloud alpha compute tpus tpu-vm ssh
. Para "Várias fatias", adicione --node=all
.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'
Você pode executar o código Python abaixo para verificar quantos núcleos de TPU estão disponíveis na sua fatia e testar se tudo está instalado corretamente. As saídas mostradas aqui foram produzidas com uma fatia v6e-16:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'
O resultado será assim:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4
jax.device_count() mostra o número total de chips na fatia especificada. jax.local_device_count() indica a contagem de chips acessíveis por uma única VM nessa fatia.
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
--command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
&& pip install -r requirements.txt && pip install . '
Solução de problemas de configurações do JAX
Uma dica geral é ativar o registro detalhado no manifesto de carga de trabalho do GKE. Em seguida, envie os registros para o suporte do GKE.
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0
Mensagens de erro
no endpoints available for service 'jobset-webhook-service'
Esse erro significa que o jobset não foi instalado corretamente. Verifique se os pods do Kubernetes da implantação jobset-controller-manager estão em execução. Para mais informações, consulte a documentação de solução de problemas do JobSet.
TPU initialization failed: Failed to connect
Verifique se a versão do nó do GKE é 1.30.4-gke.1348000 ou mais recente. Não há suporte para o GKE 1.31.
Configuração para PyTorch
Esta seção descreve como começar a usar o PJRT na v6e com o PyTorch/XLA. O Python 3.10 é a versão recomendada.
Configurar o PyTorch usando o GKE com o XPK
Você pode usar o contêiner do Docker abaixo com o XPK, que já tem as dependências do PyTorch instaladas:
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028
Para criar uma carga de trabalho XPK, use o seguinte comando:
python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -${NUM_SLICES} \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES} \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'
O uso de --base-docker-image
cria uma nova imagem do Docker com o diretório de trabalho
atual integrado ao novo Docker.
Configurar o PyTorch usando recursos enfileirados
Siga estas etapas para instalar o PyTorch usando recursos em fila e executar um pequeno script na v6e.
Instalar dependências usando SSH para acessar as VMs
Para o recurso Multislice, adicione --node=all
:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='sudo apt install -y libopenblas-base pip3 \
install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Melhorar o desempenho de modelos com alocações frequentes e de grande porte
Para modelos com alocações frequentes e grandes, observamos que o uso de
tcmalloc
melhora o desempenho significativamente em comparação com a
implementação padrão de malloc
.
Portanto, o malloc
padrão usado na VM de TPU é tcmalloc
. No entanto, dependendo da
carga de trabalho (por exemplo, com o DLRM, que tem alocações muito grandes para as
tabelas de incorporação), tcmalloc
pode causar uma lentidão. Nesse caso, tente
redefinir a seguinte variável usando o malloc
padrão:
unset LD_PRELOAD
Use um script Python para fazer um cálculo na VM v6e:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--project ${PROJECT_ID} \
--zone ${ZONE} --worker all --command='
unset LD_PRELOAD
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'
Isso gera um resultado semelhante ao seguinte:
SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
[-1.4656, 0.3196, -2.8766],
[ 0.8668, -1.5060, 0.7125]], device='xla:0')
Configuração para o TensorFlow
Para a prévia pública da v6e, somente a versão do ambiente de execução tf-nightly é compatível.
É possível redefinir tpu-runtime
com a versão do TensorFlow compatível com a v6e
executando os seguintes comandos:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
Use o SSH para acessar o worker-0:
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE}
Instale o TensorFlow no worker-0:
sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Exporte a variável de ambiente TPU_NAME
:
export TPU_NAME=v6e-16
Você pode executar o script Python abaixo para verificar quantos núcleos de TPU estão disponíveis na sua fatia e para testar se tudo está instalado corretamente. As saídas mostradas foram produzidas com uma fatia v6e-16:
import TensorFlow as tf
print("TensorFlow version " + tf.__version__)
@tf.function
def add_fn(x,y):
z = x + y
return z
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.TPUStrategy(cluster_resolver)
x = tf.constant(1.)
y = tf.constant(1.)
z = strategy.run(add_fn, args=(x,y))
print(z)
O resultado será assim:
PerReplica:{
0: tf.Tensor(2.0, shape=(), dtype=float32),
1: tf.Tensor(2.0, shape=(), dtype=float32),
2: tf.Tensor(2.0, shape=(), dtype=float32),
3: tf.Tensor(2.0, shape=(), dtype=float32),
4: tf.Tensor(2.0, shape=(), dtype=float32),
5: tf.Tensor(2.0, shape=(), dtype=float32),
6: tf.Tensor(2.0, shape=(), dtype=float32),
7: tf.Tensor(2.0, shape=(), dtype=float32)
}
v6e com SkyPilot
É possível usar a TPU v6e com o SkyPilot. Siga as etapas abaixo para adicionar informações de local/preço relacionadas à v6e ao SkyPilot.
Adicione o seguinte ao final de
~/.sky/catalogs/v5/gcp/vms.csv
:,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
Especifique os seguintes recursos em um arquivo YAML:
# tpu_v6.yaml resources: accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use accelerator_args: runtime_version: v2-alpha-tpuv6e # Official suggested runtime
Inicie um cluster com a TPU v6e:
sky launch tpu_v6.yaml -c tpu_v6
Conecte-se ao TPU v6e usando SSH:
ssh tpu_v6
Tutoriais de inferência
Os tutoriais a seguir mostram como executar a inferência na TPU v6e:
Exemplos de treinamento
As seções a seguir fornecem exemplos de treinamento de modelos MaxText, MaxDiffusion e PyTorch no TPU v6e.
Treinamento do MaxText e do MaxDiffusion na VM do Cloud TPU v6e
As seções a seguir abrangem o ciclo de vida de treinamento dos modelos MaxText e MaxDiffusion.
Em geral, as etapas gerais são:
- Crie a imagem de base da carga de trabalho.
- Execute a carga de trabalho usando o XPK.
- Crie o comando de treinamento para a carga de trabalho.
- Implante a carga de trabalho.
- Acompanhe a carga de trabalho e confira as métricas.
- Exclua a carga de trabalho do XPK se ela não for necessária.
- Exclua o cluster XPK quando ele não for mais necessário.
Criar imagem de base
Instale o MaxText ou o MaxDiffusion e crie a imagem do Docker:
Clone o repositório que você quer usar e mude para o diretório do repositório:
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
Configure o Docker para usar a Google Cloud CLI:
gcloud auth configure-docker
Crie a imagem do Docker usando o comando a seguir ou a pilha estável do JAX. Para mais informações sobre a pilha estável do JAX, consulte Criar uma imagem do Docker com a pilha estável do JAX.
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.37
Se você estiver iniciando a carga de trabalho em uma máquina que não tem a imagem criada localmente, faça o upload da imagem:
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
Criar uma imagem do Docker com a pilha estável do JAX
É possível criar as imagens do Docker MaxText e MaxDiffusion usando a imagem base da pilha Stable do JAX.
O JAX Stable Stack fornece um ambiente consistente para MaxText e MaxDiffusion
agrupando o JAX com pacotes principais, como orbax
, flax
e optax
, com
um libtpu.so bem qualificado que direciona utilitários de programa TPU e outras ferramentas essenciais. Essas bibliotecas são testadas para garantir a compatibilidade e fornecer uma base estável para criar e executar o MaxText e o MaxDiffusion. Isso elimina possíveis conflitos devido a versões de pacotes incompatíveis.
O JAX Stable Stack inclui um libtpu.so totalmente lançado e qualificado, a biblioteca principal que orienta a compilação, a execução e a configuração de rede ICI do programa TPU. A versão libtpu substitui o build noturno usado anteriormente pelo JAX e garante a funcionalidade consistente das computações XLA na TPU com testes de qualificação no nível do PJRT em HLO/StableHLO IRs.
Para criar a imagem do Docker MaxText e MaxDiffusion com a pilha estável do JAX, ao
executar o script docker_build_dependency_image.sh
, defina a variável MODE
como stable_stack
e a variável BASEIMAGE
como a imagem de base que você quer
usar.
O exemplo a seguir especifica
us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
como
a imagem base:
bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
Para uma lista de imagens de base do JAX Stable Stack disponíveis, consulte Imagens do JAX Stable Stack no Artifact Registry.
Executar a carga de trabalho usando o XPK
Defina as seguintes variáveis de ambiente se você não estiver usando os valores padrão definidos por MaxText ou MaxDiffusion:
export BASE_OUTPUT_DIR=gs://YOUR_BUCKET export PER_DEVICE_BATCH_SIZE=2 export NUM_STEPS=30 export MAX_TARGET_LENGTH=8192
Crie o script do modelo. Esse script será copiado como um comando de treinamento em uma etapa posterior.
Não execute o script do modelo ainda.
MaxText
O MaxText é um LLM de código aberto de alto desempenho e altamente escalonável escrito em Python e JAX puros e direcionado a Google Cloud TPUs e GPUs para treinamento e inferência.
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} # attention='dot_product'"
Gemma2
O Gemma é uma família de LLMs de pesos abertos desenvolvidos pelo Google DeepMind, com base na pesquisa e tecnologia do Gemini.
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
O Mixtral é um modelo de IA de última geração desenvolvido pela Mistral AI, que utiliza uma arquitetura de mistura de especialistas (MoE, na sigla em inglês) esparsa.
python3 MaxText/train.py MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
O Llama é uma família de LLMs de peso aberto desenvolvidos pela Meta.
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=llama3-8b \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ tokenizer_path=assets/tokenizer_llama3.tiktoken \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ attention=flash"
MaxDiffusion
O MaxDiffusion é uma coleção de implementações de referência de vários modelos de difusão latente escritos em Python puro e JAX que são executados em dispositivos XLA, incluindo TPUs e GPUs do Cloud. O Stable Diffusion é um modelo latente de texto para imagem que gera imagens fotorrealistas a partir de qualquer entrada de texto.
É necessário instalar uma ramificação específica do Git para executar o MaxDiffusion, conforme mostrado no comando
git checkout
a seguir.git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 && pip install -r requirements.txt && pip install .
Script de treinamento:
cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \ python src/maxdiffusion/train_sdxl.py \ src/maxdiffusion/configs/base_xl.yml \ revision=refs/pr/95 \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ resolution=1024 \ per_device_batch_size=1 \ output_dir=${OUT_DIR} \ jax_cache_dir=${OUT_DIR}/cache_dir/ \ max_train_steps=200 \ attention=flash run_name=sdxl-ddp-v6e
Execute o modelo usando o script criado na etapa anterior. É necessário especificar a flag
--base-docker-image
para usar a imagem base MaxText ou especificar a flag--docker-image
e a imagem que você quer usar.Opcional: é possível ativar o registro de depuração incluindo a flag
--enable-debug-logs
. Para mais informações, consulte Depurar o JAX no MaxText.Opcional: é possível criar um experimento da Vertex AI para fazer upload de dados para o TensorBoard da Vertex AI incluindo a flag
--use-vertex-tensorboard
. Para mais informações, consulte Monitorar JAX no MaxText usando a Vertex AI.python3 xpk.py workload create \ --cluster CLUSTER_NAME \ {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \ --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \ --tpu-type=$ACCELERATOR_TYPE \ --num-slices=$NUM_SLICES \ --on-demand \ --zone $ZONE \ --project $PROJECT_ID \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command $YOUR-MODEL-SCRIPT
Exporte as seguintes variáveis:
export ClUSTER_NAME=CLUSTER_NAME: o nome do cluster XPK. export ACCELERATOR_TYPEACCELERATOR_TYPE: a versão e o tamanho do TPU. Por exemplo,
v6e-256
. export NUM_SLICES=NUM_SLICES: o número de fatias de TPU. export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT: o script do modelo a ser executado como um comando de treinamento.A saída inclui um link para acompanhar sua carga de trabalho, semelhante a este:
[XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
Abra o link e clique na guia Registros para acompanhar sua carga de trabalho em tempo real.
Depurar o JAX no MaxText
Use comandos XPK complementares para diagnosticar por que o cluster ou a carga de trabalho não está em execução.
- Lista de cargas de trabalho do XPK
- Inspetor XPK
- Ative a geração de registros detalhada nos seus registros de carga de trabalho usando a flag
--enable-debug-logs
ao criar a carga de trabalho XPK.
Monitorar JAX no MaxText usando a Vertex AI
Acesse dados escalares e de perfil pelo TensorBoard gerenciado da Vertex AI.
- Aumente as solicitações de gerenciamento de recursos (CRUD) da zona que você está usando de 600 para 5.000. Isso pode não ser um problema para cargas de trabalho pequenas que usam menos de 16 VMs.
Instale dependências, como
cloud-accelerator-diagnostics
, para o Vertex AI:# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
Crie o cluster XPK usando a flag
--create-vertex-tensorboard
, conforme documentado em Criar o TensorBoard da Vertex AI. Também é possível executar esse comando em clusters atuais.Crie seu experimento da Vertex AI ao executar a carga de trabalho do XPK usando a flag
--use-vertex-tensorboard
e a flag opcional--experiment-name
. Para conferir a lista completa de etapas, consulte Criar um experimento da Vertex AI para fazer upload de dados para o TensorBoard da Vertex AI.
Os registros incluem um link para um TensorBoard da Vertex AI, semelhante a este:
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
Também é possível encontrar o link do TensorBoard da Vertex AI no console do Google Cloud. Acesse Experimentos da Vertex AI no console do Google Cloud. Selecione a região adequada no menu suspenso.
O diretório do TensorBoard também é gravado no bucket do Cloud Storage especificado com ${BASE_OUTPUT_DIR}
.
Excluir cargas de trabalho do XPK
Use o comando xpk workload delete
para
excluir uma ou mais cargas de trabalho com base no prefixo ou status do job. Esse comando pode ser útil se você enviou cargas de trabalho XPK que não precisam mais ser executadas ou se tiver trabalhos presos na fila.
Excluir cluster XPK
Use o comando xpk cluster delete
para excluir um cluster:
python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \ --zone $ZONE --project $PROJECT_ID
Treinamento de Llama e PyTorch/XLA na VM v6e da Cloud TPU
Este tutorial descreve como treinar modelos Llama usando PyTorch/XLA na TPU v6e usando o conjunto de dados WikiText.
Receber acesso ao Hugging Face e ao modelo Llama 3
Você precisa de um token de acesso do usuário do Hugging Face para executar este tutorial. Para informações sobre como criar e usar tokens de acesso do usuário, consulte a documentação do Hugging Face sobre tokens de acesso do usuário.
Você também precisa de permissão para acessar o modelo Llama 3 8B no Hugging Face. Para ter acesso, acesse o modelo Meta-Llama-3-8B no Hugging Face e solicite acesso.
Criar uma VM de TPU
Crie uma TPU v6e com 8 chips para executar o tutorial.
Configure as variáveis de ambiente:
export ACCELERATOR_TYPE=v6e-8 export VERSION=v2-alpha-tpuv6e export TPU_NAME=$USER-$ACCELERATOR_TYPE export PROJECT=YOUR_PROJECT export ZONE=YOUR_ZONE
Crie uma VM de TPU:
gcloud alpha compute tpus tpu-vm create $TPU_NAME --version=$VERSION \ --accelerator-type=$ACCELERATOR_TYPE --zone=$ZONE --project=$PROJECT
Instalação
Instale o pytorch-tpu/transformers
fork dos
transformadores e dependências do Hugging Face. Este tutorial foi testado com as
seguintes versões de dependência usadas neste exemplo:
torch
: compatível com 2.5.0torch_xla[tpu]
: compatível com 2.5.0jax
: 0.4.33jaxlib
: 0.4.33
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT --zone $ZONE \ --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers sudo pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.33 jaxlib==0.4.33 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Configurar as configurações do modelo
O comando de treinamento na próxima seção, Executar o modelo, usa dois arquivos de configuração JSON para definir parâmetros de modelo e a configuração de dados paralelos totalmente fragmentados (FSDP, na sigla em inglês). O sharding do FSDP é usado para que os pesos do modelo se ajustem a um tamanho de lote maior durante o treinamento. Ao treinar com modelos menores, pode ser suficiente usar o paralelismo de dados e replicar os pesos em cada dispositivo. Para mais informações sobre como dividir tensores em dispositivos no PyTorch/XLA, consulte o Guia do usuário do SPMD do PyTorch/XLA.
Crie o arquivo de configuração do parâmetro do modelo. Confira a seguir a configuração do parâmetro do modelo para Llama3-8B. Para outros modelos, encontre a configuração no Hugging Face. Por exemplo, consulte a configuração Llama2-7B.
cat > llama-config.json <
{ "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF Crie o arquivo de configuração do FSDP:
cat > fsdp-config.json <
{ "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF Para mais informações sobre o FSDP, consulte FSDPv2.
Faça o upload dos arquivos de configuração para as VMs da TPU usando o seguinte comando:
gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json $TPU_NAME:. \ --worker=all \ --project=$PROJECT \ --zone $ZONE
Executar o modelo
Usando os arquivos de configuração criados na seção anterior, execute o script run_clm.py
para treinar o modelo Llama 3 8B no conjunto de dados WikiText. O script de treinamento
leva aproximadamente 10 minutos para ser executado em uma TPU v6e-8.
Faça login no Hugging Face na TPU usando o seguinte comando:
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \ --zone $ZONE \ --worker=all \ --command=' pip3 install "huggingface_hub[cli]" huggingface-cli login --token HUGGING_FACE_TOKEN'
Execute o treinamento de modelo:
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \ --zone $ZONE \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=PROFILE_PATH python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
Solução de problemas do PyTorch/XLA
Se você definir as variáveis opcionais para depuração na seção anterior,
o perfil do modelo será armazenado no local especificado pela
variável PROFILE_LOGDIR
. É possível extrair o arquivo xplane.pb
armazenado
neste local e usar tensorboard
para conferir os perfis no
navegador usando as instruções do TensorBoard.
Se o PyTorch/XLA não estiver funcionando como esperado, consulte o guia de solução de problemas,
que tem sugestões para depurar, criar perfis e otimizar seu modelo.
Treinamento do DLRM DCN v2 na v6e
Neste tutorial, mostramos como treinar o modelo DLRM DCN v2 na TPU v6e. É necessário provisionar uma TPU v6e com 64, 128 ou 256 chips.
Se você estiver executando em vários hosts, redefina tpu-runtime
com a versão
apropriada do TensorFlow executando
o comando a seguir. Se você estiver executando em um único host, não será necessário
executar os dois comandos a seguir.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID}
--zone ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
SSH no worker-0
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}
Definir o nome da TPU
export TPU_NAME=${TPU_NAME}
Executar o DLRM v2
pip install --user setuptools==65.5.0
pip install cloud-tpu-client
pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git
export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'
TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py --mode=train --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
distribution_strategy: tpu
mixed_precision_dtype: 'mixed_bfloat16'
task:
use_synthetic_data: false
use_tf_record_reader: true
train_data:
input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
global_batch_size: 16384
use_cached_data: true
validation_data:
input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
global_batch_size: 16384
use_cached_data: true
model:
num_dense_features: 13
bottom_mlp: [512, 256, 128]
embedding_dim: 128
interaction: 'multi_layer_dcn'
dcn_num_layers: 3
dcn_low_rank_dim: 512
size_threshold: 8000
top_mlp: [1024, 1024, 512, 256, 1]
use_multi_hot: true
concat_dense: false
dcn_use_bias: true
vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
max_ids_per_chip_per_sample: 128
max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
use_partial_tpu_embedding: false
size_threshold: 0
initialize_tables_on_host: true
trainer:
train_steps: 10000
validation_interval: 1000
validation_steps: 660
summary_interval: 1000
steps_per_loop: 1000
checkpoint_interval: 0
optimizer_config:
embedding_optimizer: 'Adagrad'
dense_optimizer: 'Adagrad'
lr_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.025
warmup_steps: 0
dense_sgd_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.00025
warmup_steps: 8000
train_tf_function: true
train_tf_while_loop: true
eval_tf_while_loop: true
use_orbit: true
pipeline_sparse_and_dense_execution: true"
Execute script.sh
:
chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
As flags a seguir são necessárias para executar cargas de trabalho de recomendação (DLRM DCN):
ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"
Resultados da comparação
A seção a seguir contém resultados de comparação de mercado para o DLRM DCN v2 e MaxDiffusion na v6e.
DLRM DCN v2
O script de treinamento do DLRM DCN v2 foi executado em diferentes escalas. Confira as taxas de transferência na tabela a seguir.
v6e-64 | v6e-128 | v6e-256 | |
Etapas de treinamento | 7.000 | 7.000 | 7.000 |
Tamanho global do lote | 131072 | 262144 | 524288 |
Capacidade (exemplos/s) | 2975334 | 5111808 | 10066329 |
MaxDiffusion
Executamos o script de treinamento para MaxDiffusion em um v6e-4, um v6e-16 e um 2xv6e-16. Confira as taxas de transferência na tabela a seguir.
v6e-4 | v6e-16 | Duas v6e-16 | |
Etapas de treinamento | 0,069 | 0,073 | 0,13 |
Tamanho global do lote | 8 | 32 | 64 |
Capacidade (exemplos/s) | 115,9 | 438,4 | 492.3 |
Programação da coleta
O Trillium (v6e) inclui um novo recurso chamado "programação de coleta". Esse recurso oferece uma maneira de gerenciar várias frações de TPU que executam uma carga de trabalho de inferência de host único no GKE e na API Cloud TPU. Agrupar essas fatias em uma coleção facilita o ajuste do número de réplicas para atender à demanda. As atualizações de software são controladas com cuidado para garantir que uma parte das fatias na coleção esteja sempre disponível para processar o tráfego de entrada.
Consulte a documentação do GKE para mais informações sobre como usar a programação de coleta com o GKE.
O recurso de programação de coleta só se aplica à v6e.
Usar a programação de coleta da API Cloud TPU
Uma coleção de um único host na API Cloud TPU é um recurso em fila em
que uma flag especial (--workload-type = availability-optimized
) é definida para
indicar à infraestrutura subjacente que ela deve ser usada para
servir cargas de trabalho.
O comando a seguir provisiona uma coleção de host único usando a API Cloud TPU:
gcloud alpha compute tpus queued-resources create my-collection \ --project=$PROJECT_ID \ --zone=${ZONE} \ --accelerator-type $ACCELERATOR_TYPE \ --node-count ${NODE_COUNT} \ --workload-type=availability-optimized
Monitorar e criar perfil
A Cloud TPU v6e oferece suporte ao monitoramento e ao perfil usando os mesmos métodos das gerações anteriores da Cloud TPU. Para mais informações sobre o monitoramento, consulte Monitorar VMs do TPU.