Training PyTorch models on Cloud TPU Pods

This tutorial shows how to scale up training your model from a single Cloud TPU (v2-8 or v3-8) to a Cloud TPU Pod. Cloud TPUs accelerators in a TPU Pod are connected by very high bandwidth interconnects making them great at scaling up training jobs.

For more information about the Cloud TPU Pods offerings refer to the Cloud TPU product page or to this Cloud TPU presentation.

The following diagram provides an overview of the distributed cluster setup. In a single device setup, a single VM (also called a client worker or CW) feeds a single TPU worker (also called a service worker or SW).

In distributed training a cluster of CWs and a corresponding TPU Pod slice (cluster of SWs) feed a single SW. The input pipeline runs on the CW and all the model training happens on the SWs.



  • Set up a Compute Engine Instance Group and Cloud TPU Pod for training with PyTorch/XLA
  • Run PyTorch/XLA training on a Cloud TPU Pod

Before you begin

Before starting distributed training on Cloud TPU Pods, verify that your model trains on a single v2-8 or v3-8 Cloud TPU device. If your model has significant performance problems on a single device, refer to the best practices and troubleshooting guides.

Once your your single TPU device is successfully training, perform the following steps to set up and train on a Cloud TPU Pod:

  1. [Optional] Capture a VM disk image into a VM image.

  2. Create an Instance template from a VM image.

  3. Create an Instance Group from your Instance template.

  4. SSH into your Compute Engine VM

  5. Verify firewall rules to allow inter-VM communication..

  6. Create a Cloud TPU Pod.

  7. Run distributed training on the Pod.

  8. Clean up.

[Optional] Capture a VM disk image into a VM image

You can use the disk image from the VM you used to train the single TPU (that already has the dataset, packages installed, etc.) to create a VM Image that can be used for Pod training. You can skip this step if you do not require any additional packages or datasets for training and can use a standard PyTorch/XLA image.

Create an Instance template from a VM image

Create a default instance template. When you are creating an instance template, you can use the VM image you created in the above step OR you can use the public PyTorch/XLA image Google provides by following these instructions:

  1. From the Google Cloud Console page, select Compute Engine > Instance templates.
  2. On the Instance templates page, click Create Instance Template.
  3. On the Create an instance template page, click the Change button under the Boot Disk section.
  4. If you captured a VM disk image, click on the Custom images tab and select the image you captured. If you did not capture a VM disk image, select Deep Learning on Linux from the Operating system pull down menu. Select Debian GNU/Linux 9 Stretch + PyTorch/XLA from the Version pull down menu.
  5. Make sure that:
    1. Under Machine type, select n1-standard-16 for this example that uses ResNet-50 training. For your own model choose whatever VM size you used to train on a v3-8/v2-8.
    2. Under Identity and API access → Access Scopes, select Allow full access to all Cloud APIs.
  6. Click Create at the bottom to create your Instance template.

Create an Instance Group from your Instance template

  1. From the Google Cloud Console, select Instance groups and click Create Instance Group.
  2. Under Name, specify a name for your Instance Group.
  3. Under Location section make sure to select Single Zone.
  4. Under Region and Zone sections make sure to select the same zone that the TPU Pod will be created in.
  5. Under Instance template select the template you created in the previous step.
  6. Under Number of Instances: You will need N instances where N equals the total number of Cloud TPU cores you are using divided by 8. In this example you are using a v2-32 Pod, so 32 cores divided by 8 gives you 4. You need 4 instances.
  7. Under Autoscaling mode select Off. If you see an Autoscaling mode dropdown, select Don't autoscale.
  8. Under Health check select No health check.
  9. Click Create at the bottom to create your instance group.

SSH into your Compute Engine VM

After creating your Instance Group, SSH into your Compute Engine VM to continue these instructions.

  1. From the Google Cloud Console select Compute Engine > Instance Groups and click on the instance group you just created. This opens a window that displays all of the instances in your instance group.

  2. At the end of the description line for one of your instances, click the SSH button. This brings up a VM session window. Run each command that begins with (vm)$ in your VM session window.

Verify firewall rules to allow inter-VM communication

