Ejecuta el código de PyTorch en porciones de pod de TPU

PyTorch/XLA requiere que todas las VM de la TPU puedan acceder al código y a los datos del modelo. Puedes usar una secuencia de comandos de inicio para descargar el software necesario a fin de distribuir los datos del modelo a todas las VM de TPU.

Si conectas tus VM de TPU a una nube privada virtual (VPC), debes agregar una regla de firewall en tu proyecto para permitir la entrada de los puertos 8470 a 8479. Para obtener más información sobre cómo agregar reglas de firewall, consulta Usa reglas de firewall.

Configura tu entorno

  1. En Cloud Shell, ejecuta el siguiente comando para asegurarte de estar ejecutando la versión actual de gcloud:

    $ gcloud components update
    

    Si necesitas instalar gcloud, usa el siguiente comando:

    $ sudo apt install -y google-cloud-sdk
  2. Crea algunas variables de entorno:

    $ export PROJECT_ID=project-id
    $ export TPU_NAME=tpu-name
    $ export ZONE=us-central2-b
    $ export RUNTIME_VERSION=tpu-ubuntu2204-base
    $ export ACCELERATOR_TYPE=v4-32
    

Crea la VM de TPU

$ gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--project=${PROJECT_ID} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version ${RUNTIME_VERSION}

Configura y ejecuta la secuencia de comandos de entrenamiento

  1. Agrega el certificado SSH al proyecto:

    ssh-add ~/.ssh/google_compute_engine
    
  2. Instala PyTorch/XLA en todos los trabajadores de VM de TPU

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="
      pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
    
  3. Clona XLA en todos los trabajadores de VM de TPU

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="git clone -b r2.2 https://github.com/pytorch/xla.git"
    
  4. Ejecuta la secuencia de comandos de entrenamiento en todos los trabajadores

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all \
      --command="PJRT_DEVICE=TPU python3 ~/xla/test/test_train_mp_imagenet.py  \
      --fake_data \
      --model=resnet50  \
      --num_epochs=1 2>&1 | tee ~/logs.txt"
      

    El entrenamiento tarda alrededor de 5 minutos. Cuando se complete, deberías ver un mensaje similar al siguiente:

    Epoch 1 test end 23:49:15, Accuracy=100.00
    10.164.0.11 [0] Max Accuracy: 100.00%
    

Limpia

Cuando termines de usar la VM de TPU, sigue estos pasos para limpiar los recursos.

  1. Desconéctate de Compute Engine:

    (vm)$ exit
    
  2. Ejecuta el siguiente comando para verificar que se hayan borrado los recursos. Asegúrate de que tu TPU ya no aparezca en la lista. La eliminación puede tardar varios minutos.

    $ gcloud compute tpus tpu-vm list \
      --zone europe-west4-a