Training PyTorch models on Cloud TPU Pods

This tutorial shows how to scale up training your model from a single Cloud TPU (v2-8 or v3-8) to a Cloud TPU Pod. Cloud TPUs accelerators in a TPU Pod are connected by a very high bandwidth interconnects making them great at scaling up training jobs. For more information about the Cloud TPU Pods offerings refer to the Cloud TPU product page or to this Cloud TPU presentation.

The following diagram provides an overview of the distributed cluster setup and shows how training happens. In the single device setup, a single VM (client worker = CW) feeds a single TPU worker (service worker = SW). Similarly, in distributed training a cluster of VMs/CWs and also a corresponding TPU Pod slice (cluster of SWs) and each of the CWs feeds a single SW. The input pipeline runs on the CW and all the model training happens on the SW.



  • Set up a Compute Engine Instance Group and Cloud TPU Pod for training with PyTorch/XLA
  • Run PyTorch/XLA training on a Cloud TPU Pod

Before you begin

Before starting distributed training on Cloud TPU Pods, verify that your model trains on a single v2-8 or v3-8 Cloud TPU device. If your model has significant performance problems on a single device, refer to the best practices and troubleshooting guides.

Once your your single TPU device is successfully training, perform the following steps to set up and train on a Cloud TPU Pod:

  1. [Optional] Capture a VM disk image into a VM image.

  2. Create an Instance template from a VM image.

  3. Create an Instance Group from your Instance template.

  4. SSH into your Compute Engine VM

  5. Verify firewall rules to allow inter-VM communication..

  6. Create a Cloud TPU Pod.

  7. Run distributed training on the Pod.

  8. Clean up.

[Optional] Capture a VM disk image into a VM image

You can use the disk image from the VM you used to train the single TPU (that already has the dataset, packages installed, etc.) to create a VM Image that can be used for Pod training. You can skip this step if you do not require any additional packages or datasets for training and can just use a standard PyTorch/XLA image.

Create an Instance template from a VM image

Create a default instance template. When you are creating an instance template, you can use the VM Image you created in the above step OR you can use the public PyTorch/XLA image Google provides by following these instructions:

  1. On the create an instance template page, click the "Change" button under the "Boot Disk" Section.
  2. If you captured a VM disk image, click on the "Custom images" tab and select the image you captured. If you did not capture a VM disk image, select the public PyTorch/XLA image from the "OS images" pull down menu.
  3. Make sure that:
    1. Under Machine type, select n1-standard-16 for this example that uses ResNet-50 training. For your own model choose whatever VM size you used to train on a v3-8/v2-8.
    2. Under "Identity and API access" → "Access Scopes", select "Allow full access to all Cloud APIs".
  4. Click "Select" at the bottom to create your Instance template.

Create an Instance Group from your Instance template

From the Google Cloud Console, select Compute Engine > instance groups to access the Create an instance group form. When filling out the form, specify the following parameters for the configuration:

  1. Under "Name", specify a name for your Instance Group.
  2. Under "Location" section make sure to select "Single Zone".
  3. Under "Region" and "Zone" sections make sure to select the same zone that the TPU Pod will be created in.
  4. Under "Instance template" select the template you created in the previous step.
  5. Under "Autoscaling mode" select "Off". If you see an "Autoscaling mode" dropdown, select "Don't configure autoscaling".
  6. Under "Number of Instances": You will need N number instances where N equals the total number of cores you are using divided by 8, since each VM instance (CW) feeds 8 TPU cores. In this example, N equals 4, so you need 4 Instances.
  7. Under "Health check" select "No health check".
  8. Click "Select" at the bottom to create your Instance Group.

SSH into your Compute Engine VM

After creating your Instance Group, SSH into your Compute Engine VM to continue these instructions.

  1. From the Google Cloud Console select Compute Engine > instance groups and click on the Instance group you just created. This opens a window that displays all of the Instances in your Instance Group.

  2. At the end of the description line for one of your Instances, click the SSH button. This brings up a VM session window. Run each command that begins with (vm)$ in your VM session window.

Verify firewall rules to allow inter-VM communication

Verify that your Compute VMs can communicate with each other on port 8477 by checking the firewall rules for your project OR by running the nmap command from your Compute Engine VM.

  1. Run the nmap command. The variable instance-ID is one of the instances from your Instance Group. For example, instance-group-1-g0h2.

    (vm)$ nmap -Pn -p 8477 instance-ID
    Starting Nmap 7.40 ( ) at 2019-10-02 21:35 UTC
    Nmap scan report for pytorch-20190923-n4tx.c.jysohntpu.internal (
    Host is up (0.00034s latency).
    8477/tcp closed unknown

As long as the STATE field does not say filtered the firewall rules are set up correctly.

Create a Cloud TPU Pod

Go to the Google Cloud Console and select Compute Engine > TPUs. This brings up the page where you can create a TPU node and specify the following parameters:

  1. Under "Name", specify a name for your TPU Pod.
  2. Under "Zone" specify the zone to use for your Cloud TPU. Make sure it is in the same zone as your Instance Group.
  3. Under "TPU type", select the Cloud TPU type..
  4. Under "TPU software version" select the latest stable release (pytorch-0.X), for example pytorch-0.5.
  5. Use the default network.
  6. Set the IP address range. For example,

Run distributed training on the Pod

  1. From your VM session window, export the Cloud TPU name and activate the conda environment.

    (vm)$ export TPU_NAME=tpu-pod-name
    (vm)$ conda activate torch-xla-0.5
  2. Run the training script:

    (torch-xla-0.5)$ python -m torch_xla.distributed.xla_dist \
          --tpu=$TPU_NAME \
          --conda-env=torch-xla-0.5 \
          --env XLA_USE_BF16=1 \
          --env ANY_OTHER=ENV_VAR \
          -- \
          python /usr/share/torch-xla-0.5/pytorch/xla/test/ \

Once you run the above command you should see output similar to the following (note this is using --fake_data). The training takes about 1/2 hour on a v3-32 TPU Pod.

2019-09-12 18:24:07  [] Command to distribute: "python" "/usr/share/torch-xla-0.5/pytorch/xla/test/" "--fake_data"
2019-09-12 18:24:07  [] Cluster configuration: {client_workers: [{, n1-standard-64, europe-west4-a, tutorial-test-pytorch-pods-05lj}, {, n1-standard-64, europe-west4-a, tutorial-test-pytorch-pods-5v4b}, {, n1-standard-64, europe-west4-a, tutorial-test-pytorch-pods-8r6z}, {, n1-standard-64, europe-west4-a, tutorial-test-pytorch-pods-bhzk}], service_workers: [{, 8470, v3-32, europe-west4-a, pytorch-0.5}, {, 8470, v3-32, europe-west4-a, pytorch-0.5}, {, 8470, v3-32, europe-west4-a, pytorch-0.5}, {, 8470, v3-32, europe-west4-a, pytorch-0.5}]}
2019-09-12 18:24:29 [0] 2019-09-12 18:24:29.194350: I tensorflow/compiler/xla/xla_client/] XRT device (LOCAL) CPU:0 -> /job:c_tpu_worker/replica:0/task:0/device:XLA_CPU:0
2019-09-12 18:24:29 [0] 2019-09-12 18:24:29.194423: I tensorflow/compiler/xla/xla_client/] XRT device (REMOTE) CPU:1 -> /job:c_tpu_worker/replica:0/task:1/device:XLA_CPU:0
2019-09-12 18:24:29 [0] 2019-09-12 18:24:29.194431: I tensorflow/compiler/xla/xla_client/] XRT device (REMOTE) CPU:2 -> /job:c_tpu_worker/replica:0/task:2/device:XLA_CPU:0
2019-09-12 18:24:29 [0] 2019-09-12 18:24:29.194437: I tensorflow/compiler/xla/xla_client/] XRT device (REMOTE) CPU:3 -> /job:c_tpu_worker/replica:0/task:3/device:XLA_CPU:0
2019-09-12 18:24:29 [0] 2019-09-12 18:24:29.194443: I tensorflow/compiler/xla/xla_client/] XRT device (LOCAL) TPU:0 -> /job:c_tpu_worker/replica:0/task:0/device:TPU:0
2019-09-12 18:24:29 [0] 2019-09-12 18:24:29.194448: I tensorflow/compiler/xla/xla_client/] XRT device (LOCAL) TPU:1 -> /job:c_tpu_worker/replica:0/task:0/device:TPU:1
2019-09-12 18:24:29 [0] 2019-09-12 18:24:29.194454: I tensorflow/compiler/xla/xla_client/] XRT device (REMOTE) TPU:10 -> /job:c_tpu_worker/replica:0/task:1/device:TPU:2
2019-09-12 18:25:14 [0] mesh_shape: 4
2019-09-12 18:25:14 [0] mesh_shape: 2
2019-09-12 18:25:14 [0] num_tasks: 4
2019-09-12 18:25:14 [0] num_tpu_devices_per_task: 8
2019-09-12 18:25:14 [0] device_coordinates: 2
2019-09-12 18:25:14 [0] device_coordinates: 3
2019-09-12 18:25:14 [0] device_coordinates: 0
2019-09-12 18:25:14 [0] device_coordinates: 2
2019-09-12 18:31:36 [2] [xla:2](0) Loss=0.00000 Rate=142.92
2019-09-12 18:31:36 [2] [xla:1](0) Loss=0.00000 Rate=116.86
2019-09-12 18:31:36 [2] [xla:6](0) Loss=0.00000 Rate=114.17
2019-09-12 18:31:36 [2] [xla:4](0) Loss=0.00000 Rate=112.40
2019-09-12 18:31:36 [2] [xla:5](0) Loss=0.00000 Rate=101.27
2019-09-12 18:31:36 [2] [xla:3](0) Loss=0.00000 Rate=97.01
2019-09-12 18:31:36 [2] [xla:7](0) Loss=0.00000 Rate=99.72
2019-09-12 18:31:36 [2] [xla:8](0) Loss=0.00000 Rate=98.05
2019-09-12 18:31:36 [2] [xla:4](20) Loss=0.00000 Rate=314.58
2019-09-12 18:31:36 [2] [xla:3](20) Loss=0.00000 Rate=316.00
2019-09-12 18:31:36 [2] [xla:7](20) Loss=0.00000 Rate=317.12
2019-09-12 18:31:36 [2] [xla:2](20) Loss=0.00000 Rate=314.21
2019-09-12 18:31:36 [2] [xla:6](20) Loss=0.00000 Rate=314.27
2019-09-12 18:31:36 [2] [xla:1](20) Loss=0.00000 Rate=311.75
2019-09-12 18:31:36 [2] [xla:5](20) Loss=0.00000 Rate=314.76
2019-09-12 18:31:36 [2] [xla:8](20) Loss=0.00000 Rate=316.58
2019-09-12 18:31:36 [2] [xla:6](40) Loss=0.00000 Rate=423.79
2019-09-12 18:31:36 [2] [xla:8](40) Loss=0.00000 Rate=425.21
2019-09-12 18:31:36 [2] [xla:5](40) Loss=0.00000 Rate=424.41
2019-09-12 18:31:36 [2] [xla:3](40) Loss=0.00000 Rate=423.38
2019-09-12 18:31:36 [2] [xla:2](40) Loss=0.00000 Rate=423.19
2019-09-12 18:31:36 [2] [xla:7](40) Loss=0.00000 Rate=424.01
2019-09-12 18:31:36 [2] [xla:4](40) Loss=0.00000 Rate=422.16
2019-09-12 18:31:36 [2] [xla:1](40) Loss=0.00000 Rate=422.09
2019-09-12 18:31:36 [2] [xla:1](60) Loss=0.00000 Rate=472.79
2019-09-12 18:31:36 [2] [xla:3](60) Loss=0.00000 Rate=472.50
2019-09-12 18:31:36 [2] [xla:6](60) Loss=0.00000 Rate=471.26
2019-09-12 18:31:36 [2] [xla:8](60) Loss=0.00000 Rate=472.04
2019-09-12 18:31:36 [2] [xla:4](60) Loss=0.00000 Rate=471.01
2019-09-12 18:31:36 [2] [xla:5](60) Loss=0.00000 Rate=471.77
2019-09-12 18:31:36 [2] [xla:2](60) Loss=0.00000 Rate=471.16
2019-09-12 18:31:36 [2] [xla:7](60) Loss=0.00000 Rate=471.25
2019-09-12 18:31:36 [2] [xla:1](80) Loss=0.00000 Rate=496.79
2019-09-12 18:31:36 [2] [xla:4](80) Loss=0.00000 Rate=496.45
2019-09-12 18:31:36 [2] [xla:6](80) Loss=0.00000 Rate=496.14
2019-09-12 18:31:36 [2] [xla:2](80) Loss=0.00000 Rate=496.45
2019-09-12 18:31:36 [2] [xla:3](80) Loss=0.00000 Rate=495.66
2019-09-12 18:31:36 [2] [xla:8](80) Loss=0.00000 Rate=496.43
2019-09-12 18:31:36 [2] [xla:5](80) Loss=0.00000 Rate=496.22
2019-09-12 18:31:36 [2] [xla:7](80) Loss=0.00000 Rate=496.37
2019-09-12 18:31:36 [2] [xla:5](100) Loss=0.00000 Rate=503.71
2019-09-12 18:31:36 [2] [xla:2](100) Loss=0.00000 Rate=503.13
2019-09-12 18:31:36 [2] [xla:3](100) Loss=0.00000 Rate=502.50

Cleaning up

To avoid incurring charges to your Google Cloud Platform account for the resources used in this tutorial:

Exit from the Compute Engine VM and delete:

  1. The Instance Group you created under Compute Engine > Instance Groups

  2. The TPU Pod under Compute Engine > TPUs.

What's next

Try the PyTorch colabs: