Cloud TPU PyTorch/XLA user guide

Run ML Workloads With PyTorch/XLA

Before starting the procedures in this guide, set up a TPU VM and ssh into it as described in Prepare a GCP Project. This will set up the resources needed to run the commands in this guide.

PyTorch 1.8.1 and PyTorch/XLA 1.8.1 are preinstalled on the TPU VM.

Basic setup

Set the XRT TPU device configuration:

   (vm)$ export XRT_TPU_CONFIG="localservice;0;localhost:51011"

For models that have sizeable, frequent allocations, tcmalloc improves performance significantly compared to the C/C++ runtime function malloc. The default malloc used on TPU VM is tcmalloc. You can force the TPU VM runtime to use the standard malloc by unsetting the LD_PRELOAD environment variable:

   (vm)$ unset LD_PRELOAD

Changing PyTorch version

If you don't want to use the PyTorch version preinstalled on TPU VMs, install the version you want to use. For example:

(vm)$ sudo bash /var/scripts/docker-login.sh
(vm)$ sudo docker rm libtpu || true
(vm)$ sudo docker create --name libtpu gcr.io/cloud-tpu-v2-images/libtpu:pytorch-1.9 "/bin/bash"
(vm)$ sudo docker cp libtpu:libtpu.so /lib
(vm)$ sudo pip3 uninstall --yes torch torch_xla torchvision
(vm)$ sudo pip3 install torch==1.9.0
(vm)$ sudo pip3 install torchvision==0.10.0
(vm)$ sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.9-cp38-cp38-linux_x86_64.whl

Perform a simple calculation

  1. Start the Python interpreter on the TPU VM:

    (vm)$ python3
    
  2. Import the following PyTorch packages:

    import torch
    import torch_xla.core.xla_model as xm
    
  3. Enter the following script:

    dev = xm.xla_device()
    t1 = torch.randn(3,3,device=dev)
    t2 = torch.randn(3,3,device=dev)
    print(t1 + t2)
    

    The following output is displayed:

    tensor([[-0.2121,  1.5589, -0.6951],
           [-0.7886, -0.2022,  0.9242],
           [ 0.8555, -1.8698,  1.4333]], device='xla:1')
    

Running Resnet on a single-device TPU

At this point, you can run any PyTorch / XLA code you please. For instance, you can run a ResNet model with fake data:

(vm)$ git clone --recursive https://github.com/pytorch/xla.git
(vm)$ python3 xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1

The ResNet sample trains for 1 epoch and takes about 7 minutes. It returns output similar to the following:

Epoch 1 test end 20:57:52, Accuracy=100.00 Max Accuracy: 100.00%

After the ResNet training ends, delete the TPU VM.

(vm)$ exit
$ gcloud alpha compute tpus tpu-vm delete tpu-name \
--zone=zone

The deletion might take several minutes. Verify the resources have been deleted by running gcloud alpha compute tpus list --zone=${ZONE}.

Advanced setup

In the above examples (the simple calculation and ResNet50), your PyTorch/XLA program will start the local XRT server in the same process as the Python interpreter. You can also choose to start the XRT local service in a separate process:

(vm)$ python3 -m torch_xla.core.xrt_run_server --port 51011 --restart

The advantage of this approach is that compilation cache will persist across training runs. When running the XLA server in a separate process, server side logging information is written to /tmp/xrt_server_log.

(vm)$ ls /tmp/xrt_server_log/
server_20210401-031010.log

TPU VM performance profiling

For more information about profiling your models on TPU VM, see PyTorch XLA performance profiling.

Pods

