Introdução ao Trillium (v6e)

O v6e é usado para se referir ao Trillium nesta documentação, na API TPU e nos registros. O v6e representa a 6ª geração de TPU do Google.

Com 256 chips por pod, a v6e compartilha muitas semelhanças com a v5e. Esse sistema é otimizado para ser o produto de maior valor para treinamento, ajuste fino e veiculação de transformadores, conversão de texto em imagem e redes neurais convolucionais (CNNs).

Arquitetura do sistema v6e

Para informações sobre a configuração do Cloud TPU, consulte a documentação da v6e.

Este documento se concentra no processo de configuração do treinamento do modelo usando os frameworks JAX, PyTorch ou TensorFlow. Com cada framework, é possível provisionar TPUs usando recursos em fila ou o Google Kubernetes Engine (GKE). A configuração do GKE pode ser feita usando comandos XPK ou GKE.

Preparar um projeto do Google Cloud

  1. Faça login na sua Conta do Google. Se ainda não tiver uma, crie uma nova conta.
  2. No console do Google Cloud, selecione ou crie um projeto do Cloud na página do seletor de projetos.
  3. Ative o faturamento para seu projeto do Google Cloud. O faturamento é obrigatório para todo o uso do Google Cloud.
  4. Instale os componentes da gcloud alfa.
  5. Execute o comando a seguir para instalar a versão mais recente dos componentes gcloud.

    gcloud components update
    
  6. Ative a API TPU usando o comando gcloud a seguir no Cloud Shell. Também é possível ativá-la no console do Google Cloud.

    gcloud services enable tpu.googleapis.com
    
  7. Ativar permissões com a conta de serviço do TPU para a API Compute Engine

    As contas de serviço permitem que o serviço do Cloud TPU acesse outros serviços do Google Cloud. Uma conta de serviço gerenciado pelo usuário é uma prática recomendada do Google Cloud. Siga estes guias para criar e conceder funções. Os seguintes papéis são necessários:

    • Administrador da TPU
    • Administrador de armazenamento
    • Gravador de registros
    • Gravador de métricas do Monitoring

    a. Configure as permissões XPK com sua conta de usuário para o GKE: XPK.

  8. Crie variáveis de ambiente para o ID e a zona do projeto.

     gcloud auth login
     gcloud config set project ${PROJECT_ID}
     gcloud config set compute/zone ${ZONE}
    
  9. Crie uma identidade de serviço para a VM da TPU.

     gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
    

Capacidade segura

Entre em contato com o suporte de vendas/conta do Cloud TPU para solicitar a cota da TPU e tirar dúvidas sobre a capacidade.

Provisionar o ambiente do Cloud TPU

As TPUs v6e podem ser provisionadas e gerenciadas com o GKE, com o GKE e o XPK (uma ferramenta de wrapper da CLI sobre o GKE) ou como recursos em fila.

Pré-requisitos

  • Verifique se o projeto tem cota suficiente de TPUS_PER_TPU_FAMILY, que especifica o número máximo de chips que você pode acessar no projeto do Google Cloud.
  • O v6e foi testado com a seguinte configuração:
    • Python 3.10 ou mais recente
    • Versões noturnas do software:
      • JAX noturno 0.4.32.dev20240912
      • LibTPU noturno 0.1.dev20240912+nightly
    • Versões estáveis do software:
      • JAX + JAX Lib da v0.4.35
  • Verifique se o projeto tem cota suficiente de TPU para:
    • Cota de VM de TPU
    • Quota de endereços IP
    • Cota do Hyperdisk equilibrado
  • Permissões do projeto do usuário

Variáveis de ambiente

No Cloud Shell, crie as seguintes variáveis de ambiente:

export NODE_ID=TPU_NODE_ID # TPU name
export PROJECT_ID=PROJECT_ID
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-central2-b
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=YOUR_SERVICE_ACCOUNT
export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID
export VALID_DURATION=VALID_DURATION

# Additional environment variable needed for Multislice:
export NUM_SLICES=NUM_SLICES

# Use a custom network for better performance as well as to avoid having the
# default network becoming overloaded.
export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

Descrições de sinalizações de comando

Variável Descrição
NODE_ID O ID atribuído pelo usuário do TPU, que é criado quando a solicitação de recurso em fila é alocada.
PROJECT_ID Nome do projeto do Google Cloud. Use um projeto atual ou crie um novo em
ZONA Consulte o documento Regiões e zonas de TPU para conferir as zonas compatíveis.
ACCELERATOR_TYPE Consulte Tipos de acelerador.
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT Esse é o endereço de e-mail da sua conta de serviço, que pode ser encontrado em Console do Google Cloud -> IAM -> Contas de serviço

Por exemplo: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

NUM_SLICES O número de fatias a serem criadas (necessário apenas para fatias múltiplas)
QUEUED_RESOURCE_ID O ID de texto atribuído pelo usuário da solicitação de recurso em fila.
VALID_DURATION O período em que a solicitação de recurso em fila é válida.
NETWORK_NAME O nome de uma rede secundária a ser usada.
NETWORK_FW_NAME O nome de um firewall de rede secundário a ser usado.

Otimizações de desempenho de rede

Para ter o melhor desempenho,use uma rede com 8.896 MTU (unidade máxima de transmissão).