Verify that your Compute VMs can communicate with each other on port 8477 by checking the firewall rules for your project OR by running the nmap command from your Compute Engine VM.

  1. Run the nmap command. Replace instance-ID with one of the instance IDs from your instance group. For example, instance-group-1-g0h2.

    (vm)$ nmap -Pn -p 8477 instance-ID
       Starting Nmap 7.40 ( ) at 2019-10-02 21:35 UTC
       Nmap scan report for pytorch-20190923-n4tx.c.jysohntpu.internal (
       Host is up (0.00034s latency).
       8477/tcp closed unknown

As long as the STATE field does not say filtered the firewall rules are set up correctly.

Create a Cloud TPU Pod

  1. Go to the Google Cloud Console and select Compute Engine > TPUs.
  2. Click Create TPU Node.
  3. Under Name, specify a name for your TPU Pod.
  4. Under Zone specify the zone to use for your Cloud TPU. Make sure it is in the same zone as your instance group.
  5. Under TPU type, select the Cloud TPU type.. The type of Cloud TPU you select must have 8 times the number of instaces in your instance group.
  6. Under TPU software version select the latest stable release, for example pytorch-1.5.
  7. Use the default network.
  8. Set the IP address range. For example,
  9. Click Create to create the TPU Pod.

Run distributed training on the Pod

  1. From your VM session window, export the Cloud TPU name and activate the conda environment.

    (vm)$ export TPU_NAME=tpu-pod-name
    (vm)$ conda activate torch-xla-1.5
  2. Run the training script:

    (torch-xla-1.5)$ python -m torch_xla.distributed.xla_dist \
          --tpu=$TPU_NAME \
          --conda-env=torch-xla-1.5 \
          --env XLA_USE_BF16=1 \
          --env ANY_OTHER=ENV_VAR \
          -- \
          python /usr/share/torch-xla-1.5/pytorch/xla/test/ \

Once you run the above command you should see output similar to the following (note this is using --fake_data). The training takes about 1/2 hour on a v3-32 TPU Pod.

2019-09-12 18:24:07  [] Command to distribute: "python" "/usr/share/torch-xla-1.5/pytorch/xla/test/" "--fake_data"
2019-09-12 18:24:07  [] Cluster configuration: {client_workers: [{, n1-standard-64, europe-west4-a, tutorial-test-pytorch-pods-05lj}, {, n1-standard-64, europe-west4-a, tutorial-test-pytorch-pods-5v4b}, {, n1-standard-64, europe-west4-a, tutorial-test-pytorch-pods-8r6z}, {, n1-standard-64, europe-west4-a, tutorial-test-pytorch-pods-bhzk}], service_workers: [{, 8470, v3-32, europe-west4-a, pytorch-1.5}, {, 8470, v3-32, europe-west4-a, pytorch-1.5}, {, 8470, v3-32, europe-west4-a, pytorch-1.5}, {, 8470, v3-32, europe-west4-a, pytorch-1.5}]}
2019-09-12 18:24:29 [0] 2019-09-12 18:24:29.194350: I tensorflow/compiler/xla/xla_client/] XRT device (LOCAL) CPU:0 -> /job:c_tpu_worker/replica:0/task:0/device:XLA_CPU:0
2019-09-12 18:24:29 [0] 2019-09-12 18:24:29.194423: I tensorflow/compiler/xla/xla_client/] XRT device (REMOTE) CPU:1 -> /job:c_tpu_worker/replica:0/task:1/device:XLA_CPU:0
2019-09-12 18:24:29 [0] 2019-09-12 18:24:29.194431: I tensorflow/compiler/xla/xla_client/] XRT device (REMOTE) CPU:2 -> /job:c_tpu_worker/replica:0/task:2/device:XLA_CPU:0
2019-09-12 18:24:29 [0] 2019-09-12 18:24:29.194437: I tensorflow/compiler/xla/xla_client/] XRT device (REMOTE) CPU:3 -> /job:c_tpu_worker/replica:0/task:3/device:XLA_CPU:0
2019-09-12 18:24:29 [0] 2019-09-12 18:24:29.194443: I tensorflow/compiler/xla/xla_client/] XRT device (LOCAL) TPU:0 -> /job:c_tpu_worker/replica:0/task:0/device:TPU:0
2019-09-12 18:24:29 [0] 2019-09-12 18:24:29.194448: I tensorflow/compiler/xla/xla_client/] XRT device (LOCAL) TPU:1 -> /job:c_tpu_worker/replica:0/task:0/device:TPU:1
2019-09-12 18:24:29 [0] 2019-09-12 18:24:29.194454: I tensorflow/compiler/xla/xla_client/] XRT device (REMOTE) TPU:10 -> /job:c_tpu_worker/replica:0/task:1/device:TPU:2
2019-09-12 18:25:14 [0] mesh_shape: 4
2019-09-12 18:25:14 [0] mesh_shape: 2
2019-09-12 18:25:14 [0] num_tasks: 4
2019-09-12 18:25:14 [0] num_tpu_devices_per_task: 8
2019-09-12 18:25:14 [0] device_coordinates: 2
2019-09-12 18:25:14 [0] device_coordinates: 3
2019-09-12 18:25:14 [0] device_coordinates: 0
2019-09-12 18:25:14 [0] device_coordinates: 2
2019-09-12 18:31:36 [2] [xla:2](0) Loss=0.00000 Rate=142.92
2019-09-12 18:31:36 [2] [xla:1](0) Loss=0.00000 Rate=116.86
2019-09-12 18:31:36 [2] [xla:6](0) Loss=0.00000 Rate=114.17
2019-09-12 18:31:36 [2] [xla:4](0) Loss=0.00000 Rate=112.40
2019-09-12 18:31:36 [2] [xla:5](0) Loss=0.00000 Rate=101.27
2019-09-12 18:31:36 [2] [xla:3](0) Loss=0.00000 Rate=97.01
2019-09-12 18:31:36 [2] [xla:7](0) Loss=0.00000 Rate=99.72
2019-09-12 18:31:36 [2] [xla:8](0) Loss=0.00000 Rate=98.05
2019-09-12 18:31:36 [2] [xla:4](20) Loss=0.00000 Rate=314.58
2019-09-12 18:31:36 [2] [xla:3](20) Loss=0.00000 Rate=316.00
2019-09-12 18:31:36 [2] [xla:7](20) Loss=0.00000 Rate=317.12
2019-09-12 18:31:36 [2] [xla:2](20) Loss=0.00000 Rate=314.21
2019-09-12 18:31:36 [2] [xla:6](20) Loss=0.00000 Rate=314.27
2019-09-12 18:31:36 [2] [xla:1](20) Loss=0.00000 Rate=311.75
2019-09-12 18:31:36 [2] [xla:5](20) Loss=0.00000 Rate=314.76
2019-09-12 18:31:36 [2] [xla:8](20) Loss=0.00000 Rate=316.58
2019-09-12 18:31:36 [2] [xla:6](40) Loss=0.00000 Rate=423.79
2019-09-12 18:31:36 [2] [xla:8](40) Loss=0.00000 Rate=425.21
2019-09-12 18:31:36 [2] [xla:5](40) Loss=0.00000 Rate=424.41
2019-09-12 18:31:36 [2] [xla:3](40) Loss=0.00000 Rate=423.38
2019-09-12 18:31:36 [2] [xla:2](40) Loss=0.00000 Rate=423.19
2019-09-12 18:31:36 [2] [xla:7](40) Loss=0.00000 Rate=424.01
2019-09-12 18:31:36 [2] [xla:4](40) Loss=0.00000 Rate=422.16
2019-09-12 18:31:36 [2] [xla:1](40) Loss=0.00000 Rate=422.09
2019-09-12 18:31:36 [2] [xla:1](60) Loss=0.00000 Rate=472.79
2019-09-12 18:31:36 [2] [xla:3](60) Loss=0.00000 Rate=472.50
2019-09-12 18:31:36 [2] [xla:6](60) Loss=0.00000 Rate=471.26
2019-09-12 18:31:36 [2] [xla:8](60) Loss=0.00000 Rate=472.04
2019-09-12 18:31:36 [2] [xla:4](60) Loss=0.00000 Rate=471.01
2019-09-12 18:31:36 [2] [xla:5](60) Loss=0.00000 Rate=471.77
2019-09-12 18:31:36 [2] [xla:2](60) Loss=0.00000 Rate=471.16
2019-09-12 18:31:36 [2] [xla:7](60) Loss=0.00000 Rate=471.25
2019-09-12 18:31:36 [2] [xla:1](80) Loss=0.00000 Rate=496.79
2019-09-12 18:31:36 [2] [xla:4](80) Loss=0.00000 Rate=496.45
2019-09-12 18:31:36 [2] [xla:6](80) Loss=0.00000 Rate=496.14
2019-09-12 18:31:36 [2] [xla:2](80) Loss=0.00000 Rate=496.45
2019-09-12 18:31:36 [2] [xla:3](80) Loss=0.00000 Rate=495.66
2019-09-12 18:31:36 [2] [xla:8](80) Loss=0.00000 Rate=496.43
2019-09-12 18:31:36 [2] [xla:5](80) Loss=0.00000 Rate=496.22
2019-09-12 18:31:36 [2] [xla:7](80) Loss=0.00000 Rate=496.37
2019-09-12 18:31:36 [2] [xla:5](100) Loss=0.00000 Rate=503.71
2019-09-12 18:31:36 [2] [xla:2](100) Loss=0.00000 Rate=503.13
2019-09-12 18:31:36 [2] [xla:3](100) Loss=0.00000 Rate=502.50

Cleaning up

To avoid incurring charges to your Google Cloud Platform account for the resources used in this tutorial:

Exit from the Compute Engine VM and delete:

  1. The Instance Group you created under Compute Engine > Instance Groups.

  2. The TPU Pod under Compute Engine > TPUs.

What's next

Try the PyTorch colabs: