Run PyTorch code on TPU Pod slices

PyTorch/XLA requires all TPU VMs to be able to access the model code and data. You can use a startup script to download the software needed to distribute the model data to all TPU VMs.

If you are connecting your TPU VMs to a Virtual Private Cloud (VPC) you must add a firewall rule in your project to allow ingress for ports 8470 - 8479. For more information about adding firewall rules, see Using firewall rules

Set up your environment

  1. In the Cloud Shell, run the following command to make sure you are running the current version of gcloud:

    $ gcloud components update
    

    If you need to install gcloud, use the following command:

    $ sudo apt install -y google-cloud-sdk
  2. Create some environment variables:

    $ export PROJECT_ID=project-id
    $ export TPU_NAME=tpu-name
    $ export ZONE=us-central2-b
    $ export RUNTIME_VERSION=tpu-ubuntu2204-base
    $ export ACCELERATOR_TYPE=v4-32
    

Create the TPU VM

$ gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--project=${PROJECT_ID} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version ${RUNTIME_VERSION}

Configure and run the training script

  1. Add your SSH certificate to your project:

    ssh-add ~/.ssh/google_compute_engine
    
  2. Install PyTorch/XLA on all TPU VM workers

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="
      pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html" 
    
  3. Clone XLA on all TPU VM workers

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="git clone -b r2.2 https://github.com/pytorch/xla.git"
    
  4. Run the training script on all workers

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all \
      --command="PJRT_DEVICE=TPU python3 ~/xla/test/test_train_mp_imagenet.py  \
      --fake_data \
      --model=resnet50  \
      --num_epochs=1 2>&1 | tee ~/logs.txt"
      

    The training takes about 5 minutes. When it completes, you should see a message similar to the following:

    Epoch 1 test end 23:49:15, Accuracy=100.00
    10.164.0.11 [0] Max Accuracy: 100.00%
    

Clean up

When you are done with your TPU VM follow these steps to clean up your resources.

  1. Disconnect from the Compute Engine:

    (vm)$ exit
    
  2. Verify the resources have been deleted by running the following command. Make sure your TPU is no longer listed. The deletion might take several minutes.

    $ gcloud compute tpus tpu-vm list \
      --zone europe-west4-a