Por padrão, uma nuvem privada virtual (VPC) fornece apenas uma MTU de 1.460 bytes,o que vai gerar um desempenho de rede subótimo. É possível definir a MTU de uma rede VPC como qualquer valor entre 1.300 e 8.896 bytes (inclusivo). Os tamanhos comuns de MTU personalizados são 1.500 bytes (Ethernet padrão) ou 8.896 bytes (o máximo possível). Para mais informações, consulte Tamanhos válidos de MTU da rede VPC.

Para mais informações sobre como mudar a configuração de MTU de uma rede existente ou padrão, consulte Alterar a configuração de MTU de uma rede VPC.

O exemplo a seguir cria uma rede com 8.896 MTU

export RESOURCE_NAME=RESOURCE_NAME
export NETWORK_NAME=${RESOURCE_NAME}
export NETWORK_FW_NAME=${RESOURCE_NAME}
export PROJECT=X
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT} --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} \

Como usar a multi-NIC (opção para multislice)

As variáveis de ambiente a seguir são necessárias para uma sub-rede secundária quando você está usando um ambiente de várias fatias.

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2

Use os comandos a seguir para criar um roteamento IP personalizado para a rede e a sub-rede.

gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
   --bgp-routing-mode=regional --subnet-mode=custom --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
   --network="${NETWORK_NAME_2}" \
   --range=10.10.0.0/18 --region="${REGION}" \
   --project=$PROJECT

gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
   --network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project="${PROJECT}"

gcloud compute routers create "${ROUTER_NAME}" \
  --project="${PROJECT}" \
  --network="${NETWORK_NAME_2}" \
  --region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \
  --router="${ROUTER_NAME}" \
  --region="${REGION}" \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project="${PROJECT}" \
  --enable-logging

Depois que uma fatia de várias redes for criada, você poderá validar se as duas NICs estão sendo usadas executando --command ifconfig como parte da carga de trabalho do XPK. Em seguida, analise a saída impressa dessa carga de trabalho XPK nos registros do console do Cloud e verifique se eth0 e eth1 têm mtu=8896.

python3 xpk.py workload create \
   --cluster ${CLUSTER_NAME} \
   (--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}) \
   --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone $ZONE \
   --project $PROJECT_ID \
   [--enable-debug-logs] \
   [--use-vertex-tensorboard] \
   --command "ifconfig"

Verifique se eth0 e eth1 têm mtu=8,896. Uma maneira de verificar se a multi-NIC está em execução é executar o comando --command "ifconfig" como parte da carga de trabalho do XPK. Em seguida, analise a saída impressa dessa carga de trabalho xpk nos registros do console do Cloud e verifique se eth0 e eth1 têm mtu=8896.

Melhorias nas configurações do TCP

Para TPUs criadas usando a interface de recursos em fila, execute o comando abaixo para melhorar o desempenho da rede, alterando as configurações padrão do TCP para rto_min e quickack.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
   --project "$PROJECT" --zone "${ZONE}" \
   --command='ip route show | while IFS= read -r route; do if ! echo $route | \
   grep -q linkdown; then sudo ip route change ${route/lock/} rto_min 5ms quickack 1; fi; done' \
   --worker=all

Provisionamento com recursos em fila (API Cloud TPU)

A capacidade pode ser provisionada usando o comando create de recursos em fila.

  1. Crie uma solicitação de recurso em fila da TPU.

    A flag --reserved é necessária apenas para recursos reservados, não para recursos sob demanda.

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --valid-until-duration ${VALID_DURATION} \
      --service-account ${SERVICE_ACCOUNT} \
      [--reserved]

    Se a solicitação de recurso em fila for criada, o estado no campo "response" será "WAITING_FOR_RESOURCES" ou "FAILED". Se a solicitação de recurso enfileirado estiver no estado "WAITING_FOR_RESOURCES", o recurso enfileirado foi enfileirado e será provisionado quando houver capacidade suficiente de TPU. Se a solicitação de recurso em fila estiver no estado "FAILED", o motivo da falha vai aparecer na saída. A solicitação de recurso em fila vai expirar se um v6e não for provisionado dentro da duração especificada e o estado se tornar "FAILED". Consulte a documentação pública Recursos em fila para mais informações.

    Quando a solicitação de recurso enfileirado está no estado "ATIVO", é possível se conectar às VMs de TPU usando o SSH. Use os comandos list ou describe para consultar o status do recurso em fila.

    gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project ${PROJECT_ID} --zone ${ZONE}
    

    Quando o recurso na fila está no estado "ATIVO", a saída é semelhante a esta:

      state:
       state: ACTIVE
    
  2. Gerenciar VMs TPU. Para opções de gerenciamento de VMs de TPU, consulte Gerenciar VMs de TPU.

  3. Conectar-se às VMs do TPU usando SSH

    É possível instalar binários em cada VM de TPU na fatia de TPU e executar o código. Consulte a seção Tipos de VM para determinar quantas VMs seu slice terá.

    Para instalar os binários ou executar o código, use o SSH para se conectar a uma VM usando o comando tpu-vm ssh.

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --node=all # add this flag if you are using Multislice
    

    Para usar o SSH para se conectar a uma VM específica, use a flag --worker que segue um índice baseado em 0:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
    

    Se você tiver formas de fatia maiores que 8 chips, terá várias VMs em uma fatia. Nesse caso, use os parâmetros --worker=all e --command no comando gcloud alpha compute tpus tpu-vm ssh para executar um comando em todas as VMs simultaneamente. Exemplo:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
      --zone  ${ZONE} --worker=all \
      --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
      -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Excluir um recurso na fila

    Exclua um recurso na fila no fim da sessão ou remova solicitações de recursos na fila que estão no estado "FAILED". Para excluir um recurso na fila, exclua a fatia e, em seguida, a solicitação de recurso na fila em duas etapas:

    gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \
     --zone=${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
     --project ${PROJECT_ID} --zone ${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
    

Como usar o GKE com o v6e

Se você estiver usando comandos do GKE com o v6e, poderá usar comandos do Kubernetes ou o XPK para provisionar TPUs e treinar ou oferecer modelos. Consulte Planejar TPUs no GKE para saber como usar o GKE com TPUs e v6e.

Configuração do framework

Esta seção descreve o processo de configuração geral para treinamento de modelo de ML usando os frameworks JAX, PyTorch ou TensorFlow. É possível provisionar TPUs usando recursos em fila ou o GKE. A configuração do GKE pode ser feita usando comandos do XPK ou do Kubernetes.

Configurar o JAX usando recursos na fila

Instale o JAX em todas as VMs de TPU no seu slice ou em vários slices simultaneamente usando gcloud alpha compute tpus tpu-vm ssh. Para "Várias fatias", adicione --node=all.


gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'

Você pode executar o código Python abaixo para verificar quantos núcleos de TPU estão disponíveis na sua fatia e testar se tudo está instalado corretamente. As saídas mostradas aqui foram produzidas com uma fatia v6e-16:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

O resultado será assim:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() mostra o número total de chips na fatia especificada. jax.local_device_count() indica a contagem de chips acessíveis por uma única VM nessa fatia.

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 &&
   && pip install -r requirements.txt  && pip install . '

Solução de problemas de configurações do JAX

Uma dica geral é ativar o registro detalhado no manifesto de carga de trabalho do GKE. Em seguida, envie os registros para o suporte do GKE.

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

Mensagens de erro

no endpoints available for service 'jobset-webhook-service'

Esse erro significa que o jobset não foi instalado corretamente. Verifique se os pods do Kubernetes da implantação jobset-controller-manager estão em execução. Para mais informações, consulte a documentação de solução de problemas do JobSet.

TPU initialization failed: Failed to connect

Verifique se a versão do nó do GKE é 1.30.4-gke.1348000 ou mais recente. Não há suporte para o GKE 1.31.

Configuração para PyTorch

Esta seção descreve como começar a usar o PJRT na v6e com o PyTorch/XLA. O Python 3.10 é a versão recomendada.

Configurar o PyTorch usando o GKE com o XPK

Você pode usar o contêiner do Docker abaixo com o XPK, que já tem as dependências do PyTorch instaladas:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

Para criar uma carga de trabalho XPK, use o seguinte comando:

python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -$NUM_SLICES \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES}  \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

O uso de --base-docker-image cria uma nova imagem do Docker com o diretório de trabalho atual integrado ao novo Docker.

Configurar o PyTorch usando recursos em fila

Siga estas etapas para instalar o PyTorch usando recursos em fila e executar um pequeno script na v6e.

Instale dependências usando SSH para acessar as VMs.

Para o recurso Multislice, adicione --node=all:

   gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt install -y libopenblas-base pip3 \
    install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
    --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

Melhorar o desempenho de modelos com alocações frequentes e de tamanho considerável

Para modelos com alocações frequentes e grandes, observamos que o uso de tcmalloc melhora o desempenho significativamente em comparação com a implementação padrão de malloc. Portanto, o malloc padrão usado na VM de TPU é tcmalloc. No entanto, dependendo da carga de trabalho (por exemplo, com o DLRM, que tem alocações muito grandes para as tabelas de incorporação), tcmalloc pode causar uma lentidão. Nesse caso, tente redefinir a seguinte variável usando o malloc padrão:

unset LD_PRELOAD

Use um script Python para fazer um cálculo na VM v6e:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

Isso gera um resultado semelhante ao seguinte:

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

Configuração para o TensorFlow

Para a prévia pública da v6e, somente a versão do ambiente de execução tf-nightly é compatível.

É possível redefinir tpu-runtime com a versão do TensorFlow compatível com a v6e executando os seguintes comandos:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
    --zone  ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
    --zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

Use o SSH para acessar o worker-0:

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
     --zone ${ZONE}

Instale o TensorFlow no worker-0:

sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

Exporte a variável de ambiente TPU_NAME:

export TPU_NAME=v6e-16

Você pode executar o script Python abaixo para verificar quantos núcleos de TPU estão disponíveis na sua fatia e testar se tudo está instalado corretamente. As saídas mostradas foram produzidas com uma fatia v6e-16:

import TensorFlow as tf
print("TensorFlow version " + tf.__version__)

@tf.function
  def add_fn(x,y):
  z = x + y
  return z

  cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  tf.config.experimental_connect_to_cluster(cluster_resolver)
  tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
  strategy = tf.distribute.TPUStrategy(cluster_resolver)

  x = tf.constant(1.)
  y = tf.constant(1.)
  z = strategy.run(add_fn, args=(x,y))
  print(z)

O resultado será assim:

PerReplica:{
  0: tf.Tensor(2.0, shape=(), dtype=float32),
  1: tf.Tensor(2.0, shape=(), dtype=float32),
  2: tf.Tensor(2.0, shape=(), dtype=float32),
  3: tf.Tensor(2.0, shape=(), dtype=float32),
  4: tf.Tensor(2.0, shape=(), dtype=float32),
  5: tf.Tensor(2.0, shape=(), dtype=float32),
  6: tf.Tensor(2.0, shape=(), dtype=float32),
  7: tf.Tensor(2.0, shape=(), dtype=float32)
}

v6e com SkyPilot

É possível usar a TPU v6e com o SkyPilot. Siga as etapas abaixo para adicionar informações de local/preço relacionadas à v6e ao SkyPilot.

  1. Adicione o seguinte ao final de ~/.sky/catalogs/v5/gcp/vms.csv :

    ,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
    
  2. Especifique os seguintes recursos em um arquivo YAML:

    # tpu_v6.yaml
    resources:
      accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use
      accelerator_args:
        runtime_version: v2-alpha-tpuv6e # Official suggested runtime
    
  3. Inicie um cluster com a TPU v6e:

       sky launch tpu_v6.yaml -c tpu_v6
    
  4. Conecte-se ao TPU v6e usando SSH: ssh tpu_v6

Tutoriais de inferência

As seções a seguir fornecem tutoriais para veicular modelos MaxText e PyTorch usando o JetStream, além de veicular modelos MaxDiffusion na TPU v6e.

MaxText no JetStream

Este tutorial mostra como usar o JetStream para disponibilizar modelos MaxText (JAX) na TPU v6e. O JetStream é um mecanismo otimizado para capacidade de processamento e memória para inferência de modelos de linguagem grandes (LLMs) em dispositivos XLA (TPUs). Neste tutorial, você vai executar o benchmark de inferência para o modelo Llama2-7B.

Antes de começar

  1. Crie uma TPU v6e com quatro chips:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. Conecte-se à TPU usando SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME

Executar o tutorial

Para configurar o JetStream e o MaxText, converter os pontos de verificação do modelo e executar o benchmark de inferência, siga as instruções no repositório do GitHub.

Limpar

Exclua a TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

vLLM no TPU do PyTorch

Confira abaixo um tutorial simples mostrando como começar a usar o vLLM na VM do TPU. Para o exemplo de práticas recomendadas de implantação do vLLM no Trillium na produção, vamos publicar um guia do usuário do GKE nos próximos dias. Fique ligado!

Antes de começar

  1. Crie uma TPU v6e com quatro chips:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
       --node-id TPU_NAME \
       --project PROJECT_ID \
       --zone ZONE \
       --accelerator-type v6e-4 \
       --runtime-version v2-alpha-tpuv6e \
       --service-account SERVICE_ACCOUNT

    Descrições de sinalizações de comando

    Variável Descrição
    NODE_ID O ID atribuído pelo usuário da TPU, que é criado quando a solicitação de recurso em fila é alocada.
    PROJECT_ID Nome do projeto do Google Cloud. Use um projeto atual ou crie um novo em
    ZONA Consulte o documento Regiões e zonas de TPU para conferir as zonas compatíveis.
    ACCELERATOR_TYPE Consulte Tipos de acelerador.
    RUNTIME_VERSION v2-alpha-tpuv6e
    SERVICE_ACCOUNT Esse é o endereço de e-mail da sua conta de serviço, que pode ser encontrado em Console do Google Cloud -> IAM -> Contas de serviço

    Por exemplo: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

  2. Conecte-se à TPU usando SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME
    

Create a Conda environment

  1. (Recommended) Create a new conda environment for vLLM:

    conda create -n vllm python=3.10 -y
    conda activate vllm

Configurar o vLLM na TPU

  1. Clone o repositório vLLM e navegue até o diretório vLLM:

    git clone https://github.com/vllm-project/vllm.git && cd vllm
    
  2. Limpe os pacotes torch e torch-xla:

    pip uninstall torch torch-xla -y
    
  3. Instale o PyTorch e o PyTorch XLA:

    pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
    
  4. Instale o JAX e o Pallas:

    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
    pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
    
    
  5. Instale outras dependências do build:

    pip install -r requirements-tpu.txt
    VLLM_TARGET_DEVICE="tpu" python setup.py develop
    sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
    

Receber acesso ao modelo

É necessário assinar o contrato de consentimento para usar a família de modelos Llama3 no repositório do Hugging Face.

Gere um novo token do Huggin' Face, caso ainda não tenha um:

  1. Clique em Seu perfil > Configurações > Tokens de acesso.
  2. Selecione Novo token.
  3. Especifique um Nome de sua escolha e um Papel de pelo menos Read.
  4. Selecione Gerar um token.
  5. Copie o token gerado para a área de transferência, defina-o como uma variável de ambiente e faça a autenticação com o huggingface-cli:

    export TOKEN=''
    git config --global credential.helper store
    huggingface-cli login --token $TOKEN

Fazer o download dos dados de comparativo de mercado

  1. Crie um diretório /data e faça o download do conjunto de dados ShareGPT do Hugging Face.

    mkdir ~/data && cd ~/data
    wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
    

Iniciar o servidor vLLM

O comando a seguir faz o download dos pesos do modelo do Hugging Face Model Hub para o diretório /tmp da VM do TPU, pré-compila uma variedade de formas de entrada e grava a compilação do modelo em ~/.cache/vllm/xla_cache.

Para mais detalhes, consulte os documentos de vLLM.

   cd ~/vllm
   vllm serve "meta-llama/Meta-Llama-3.1-8B" --download_dir /tmp --num-scheduler-steps 4 --swap-space 16 --disable-log-requests --tensor_parallel_size=4 --max-model-len=2048 &> serve.log &

Executar comparativos do vLLM

Execute o script de comparação de vLLMs:

   python benchmarks/benchmark_serving.py \
       --backend vllm \
       --model "meta-llama/Meta-Llama-3.1-8B"  \
       --dataset-name sharegpt \
       --dataset-path ~/data/ShareGPT_V3_unfiltered_cleaned_split.json  \
       --num-prompts 1000

Limpar

Exclua a TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

PyTorch no JetStream

Este tutorial mostra como usar o JetStream para disponibilizar modelos do PyTorch no TPU v6e. O JetStream é um mecanismo otimizado para capacidade de processamento e memória para inferência de modelos de linguagem grandes (LLMs) em dispositivos XLA (TPUs). Neste tutorial, você vai executar o benchmark de inferência para o modelo Llama2-7B.

Antes de começar

  1. Crie uma TPU v6e com quatro chips:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. Conecte-se à TPU usando SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME

Executar o tutorial

Para configurar o JetStream-PyTorch, converter os checkpoints do modelo e executar o benchmark de inferência, siga as instruções no repositório do GitHub.

Limpar

Exclua a TPU:

   gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --force \
      --async

Inferência do MaxDiffusion

Este tutorial mostra como disponibilizar modelos MaxDiffusion na TPU v6e. Neste tutorial, você vai gerar imagens usando o modelo Stable Diffusion XL.

Antes de começar

  1. Crie uma TPU v6e com quatro chips:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. Conecte-se à TPU usando SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME

Criar um ambiente da Conda

  1. Crie um diretório para o Miniconda:

    mkdir -p ~/miniconda3
  2. Faça o download do script de instalação do Miniconda:

    wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
  3. Instale o Miniconda:

    bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
  4. Remova o script de instalação do Miniconda:

    rm -rf ~/miniconda3/miniconda.sh
  5. Adicione o Miniconda à variável PATH:

    export PATH="$HOME/miniconda3/bin:$PATH"
  6. Atualize ~/.bashrc para aplicar as mudanças à variável PATH:

    source ~/.bashrc
  7. Crie um novo ambiente do Conda:

    conda create -n tpu python=3.10
  8. Ative o ambiente Conda:

    source activate tpu

Configurar o MaxDiffusion

  1. Clone o repositório MaxDiffusion e navegue até o diretório MaxDiffusion:

    https://github.com/google/maxdiffusion.git && cd maxdiffusion
  2. Alterne para a ramificação mlperf-4.1:

    git checkout mlperf4.1
  3. Instale o MaxDiffusion:

    pip install -e .
  4. Instale as dependências:

    pip install -r requirements.txt
  5. Instale o JAX:

    pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Gerar imagens

  1. Defina variáveis de ambiente para configurar o ambiente de execução da TPU:

    LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536"
  2. Gerar imagens usando o comando e as configurações definidas em src/maxdiffusion/configs/base_xl.yml:

    python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run"

Limpar

Exclua a TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

Tutoriais de treinamento

As seções a seguir fornecem tutoriais para treinar o MaxText.

Modelos MaxDiffusion e PyTorch na TPU v6e.

MaxText e MaxDiffusion

As seções a seguir abrangem o ciclo de vida de treinamento dos modelos MaxText e MaxDiffusion.

Em geral, as etapas de alto nível são:

  1. Crie a imagem de base da carga de trabalho.
  2. Execute a carga de trabalho usando o XPK.
    1. Crie o comando de treinamento para a carga de trabalho.
    2. Implante a carga de trabalho.
  3. Acompanhe a carga de trabalho e confira as métricas.
  4. Exclua a carga de trabalho do XPK se ela não for necessária.
  5. Exclua o cluster de XPK quando ele não for mais necessário.

Criar a imagem de base

Instale o MaxText ou o MaxDiffusion e crie a imagem do Docker:

  1. Clone o repositório que você quer usar e mude para o diretório do repositório:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. Configure o Docker para usar a Google Cloud CLI:

    gcloud auth configure-docker
    
  3. Crie a imagem do Docker usando o comando a seguir ou o JAX Stable Stack. Para mais informações sobre a pilha estável do JAX, consulte Criar uma imagem do Docker com a pilha estável do JAX.

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    
  4. Se você estiver iniciando a carga de trabalho em uma máquina que não tem a imagem criada localmente, faça o upload da imagem:

    bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
    
Criar uma imagem do Docker com a pilha estável do JAX

É possível criar as imagens do Docker MaxText e MaxDiffusion usando a imagem base da pilha Stable do JAX.

A pilha estável do JAX oferece um ambiente consistente para MaxText e MaxDiffusion agrupando o JAX com pacotes principais, como orbax, flax e optax, com um libtpu.so bem qualificado que direciona utilitários de programa de TPU e outras ferramentas essenciais. Essas bibliotecas são testadas para garantir compatibilidade, fornecendo uma base estável para criar e executar o MaxText e o MaxDiffusion e eliminando possíveis conflitos devido a versões de pacote incompatíveis.

O JAX Stable Stack inclui um libtpu.so totalmente lançado e qualificado, a biblioteca principal que orienta a compilação, a execução e a configuração de rede ICI do programa TPU. A versão do libtpu substitui o build noturno usado anteriormente pelo JAX e garante a funcionalidade consistente das computações XLA na TPU com testes de qualificação no nível do PJRT em HLO/StableHLO IRs.

Para criar a imagem do Docker MaxText e MaxDiffusion com a pilha estável do JAX, ao executar o script docker_build_dependency_image.sh, defina a variável MODE como stable_stack e a variável BASEIMAGE como a imagem de base que você quer usar.

O exemplo a seguir especifica us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1 como a imagem base:

bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1

Para conferir uma lista de imagens de base do JAX Stable Stack disponíveis, consulte Imagens do JAX Stable Stack no Artifact Registry.

Executar a carga de trabalho usando o XPK

  1. Defina as seguintes variáveis de ambiente se você não estiver usando os valores padrão definidos por MaxText ou MaxDiffusion:

    BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    PER_DEVICE_BATCH_SIZE=2
    NUM_STEPS=30
    MAX_TARGET_LENGTH=8192
  2. Crie o script do modelo para ser copiado como um comando de treinamento na próxima etapa. Não execute o script do modelo ainda.

    MaxText

    O MaxText é um LLM de código aberto de alto desempenho e altamente escalonável escrito em Python puro e JAX, que tem como alvo TPUs e GPUs do Google Cloud para treinamento e inferência.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \
            base_output_directory=$BASE_OUTPUT_DIR \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}"  # attention='dot_product'"
    

    Gemma2

    O Gemma é uma família de modelos de linguagem grandes (LLMs) de pesos abertos desenvolvidos pelo Google DeepMind com base na pesquisa e tecnologia do Gemini.

    # Requires v6e-256
    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    O Mixtral é um modelo de IA de última geração desenvolvido pela Mistral AI, que utiliza uma arquitetura de mistura de especialistas esparsos (MoE, na sigla em inglês).

    python3 MaxText/train.py MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama é uma família de modelos de linguagem grandes (LLMs) com pesos abertos desenvolvidos pela Meta.

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=llama3-8b \
        base_output_directory=${BASE_OUTPUT_DIR} \
        dataset_type=synthetic \
        tokenizer_path=assets/tokenizer_llama3.tiktoken \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \
        gcs_metrics=true \
        profiler=xplane \
        skip_first_n_steps_for_profiler=5 \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        attention=flash"
    

    MaxDiffusion

    O MaxDiffusion é uma coleção de implementações de referência de vários modelos de difusão latente escritos em Python puro e JAX que são executados em dispositivos XLA, incluindo TPUs do Cloud e GPUs. O Stable Diffusion é um modelo latente de texto para imagem que gera imagens fotorrealistas a partir de qualquer entrada de texto.

    Você precisa instalar uma ramificação específica para executar o MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    Script de treinamento:

        cd maxdiffusion && OUT_DIR=${your_own_bucket}
        python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml \
            run_name=v6e-sd2 \
            split_head_dim=True \
            attention=flash \
            train_new_unet=false \
            norm_num_groups=16 \
            output_dir=${BASE_OUTPUT_DIR} \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            [dcn_data_parallelism=2] \
            enable_profiler=True \
            skip_first_n_steps_for_profiler=95 \
            max_train_steps=${NUM_STEPS} ]
            write_metrics=True'
        
  3. Execute o modelo usando o script criado na etapa anterior. É necessário especificar a flag --base-docker-image para usar a imagem de base MaxText ou especificar a flag --docker-image e a imagem que você quer usar.

    Opcional: é possível ativar o registro de depuração incluindo a flag --enable-debug-logs. Para mais informações, consulte Depurar JAX no MaxText.

    Opcional: é possível criar um experimento da Vertex AI para fazer upload de dados para o TensorBoard da Vertex AI incluindo a flag --use-vertex-tensorboard. Para mais informações, consulte Monitorar JAX no MaxText usando a Vertex AI.

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \
        --workload ${USER}-xpk-ACCELERATOR_TYPE-NUM_SLICES \
        --tpu-type=ACCELERATOR_TYPE \
        --num-slices=NUM_SLICES  \
        --on-demand \
        --zone $ZONE \
        --project $PROJECT_ID \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command YOUR_MODEL_SCRIPT

    Substitua as seguintes variáveis:

    • CLUSTER_NAME: o nome do cluster XPK.
    • ACCELERATOR_TYPE: a versão e o tamanho da TPU. Por exemplo, v6e-256.
    • NUM_SLICES: o número de fatias de TPU.
    • YOUR_MODEL_SCRIPT: o script do modelo a ser executado como um comando de treinamento.

    A saída inclui um link para acompanhar sua carga de trabalho, semelhante a este:

    [XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
    

    Abra o link e clique na guia Logs para acompanhar sua carga de trabalho em tempo real.

Depurar o JAX no MaxText

Use comandos XPK complementares para diagnosticar por que o cluster ou a carga de trabalho não está em execução:

Monitorar JAX no MaxText usando a Vertex AI

Acesse dados escalares e de perfil pelo TensorBoard gerenciado da Vertex AI.

  1. Aumente as solicitações de gerenciamento de recursos (CRUD) da zona que você está usando de 600 para 5.000. Isso pode não ser um problema para cargas de trabalho pequenas que usam menos de 16 VMs.
  2. Instale dependências, como cloud-accelerator-diagnostics, para o Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Crie o cluster XPK usando a flag --create-vertex-tensorboard, conforme documentado em Criar o TensorBoard da Vertex AI. Também é possível executar esse comando em clusters atuais.

  4. Crie seu experimento da Vertex AI ao executar a carga de trabalho do XPK usando a flag --use-vertex-tensorboard e a flag opcional --experiment-name. Para conferir a lista completa de etapas, consulte Criar um experimento da Vertex AI para fazer upload de dados para o TensorBoard da Vertex AI.

Os registros incluem um link para um TensorBoard da Vertex AI, semelhante a este:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Também é possível encontrar o link do TensorBoard da Vertex AI no console do Google Cloud. Acesse Experimentos da Vertex AI no console do Google Cloud. Selecione a região adequada no menu suspenso.

O diretório do TensorBoard também é gravado no bucket do Cloud Storage especificado com ${BASE_OUTPUT_DIR}.

Excluir cargas de trabalho do XPK

Use o comando xpk workload delete para excluir uma ou mais cargas de trabalho com base no prefixo ou status do job. Esse comando pode ser útil se você enviou cargas de trabalho XPK que não precisam mais ser executadas ou se você tem jobs presos na fila.

Excluir cluster XPK

Use o comando xpk cluster delete para excluir um cluster:

python3 xpk.py cluster delete --cluster CLUSTER_NAME --zone $ZONE --project $PROJECT_ID

Llama e PyTorch

Este tutorial descreve como treinar modelos Llama usando PyTorch/XLA na TPU v6e usando o conjunto de dados WikiText. Além disso, os usuários podem acessar as recripes do modelo de TPU do PyTorch como imagens do Docker aqui.

Instalação

Instale o pytorch-tpu/transformers fork dos transformadores do Hugging Face e as dependências em um ambiente virtual:

git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
cd transformers
pip3 install -e .
pip3 install datasets
pip3 install evaluate
pip3 install scikit-learn
pip3 install accelerate

Configurar as configurações do modelo

O comando de treinamento na próxima seção, Criar o script do modelo, usa dois arquivos de configuração JSON para definir parâmetros de modelo e a configuração de dados paralelos totalmente fragmentados (FSDP, na sigla em inglês). O sharding de FSDP é usado para que os pesos do modelo se ajustem a um tamanho de lote maior durante o treinamento. Ao treinar com modelos menores, pode ser suficiente usar apenas o paralelismo de dados e replicar os pesos em cada dispositivo. Consulte o Guia do usuário do PyTorch/XLA SPMD para mais detalhes sobre como dividir tensores em dispositivos no PyTorch/XLA.

  1. Crie o arquivo de configuração do parâmetro do modelo. Confira a seguir a configuração do parâmetro do modelo para Llama3-8B. Para outros modelos, encontre a configuração no Hugging Face. Por exemplo, consulte a configuração Llama2-7B.

    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
  2. Crie o arquivo de configuração do FSDP:

    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }

    Consulte FSDPv2 para mais detalhes sobre o FSDP.

  3. Faça o upload dos arquivos de configuração para as VMs da TPU usando o seguinte comando:

        gcloud alpha compute tpus tpu-vm scp YOUR_CONFIG_FILE.json $TPU_NAME:. \
            --worker=all \
            --project=$PROJECT \
            --zone $ZONE

    Também é possível criar os arquivos de configuração no diretório de trabalho atual e usar a flag --base-docker-image no XPK.

Criar o script do modelo

Crie o script do modelo, especificando o arquivo de configuração do parâmetro do modelo usando a flag --config_name e o arquivo de configuração do FSDP usando a flag --fsdp_config. Você vai executar esse script na TPU na próxima seção, Executar o modelo. Não execute o script do modelo ainda.

    PJRT_DEVICE=TPU
    XLA_USE_SPMD=1
    ENABLE_PJRT_COMPATIBILITY=true
    # Optional variables for debugging:
    XLA_IR_DEBUG=1
    XLA_HLO_DEBUG=1
    PROFILE_EPOCH=0
    PROFILE_STEP=3
    PROFILE_DURATION_MS=100000
    PROFILE_LOGDIR=local VM path or gs://my-bucket/profile_path
    python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 8 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/config-8B.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp_config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20

Executar o modelo

Execute o modelo usando o script criado na etapa anterior, Crie o script do modelo.

Se você estiver usando uma VM de TPU de host único (como v6e-4), execute o comando de treinamento diretamente na VM de TPU. Se você estiver usando uma VM de TPU com vários hosts, use o comando abaixo para executar o script simultaneamente em todos os hosts:

gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \
    --zone $ZONE \
    --worker=all \
    --command=YOUR_COMMAND

Solução de problemas do PyTorch/XLA

Se você definir as variáveis opcionais para depuração na seção anterior, o perfil do modelo será armazenado no local especificado pela variável PROFILE_LOGDIR. É possível extrair o arquivo xplane.pb armazenado neste local e usar tensorboard para conferir os perfis no navegador usando as instruções do TensorBoard. Se o PyTorch/XLA não estiver funcionando como esperado, consulte o guia de solução de problemas, que tem sugestões para depuração, criação de perfil e otimização de modelos.

Tutorial do DLRM DCN v2

Neste tutorial, mostramos como treinar o modelo DLRM DCN v2 na TPU v6e.

Se você estiver executando em vários hosts, redefina tpu-runtime com a versão apropriada do TensorFlow executando o comando a seguir. Se você estiver executando em um único host, não será necessário executar os dois comandos a seguir.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID}
--zone  ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
 --zone  ${ZONE}   \
 --worker=all \
 --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

SSH no worker-0

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}

Definir o nome da TPU

export TPU_NAME=${TPU_NAME}

Executar o DLRM v2

pip install cloud-tpu-client

pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git

export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'

TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py  --mode=train     --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
  distribution_strategy: tpu
  mixed_precision_dtype: 'mixed_bfloat16'
task:
  use_synthetic_data: false
  use_tf_record_reader: true
  train_data:
    input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  validation_data:
    input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  model:
    num_dense_features: 13
    bottom_mlp: [512, 256, 128]
    embedding_dim: 128
    interaction: 'multi_layer_dcn'
    dcn_num_layers: 3
    dcn_low_rank_dim: 512
    size_threshold: 8000
    top_mlp: [1024, 1024, 512, 256, 1]
    use_multi_hot: true
    concat_dense: false
    dcn_use_bias: true
    vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
    multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
    max_ids_per_chip_per_sample: 128
    max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
    max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
    use_partial_tpu_embedding: false
    size_threshold: 0
    initialize_tables_on_host: true
trainer:
  train_steps: 10000
  validation_interval: 1000
  validation_steps: 660
  summary_interval: 1000
  steps_per_loop: 1000
  checkpoint_interval: 0
  optimizer_config:
    embedding_optimizer: 'Adagrad'
    dense_optimizer: 'Adagrad'
    lr_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.025
      warmup_steps: 0
    dense_sgd_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.00025
      warmup_steps: 8000
  train_tf_function: true
  train_tf_while_loop: true
  eval_tf_while_loop: true
  use_orbit: true
  pipeline_sparse_and_dense_execution: true"

Execute script.sh:

chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

As flags a seguir são necessárias para executar cargas de trabalho de recomendação (DLRM DCN):

ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"

Resultados da comparação

A seção a seguir contém resultados de comparação de mercado para o DLRM DCN v2 e MaxDiffusion na v6e.

DLRM DCN v2

O script de treinamento do DLRM DCN v2 foi executado em diferentes escalas. Confira as taxas de transferência na tabela a seguir.

v6e-64 v6e-128 v6e-256
Etapas de treinamento 7.000 7.000 7.000
Tamanho global do lote 131072 262144 524288
Capacidade (exemplos/s) 2975334 5111808 10066329

MaxDiffusion

Executamos o script de treinamento para MaxDiffusion em um v6e-4, um v6e-16 e um 2xv6e-16. Confira as taxas de transferência na tabela a seguir.

