Cloud TPU PyTorch/XLA user guide

Run ML Workloads With PyTorch/XLA

Before starting the procedures in this guide, set up a TPU VM and ssh into it as described in Prepare a GCP Project. This will set up the resources needed to run the commands in this guide.

PyTorch 1.8.1 and PyTorch/XLA 1.8.1 are preinstalled on the TPU VM.

Basic setup

Set the XRT TPU device configuration:

   (vm)$ export XRT_TPU_CONFIG="localservice;0;localhost:51011"

For models which have sizeable, frequent allocations using, tcmalloc improves performance significantly compared to the default malloc, so the default malloc used on TPU VM is tcmalloc. However, depending on your workload (for example, with DLRM which has very large allocations for its embedding tables) tcmalloc might cause a slowdown in which case you might try to unset the following variable using the default malloc instead:

   (vm)$ unset LD_PRELOAD

Perform a simple calculation

Start the Python 3 interpreter:

(vm)$ python3
   Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
   [GCC 8.4.0] on linux
   Type "help", "copyright", "credits" or "license" for more information.

From the Pytorch 3 interpreter, import the following PyTorch packages:

import torch
import torch_xla.core.xla_model as xm

Perform the following calculations:

dev = xm.xla_device()
The output of this command will be similar to the following:
2021-04-01 23:20:23.268115: E   55362 tensorflow/core/framework/op_kernel.cc:1693] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2021-04-01 23:20:23.269345: E   55362 tensorflow/core/framework/op_kernel.cc:1693] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
t1 = torch.randn(3,3,device=dev)
t2 = torch.randn(3,3,device=dev)
print(t1 + t2)

You will see the following output from the command:

tensor([[-0.2121,  1.5589, -0.6951],
        [-0.7886, -0.2022,  0.9242],
        [ 0.8555, -1.8698,  1.4333]], device='xla:1')

Running Resnet on a single-device TPU

At this point, you can run any PyTorch / XLA code you please! For instance, you can run a ResNet with fake data using:

   (vm)$ git clone --recursive https://github.com/pytorch/xla.git
   (vm)$ python3 xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1

The ResNet sample trains for 1 epoch and takes about 7 minutes. It returns output similar to the following:

Epoch 1 test end 20:57:52, Accuracy=100.00 Max Accuracy: 100.00%

After the ResNet training ends, delete the TPU VM.

   (vm)$ exit
   $ gcloud alpha compute tpus tpu-vm delete tpu-name \
   --zone=zone

Verify the resources have been deleted by running gcloud list. The deletion might take several minutes. A response like the one below indicates your instances have been successfully deleted.

   $ gcloud alpha compute tpus list --zone=zone
   

   Listed 0 items.
   

Advanced setup

In the above examples (the simple calculation and ResNet50), your PyTorch/XLA program will start the local XRT server in the same process. You can also choose to start the XRT local service in a separate process:

(vm)$ python3 -m torch_xla.core.xrt_run_server --port 51011 --restart

The advantage of this approach is that compilation cache will persist across training. In this case you can find your server side logging under /tmp/xrt_server_log.

(vm)$ ls /tmp/xrt_server_log/
server_20210401-031010.log

TPU VM performance profiling

For more information about profiling your models on TPU VM, see PyTorch XLA performance profiling.

Pods

