Run PyTorch code on TPU Pod slices

Before running the commands in this document, make sure you have followed the instructions in Set up an account and Cloud TPU project.

After you have your PyTorch code running on a single TPU VM, you can scale up your code by running it on a TPU Pod slice. TPU Pod slices are multiple TPU boards connected to each other over dedicated high-speed network connections. This document is an introduction to running PyTorch code on TPU Pod slices.

Create a Cloud TPU Pod slice

  1. Define some environment variables to make the commands easier to use.

    export PROJECT_ID=your-project
    export ACCELERATOR_TYPE=v5p-32
    export ZONE=europe-west4-b
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export TPU_NAME=your-tpu-name

    Environment variable descriptions

    PROJECT_ID
    Your Google Cloud project ID.
    ACCELERATOR_TYPE
    The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions.
    ZONE
    The zone where you plan to create your Cloud TPU.
    RUNTIME_VERSION
    The Cloud TPU software version.
    TPU_NAME
    The user-assigned name for your Cloud TPU.
  2. Create your TPU VM by running the following command:

    $ gcloud compute tpus tpu-vm create $TPU_NAME \
    --zone=$ZONE \
    --project=$PROJECT_ID \
    --accelerator-type=$ACCELERATOR_TYPE \
    --version=$RUNTIME_VERSION

Install PyTorch/XLA on your Pod slice

After creating the TPU Pod slice, you must install PyTorch on all hosts in the TPU Pod slice. You can do this using the gcloud compute tpus tpu-vm ssh command using the --worker=all and --commamnd parameters.

  1. Install PyTorch/XLA on all TPU VM workers

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all \
      --command="pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
  2. Clone XLA on all TPU VM workers

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="git clone https://github.com/pytorch/xla.git"

Run a training script on your TPU Pod slice

Run the training script on all workers:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --worker=all \
   --command="PJRT_DEVICE=TPU python3 ~/xla/test/test_train_mp_imagenet.py  \
   --fake_data \
   --model=resnet50  \
   --num_epochs=1 2>&1 | tee ~/logs.txt"

The training takes about 5 minutes. When it completes, you should see a message similar to the following:

Epoch 1 test end 23:49:15, Accuracy=100.00
     10.164.0.11 [0] Max Accuracy: 100.00%

Clean up

When you are done with your TPU VM follow these steps to clean up your resources.

  1. Disconnect from the Compute Engine instance, if you have not already done so:

    (vm)$ exit

    Your prompt should now be username@projectname, showing you are in the Cloud Shell.

  2. Delete your Cloud TPU and Compute Engine resources.

    $ gcloud compute tpus tpu-vm delete  \
      --zone=${ZONE}
  3. Verify the resources have been deleted by running gcloud compute tpus execution-groups list. The deletion might take several minutes. The output from the following command shouldn't include any of the resources created in this tutorial:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE}