Executar o código PyTorch em frações do Pod de TPU

O PyTorch/XLA exige que todas as VMs de TPU tenham acesso ao código e aos dados do modelo. É possível usar um script de inicialização para fazer o download do software necessário para distribuir os dados do modelo para todas as VMs da TPU.

Se você estiver conectando suas VMs de TPU a uma nuvem privada virtual (VPC), será preciso adicionar uma regra de firewall em seu projeto para permitir a entrada de portas 8470 a 8479. Para mais informações sobre como adicionar regras de firewall, consulte Como usar regras de firewall.

Configurar o ambiente

  1. No Cloud Shell, execute o seguinte comando para verificar se você está executando a versão atual de gcloud:

    $ gcloud components update
    

    Se você precisar instalar o gcloud, use o seguinte comando:

    $ sudo apt install -y google-cloud-sdk
  2. Crie algumas variáveis de ambiente:

    $ export PROJECT_ID=project-id
    $ export TPU_NAME=tpu-name
    $ export ZONE=us-central2-b
    $ export RUNTIME_VERSION=tpu-ubuntu2204-base
    $ export ACCELERATOR_TYPE=v4-32
    

Criar a VM de TPU

$ gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--project=${PROJECT_ID} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version ${RUNTIME_VERSION}

Configure e execute o script de treinamento

  1. Adicione o certificado SSH ao seu projeto:

    ssh-add ~/.ssh/google_compute_engine
    
  2. Instalar o PyTorch/XLA em todos os workers de VMs da TPU

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="
      pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
    
  3. Clonar XLA em todos os workers da VM da TPU

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="git clone -b r2.2 https://github.com/pytorch/xla.git"
    
  4. Executar o script de treinamento em todos os workers

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all \
      --command="PJRT_DEVICE=TPU python3 ~/xla/test/test_train_mp_imagenet.py  \
      --fake_data \
      --model=resnet50  \
      --num_epochs=1 2>&1 | tee ~/logs.txt"
      

    O treinamento leva cerca de 5 minutos. Quando ele for concluída, você verá uma mensagem semelhante a esta:

    Epoch 1 test end 23:49:15, Accuracy=100.00
    10.164.0.11 [0] Max Accuracy: 100.00%
    

Limpar

Quando terminar de usar a VM de TPU, siga estas etapas para limpar os recursos.

  1. Desconecte-se do Compute Engine:

    (vm)$ exit
    
  2. Execute o seguinte comando para verificar se os recursos foram excluídos. Verifique se a TPU não está mais listada. A exclusão pode levar vários minutos.

    $ gcloud compute tpus tpu-vm list \
      --zone europe-west4-a