PyTorch/XLA requires all TPU VMs to be able to access the model code and data. One easy way to achieve this is to use the following startup script when creating the TPU VM pod. It performs the data downloading on all TPU VMs.

  1. In the Cloud shell, run the following command to make sure you are running the most current version of gcloud:

    $ sudo /opt/google-cloud-sdk/bin/gcloud components update
    
  2. Export the following environment variables needed to create the Pod TPU VM:

    $ export PROJECT_ID=project-id
    $ export TPU_NAME=tpu-name
    $ export ZONE=zone
    $ export RUNTIME_VERSION=v2-alpha
    
  3. Create the TPU VM

    $ gcloud alpha compute tpus tpu-vm create ${TPU_NAME} \
    --zone ${ZONE} --project ${PROJECT_ID} --accelerator-type v3-32 \
    --version ${RUNTIME_VERSION} --metadata startup-script='#! /bin/bash
    cd /usr/share/
    git clone --recursive https://github.com/pytorch/pytorch
    cd pytorch/
    git clone --recursive https://github.com/pytorch/xla.git
    EOF'
    

    As you continue these instructions, run each command that begins with (vm)$ in your VM session window.

  4. ssh to any TPU worker, for example, worker 0, check to see whether data/model downloading is finished (this takes , and start the training after generating the ssh-keys to ssh between VM workers on a Pod.

    (vm)$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone ${ZONE} --project ${PROJECT_ID}
    

    Update the project ssh metadata:

    (vm)$ gcloud compute config-ssh
    
    (vm)$ export TPU_NAME=tpu-name
    
    (vm)$ python3 -m torch_xla.distributed.xla_dist \
       --tpu=${TPU_NAME} -- python3 /usr/share/pytorch/xla/test/test_train_mp_imagenet.py \
       --fake_data --model=resnet50 --num_epochs=1
    

    The training takes about 3 minutes. When it completes, you should see a message similar to the following:

    Epoch 1 test end 23:49:15, Accuracy=100.00
    10.164.0.11 [0] Max Accuracy: 100.00%
    

    You can find the server side logging under /tmp/xrt_server_log on each worker.

       (vm)$ ls /tmp/xrt_server_log/
    
    server_20210401-031010.log
    

    If you want to restart the XRT_SERVER(in case that server was been left in an unhealthy state) you can pass --restart-tpuvm-pod-server when running xla_dist. Note that new XRT server settings such as environment variables like LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 will not be picked up until you restart the server.

  5. After the ResNet training ends, delete the TPU VM.

    (vm)$ exit
    
    $ gcloud alpha compute tpus tpu-vm delete tpu-name \
    --zone=zone
    
  6. Verify the resources have been deleted by running gcloud list. The deletion might take several minutes. A response like the one below indicates your instances have been successfully deleted.

    $ gcloud alpha compute tpus list --zone=zone
    
    Listed 0 items.
    

Pods with remote coordinator

Best practice is to set up a remote coordinator to use to start your Pod training. The remote coordinator is a standard Compute Engine VM, not a TPU VM. The remote coordinator issues commands to the TPU VM Pod. The advantage to using a remote coordinator is that when TPU maintenance events happen, training might halt and you could lose access to the TPU VM Pod, but the remote coordinator can enable the TPU VM to auto recover.

  1. Export the following environment variables needed to create the TPU VM:

    $ export TPU_NAME=tpu-name
    $ export ZONE=zone
    $ export PROJECT_ID=project-id
    
  2. Create a TPU VM Pod slice

    You create a TPU Pod slice using the gcloud command. For example, to create a v2-32 Pod slice use the following command:

    $ gcloud alpha compute tpus tpu-vm create tpu-name \
    --zone europe-west4-a --project tpu-prod-env-one-vm --accelerator-type v3-32 \
    --version v2-alpha --metadata startup-script='#! /bin/bash
    cd /usr/share/
    git clone --recursive https://github.com/pytorch/pytorch
    cd pytorch/
    git clone --recursive https://github.com/pytorch/xla.git
    EOF'
    
  3. Export the following environment variables needed to create the remote coordinator VM:

    $ export VM_NAME=vm-name
    $ export ZONE=zone
    
  4. Create a remote coordinator VM by running:

    (vm)$ gcloud compute --project=project-id instances create vm-name \
      --zone=zone  \
      --machine-type=n1-standard-1  \
      --image-family=torch-xla \
      --image-project=ml-images  \
      --boot-disk-size=200GB \
      --scopes=https://www.googleapis.com/auth/cloud-platform
    

    When the gcloud compute command has finished executing, verify that your shell prompt has changed from username@projectname to username@vm-name. This change shows that you are now logged into the remote coordinator VM.

  5. ssh into the remote coordinator instance,:

    (vm)$ gcloud compute ssh vm-name --zone=zone
    
  6. Activate the torch-xla-1.8.1 environment and run your training from there.

    (vm)$ gcloud compute config-ssh
    
    (vm)$ conda activate torch-xla-1.8.1
    (vm)$ python3 -m torch_xla.distributed.xla_dist --tpu=tpu-name --restart-tpuvm-pod --env LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 -- python3 /usr/share/pytorch/xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1
    

    The training takes about 3 minutes to run and generates a message similar to:

      Epoch 1 test end 23:19:53, Accuracy=100.00
      Max Accuracy: 100.00%
      

  7. After the ResNet training ends, exit the TPU VM and delete the remote coordinator VM and the TPU VM.

    (vm)$ exit
    
    $ gcloud compute instances delete vm-name  \
      --zone=zone
    
    $ gcloud alpha compute tpus tpu-vm delete tpu-name \
      --zone zone