v6e-4 v6e-16 Duas v6e-16
Etapas de treinamento 0,069 0,073 0,13
Tamanho global do lote 8 32 64
Capacidade (exemplos/s) 115,9 438,4 492.3

Coleções

A v6e apresenta um novo recurso chamado "coleções" para beneficiar os usuários que executam cargas de trabalho de exibição. O recurso de coleções só se aplica à v6e.

Com as coleções, você pode indicar ao Google Cloud quais dos seus nós de TPU fazem parte de uma carga de trabalho de veiculação. Isso permite que a infraestrutura do Google Cloud limite e simplifique as interrupções que podem ser aplicadas a cargas de trabalho de treinamento no curso normal de operações.

Usar coleções da API Cloud TPU

Uma coleção de um único host na API Cloud TPU é um recurso em fila em que uma flag especial (--workload-type = availability-optimized) é definida para indicar à infraestrutura subjacente que ela deve ser usada para servir cargas de trabalho.

O comando a seguir provisiona uma coleção de host único usando a API Cloud TPU:

gcloud alpha compute tpus queued-resources create COLLECTION_NAME \
   --project=project name \
   --zone=zone name \
   --accelerator-type=accelerator type \
   --node-count=number of nodes \
   --workload-type=availability-optimized

Monitorar e criar perfil

A Cloud TPU v6e oferece suporte ao monitoramento e ao perfil usando os mesmos métodos das gerações anteriores da Cloud TPU. Para mais informações sobre o monitoramento, consulte Monitorar VMs do TPU.