PyTorch/XLA requires all TPU VMs to be able to access the model code and data. One way to achieve this is to use the following startup script when you create the TPU VM pod. It performs the data downloading on all TPU VMs.

  1. In the Cloud shell, run the following command to make sure you are running the most current version of gcloud:

    $ sudo /opt/google-cloud-sdk/bin/gcloud components update
    
  2. Export the following environment variables:

    $ export PROJECT_ID=project-id
    $ export TPU_NAME=tpu-name
    $ export ZONE=zone
    $ export RUNTIME_VERSION=v2-alpha
    
  3. Create the TPU VM

    $ gcloud alpha compute tpus tpu-vm create ${TPU_NAME} \
    --zone ${ZONE} --project ${PROJECT_ID} --accelerator-type v3-32 \
    --version ${RUNTIME_VERSION} --metadata startup-script='#! /bin/bash
    cd /usr/share/
    git clone --recursive https://github.com/pytorch/pytorch
    cd pytorch/
    git clone --recursive https://github.com/pytorch/xla.git
    EOF'
    

    As you continue these instructions, run each command that begins with (vm)$ in your TPU VM shell.

  4. ssh to any TPU worker, for example, worker 0. Check to see if the data/model downloading has finished. The download is usually very quick, however, file processing can take several minutes. You can verify the file processing has completed by running ls -al in the download directory. Files are downloaded in the order you specified in the startup script. For example:

    download file1
    download file2
    download file3
    

    If file3 was the last file in the startup script, and it is displayed by the ls command, file processing has completed.

    Once file processing completes, generate ssh-keys to ssh between VM workers on the Pod. Then, start the training.

    (vm)$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
     --zone ${ZONE} \
     --project ${PROJECT_ID}
    

    Update the project ssh metadata:

    (vm)$ gcloud compute config-ssh
    
    (vm)$ export TPU_NAME=tpu-name
    
    (vm)$ python3 -m torch_xla.distributed.xla_dist \
     --tpu=${TPU_NAME} -- python3 /usr/share/pytorch/xla/test/test_train_mp_imagenet.py \
     --fake_data --model=resnet50 --num_epochs=1
    

    The training takes about 3 minutes. When it completes, you should see a message similar to the following:

    Epoch 1 test end 23:49:15, Accuracy=100.00
    10.164.0.11 [0] Max Accuracy: 100.00%
    

    You can find the server side logging under /tmp/xrt_server_log on each worker.

    (vm)$ ls /tmp/xrt_server_log/
    
    server_20210401-031010.log
    

    If you want to restart the XRT_SERVER (in case the server is in an unhealthy state) you can pass --restart-tpuvm-pod-server when running xla_dist. Note that new XRT server settings such as environment variables like LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 will not take effect until you restart the server.

  5. After the ResNet training ends, delete the TPU VM.

    (vm)$ exit
    
    $ gcloud alpha compute tpus tpu-vm delete tpu-name \
     --zone=zone
    
  6. Verify the resources have been deleted by running gcloud alpha compute tpus list --zone=${ZONE}. The deletion might take several minutes.

Pods with remote coordinator

We recommend using a remote coordinator to allow your Pod to auto-recover when a TPU maintenance event occurs. A remote coordinator is a standard Compute Engine VM, not a TPU VM. It issues commands to the TPU VMs in your Pod.

  1. Export the following environment variables:

    $ export TPU_NAME=tpu-name
    $ export ZONE=zone
    $ export PROJECT_ID=project-id
    
  2. Create a TPU VM Pod slice:

    $ gcloud alpha compute tpus tpu-vm create tpu-name \
     --zone europe-west4-a --project tpu-prod-env-one-vm --accelerator-type v3-32 \
     --version v2-alpha --metadata startup-script='#! /bin/bash
    cd /usr/share/
    git clone --recursive https://github.com/pytorch/pytorch
    cd pytorch/
    git clone --recursive https://github.com/pytorch/xla.git
    EOF'
    
  3. Export the following environment variables needed to create the remote coordinator VM:

    $ export VM_NAME=vm-name
    $ export ZONE=zone
    
  4. Create a remote coordinator VM by running:

    (vm)$ gcloud compute --project=project-id instances create vm-name \
     --zone=zone  \
     --machine-type=n1-standard-1  \
     --image-family=torch-xla \
     --image-project=ml-images  \
     --boot-disk-size=200GB \
     --scopes=https://www.googleapis.com/auth/cloud-platform
    

    When the gcloud compute command has finished executing, verify that your shell prompt has changed from username@projectname to username@vm-name. This change shows that you are now logged into the remote coordinator VM.

  5. ssh into the remote coordinator instance:

    (vm)$ gcloud compute ssh vm-name --zone=zone
    
  6. Activate the torch-xla-1.8.1 environment and run your training from there.

    (vm)$ gcloud compute config-ssh
    
    (vm)$ conda activate torch-xla-1.8.1
    (vm)$ python3 -m torch_xla.distributed.xla_dist \
     --tpu=tpu-name \
     --restart-tpuvm-pod \
     --env LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 -- python3 /usr/share/pytorch/xla/test/test_train_mp_imagenet.py \
     --fake_data \
     --model=resnet50 \
     --num_epochs=1
    

    The training takes about 3 minutes to run and generates a message similar to:

    Epoch 1 test end 23:19:53, Accuracy=100.00
    Max Accuracy: 100.00%
    

  7. After the ResNet training ends, exit the TPU VM and delete the remote coordinator VM and the TPU VM.

    (vm)$ exit
    
    $ gcloud compute instances delete vm-name  \
      --zone=zone
    
    $ gcloud alpha compute tpus tpu-vm delete tpu-name \
      --zone zone