Stay organized with collections Save and categorize content based on your preferences.

Run PyTorch code on TPU Pod slices

Set up a TPU VM Pod running PyTorch and run a calculation

PyTorch/XLA requires all TPU VMs to be able to access the model code and data. A startup script downloads the software needed to distribute the model data to all TPU VMs.

  1. In the Cloud Shell, run the following command to make sure you are running the most current version of gcloud:

    $ gcloud components update
    

    If you need to install gcloud, use the following command:

    $ sudo apt install -y google-cloud-sdk
  2. Export the following environment variables:

    $ export PROJECT_ID=project-id
    $ export TPU_NAME=tpu-name
    $ export ZONE=zone
    $ export RUNTIME_VERSION=tpu-vm-pt-1.13
    
  3. Create the TPU VM

    $ gcloud compute tpus tpu-vm create ${TPU_NAME} \
    --zone ${ZONE} --project ${PROJECT_ID} --accelerator-type v3-32 \
    --version ${RUNTIME_VERSION} --metadata startup-script='#! /bin/bash
    cd /usr/share/
    git clone --recursive https://github.com/pytorch/pytorch
    cd pytorch/
    git clone --recursive https://github.com/pytorch/xla.git
    EOF'
    

    The startup script generates the /usr/share/pytorch/xla directory on the TPU VM and downloads the model code to this directory. The download is quick, however, file processing can take several minutes. You can verify the file processing has completed by running ls -al in /usr/share/pytorch to see if the xla directory has been created.

  4. Once file processing completes, run the following ssh command to generate ssh-keys to ssh between VM workers on the Pod. Then, start the training.

    To ssh into other TPU VMs associated with the TPU Pod, append --worker ${WORKER_NUMBER} in the ssh command, where the WORKER_NUMBER is 0-based index. For more details, see the TPU VM user guide.

    $ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
     --zone ${ZONE} \
     --project ${PROJECT_ID}
    

    Update the project ssh metadata:

    (vm)$ gcloud compute config-ssh
    
    (vm)$ export TPU_NAME=tpu-name
    
    (vm)$ python3 -m torch_xla.distributed.xla_dist \
     --tpu=${TPU_NAME} -- python3 /usr/share/pytorch/xla/test/test_train_mp_imagenet.py \
     --fake_data --model=resnet50 --num_epochs=1
    

    The training takes about 3 minutes. When it completes, you should see a message similar to the following:

    Epoch 1 test end 23:49:15, Accuracy=100.00
    10.164.0.11 [0] Max Accuracy: 100.00%
    

    You can find the server-side logging in /tmp/xrt_server_log on each worker.

    (vm)$ ls /tmp/xrt_server_log/
    
    server_20210401-031010.log
    

    If you want to restart the XRT_SERVER (in case the server is in an unhealthy state) you can pass --restart-tpuvm-pod-server when running xla_dist. New XRT server settings such as environment variables like LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 do not take effect until you restart the server.

Clean up

When you are done with your TPU VM follow these steps to clean up your resources.

  1. Disconnect from the Compute Engine:

    (vm)$ exit
    
  2. Delete your Cloud TPU.

    $ gcloud compute tpus tpu-vm delete tpu-name \
      --zone europe-west4-a
    
  3. Verify the resources have been deleted by running the following command. Make sure your TPU is no longer listed. The deletion might take several minutes.

    $ gcloud compute tpus tpu-vm list \
      --zone europe-west4-a
    

Pods with remote coordinator

We recommend using a remote coordinator to allow your Pod to auto-recover when a TPU maintenance event occurs. A remote coordinator is a standard Compute Engine VM, not a TPU VM. It issues commands to the TPU VMs in your Pod.

  1. Export the following environment variables:

    $ export TPU_NAME=tpu-name
    $ export ZONE=zone
    $ export PROJECT_ID=project-id
    
  2. Create a TPU VM Pod slice:

    $ gcloud compute tpus tpu-vm create tpu-name \
     --zone europe-west4-a --project project-id --accelerator-type v3-32 \
     --version tpu-vm-pt-1.13 --metadata startup-script='#! /bin/bash
    cd /usr/share/
    git clone --recursive https://github.com/pytorch/pytorch
    cd pytorch/
    git clone --recursive https://github.com/pytorch/xla.git
    EOF'
    
  3. Export the following environment variables needed to create the remote coordinator (rc) VM:

    $ export RC_NAME=rc-name
    $ export ZONE=zone
    
  4. Create a remote coordinator VM by running:

    (vm)$ gcloud compute --project=project-id instances create rc-name \
     --zone=zone  \
     --machine-type=n1-standard-1  \
     --image-family=torch-xla \
     --image-project=ml-images  \
     --boot-disk-size=200GB \
     --scopes=https://www.googleapis.com/auth/cloud-platform
    

    When the gcloud compute command has finished executing, verify that your shell prompt has changed from username@projectname to username@vm-name. This change shows that you are now logged into the remote coordinator VM.

  5. ssh into the remote coordinator instance:

    (vm)$ gcloud compute ssh rc-name --zone=zone
    
  6. Activate the torch-xla-1.13 environment and run your training from there.

    (vm)$ gcloud compute config-ssh
    
    (vm)$ conda activate torch-xla-1.13
    (vm)$ python3 -m torch_xla.distributed.xla_dist \
     --tpu=tpu-name \
     --restart-tpuvm-pod \
     --env LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 -- python3 /usr/share/pytorch/xla/test/test_train_mp_imagenet.py \
     --fake_data \
     --model=resnet50 \
     --num_epochs=1
    

    The training takes about 3 minutes to run and generates a message similar to:

    Epoch 1 test end 23:19:53, Accuracy=100.00
    Max Accuracy: 100.00%
    

  7. After the ResNet training ends, exit the TPU VM and delete the remote coordinator VM and the TPU VM.

    (vm)$ exit
    
    $ gcloud compute instances delete rc-name  \
      --zone=zone
    
    $ gcloud compute tpus tpu-vm delete tpu-name \
      --zone zone