This tutorial shows how to scale up training your model from a single Cloud TPU (v2-8 or v3-8) to a Cloud TPU Pod using the TPU Node configuration. Cloud TPU accelerators in a TPU Pod are connected by high bandwidth interconnects making them efficient at scaling up training jobs.
For more information about the Cloud TPU Pods offerings refer to the Cloud TPU product page or to this Cloud TPU presentation.
The following diagram provides an overview of the distributed cluster setup. An instance group of VMs is connected to a TPU pod. One VM is needed for each group of 8 TPU cores. The VMs feed data to the TPU cores and all training occurs on the TPU Pod.
Objectives
- Set up a Compute Engine Instance Group and Cloud TPU Pod for training with PyTorch/XLA
- Run PyTorch/XLA training on a Cloud TPU Pod
Before you begin
Before starting distributed training on Cloud TPU Pods, verify that your model trains on a single v2-8 or v3-8 Cloud TPU device. If your model has significant performance problems on a single device, refer to the best practices and troubleshooting guides.
Once your single TPU device is successfully training, perform the following steps to set up and train on a Cloud TPU Pod:
Configure the gcloud
command
Configure your Google Cloud project with gcloud
:
Create a variable for your project's ID.
export PROJECT_ID=project-id
Set your project ID as the default project in gcloud
gcloud config set project ${PROJECT_ID}
The first time you run this command in a new Cloud Shell VM, an
Authorize Cloud Shell
page is displayed. Click Authorize
at the bottom
of the page to allow gcloud
to make API calls with your credentials.
Configure the default zone with gcloud
:
gcloud config set compute/zone europe-west4-a
[Optional] Capture a VM disk image
You can use the disk image from the VM you used to train the single TPU
(that already has the dataset, packages installed, etc.). Before creating an
image, stop the VM using the gcloud
command:
gcloud compute instances stop vm-name
Next, create a VM Image using the gcloud
command:
gcloud compute images create image-name \ --source-disk instance-name \ --source-disk-zone europe-west4-a \ --family=torch-xla \ --storage-location europe-west4
Create an Instance template from a VM image
Create a default instance template. When you
are creating an instance template, you can use the VM image you created in the
above step OR you can use the public PyTorch/XLA image Google provides. To
create an instance template, use the gcloud
command:
gcloud compute instance-templates create instance-template-name \ --machine-type n1-standard-16 \ --image-project=${PROJECT_ID} \ --image=image-name \ --scopes=https://www.googleapis.com/auth/cloud-platform
Create an Instance Group from your Instance template
gcloud compute instance-groups managed create instance-group-name \ --size 4 \ --template template-name \ --zone europe-west4-a
SSH into your Compute Engine VM
After creating your Instance Group, SSH into one of the instances (VMs) in your
instance group. Use the following command to list all instances in your instance
grouping the gcloud
command:
gcloud compute instance-groups list-instances instance-group-name
SSH into one of the instances listed from the list-instances
command.
gcloud compute ssh instance-name --zone=europe-west4-a
Verify the VMs in your instance group can communicate with each other
Use the nmap
command to verify the VMs in your instance group can communicate
with each other. Run the nmap
command from the VM to which you are connected,
replacing instance-name with the instance name of another VM in your
instance group.
(vm)$ nmap -Pn -p 8477 instance-name
Starting Nmap 7.40 ( https://nmap.org ) at 2019-10-02 21:35 UTC Nmap scan report for pytorch-20190923-n4tx.c.jysohntpu.internal (10.164.0.3) Host is up (0.00034s latency). PORT STATE SERVICE 8477/tcp closed unknown
As long as the STATE field does not say filtered the firewall rules are set up correctly.
Create a Cloud TPU Pod
gcloud compute tpus create tpu-name \ --zone=europe-west4-a \ --network=default \ --accelerator-type=v2-32 \ --version=pytorch-1.13
Run distributed training on the Pod
From your VM session window, export the Cloud TPU name and activate the conda environment.
(vm)$ export TPU_NAME=tpu-name
(vm)$ conda activate torch-xla-1.13
Run the training script:
(torch-xla-1.13)$ python -m torch_xla.distributed.xla_dist \ --tpu=$TPU_NAME \ --conda-env=torch-xla-1.13 \ --env XLA_USE_BF16=1 \ --env ANY_OTHER=ENV_VAR \ -- python /usr/share/torch-xla-1.13/pytorch/xla/test/test_train_mp_imagenet.py \ --fake_data
Once you run the above command you should see output similar to the following
(note this is using --fake_data
). The training takes about 1/2 hour on
a v3-32 TPU Pod.
2020-08-06 02:38:29 [] Command to distribute: "python" "/usr/share/torch-xla-nightly/pytorch/xla/test/test_train_mp_imagenet.py" "--fake_data" 2020-08-06 02:38:29 [] Cluster configuration: {client_workers: [{10.164.0.43, n1-standard-96, europe-west4-a, my-instance-group-hm88}, {10.164.0.109, n1-standard-96, europe-west4-a, my-instance-group-n3q2}, {10.164.0.46, n1-standard-96, europe-west4-a, my-instance-group-s0xl}, {10.164.0.49, n1-standard-96, europe-west4-a, my-instance-group-zp14}], service_workers: [{10.131.144.61, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.59, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.58, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.60, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}]} 2020-08-06 02:38:31 10.164.0.43 [0] % Total % Received % Xferd Average Speed Time Time Time Current 2020-08-06 02:38:31 10.164.0.43 [0] Dload Upload Total Spent Left Speed 100 19 100 19 0 0 2757 0 --:--:-- --:--:-- --:--:-- 3166 2020-08-06 02:38:34 10.164.0.43 [0] % Total % Received % Xferd Average Speed Time Time Time Current 2020-08-06 02:38:34 10.164.0.43 [0] Dload Upload Total Spent Left Speed 100 19 100 19 0 0 2623 0 --:--:-- --:--:-- --:--:-- 2714 2020-08-06 02:38:37 10.164.0.46 [2] % Total % Received % Xferd Average Speed Time Time Time Current 2020-08-06 02:38:37 10.164.0.46 [2] Dload Upload Total Spent Left Speed 100 19 100 19 0 0 2583 0 --:--:-- --:--:-- --:--:-- 2714 2020-08-06 02:38:37 10.164.0.49 [3] % Total % Received % Xferd Average Speed Time Time Time Current 2020-08-06 02:38:37 10.164.0.49 [3] Dload Upload Total Spent Left Speed 100 19 100 19 0 0 2530 0 --:--:-- --:--:-- --:--:-- 2714 2020-08-06 02:38:37 10.164.0.109 [1] % Total % Received % Xferd Average Speed Time Time Time Current 2020-08-06 02:38:37 10.164.0.109 [1] Dload Upload Total Spent Left Speed 100 19 100 19 0 0 2317 0 --:--:-- --:--:-- --:--:-- 2375 2020-08-06 02:38:40 10.164.0.46 [2] % Total % Received % Xferd Average Speed Time Time Time Current 2020-08-06 02:38:40 10.164.0.49 [3] % Total % Received % Xferd Average Speed Time Time Time Current 2020-08-06 02:38:40 10.164.0.46 [2] Dload Upload Total Spent Left Speed 2020-08-06 02:38:40 10.164.0.49 [3] Dload Upload Total Spent Left Speed 100 19 100 19 0 0 2748 0 --:--:-- --:--:-- --:--:-- 3166 100 19 100 19 0 0 2584 0 --:--:-- --:--:-- --:--:-- 2714 2020-08-06 02:38:40 10.164.0.109 [1] % Total % Received % Xferd Average Speed Time Time Time Current 2020-08-06 02:38:40 10.164.0.109 [1] Dload Upload Total Spent Left Speed 100 19 100 19 0 0 2495 0 --:--:-- --:--:-- --:--:-- 2714 2020-08-06 02:38:43 10.164.0.49 [3] % Total % Received % Xferd Average Speed Time Time Time Current 2020-08-06 02:38:43 10.164.0.49 [3] Dload Upload Total Spent Left Speed 100 19 100 19 0 0 2654 0 --:--:-- --:--:-- --:--:-- 2714 2020-08-06 02:38:43 10.164.0.43 [0] % Total % Received % Xferd Average Speed Time Time Time Current 2020-08-06 02:38:43 10.164.0.43 [0] Dload Upload Total Spent Left Speed 100 19 100 19 0 0 2784 0 --:--:-- --:--:-- --:--:-- 3166 2020-08-06 02:38:43 10.164.0.46 [2] % Total % Received % Xferd Average Speed Time Time Time Current 2020-08-06 02:38:43 10.164.0.46 [2] Dload Upload Total Spent Left Speed 100 19 100 19 0 0 2691 0 --:--:-- --:--:-- --:--:-- 3166 2020-08-06 02:38:43 10.164.0.109 [1] % Total % Received % Xferd Average Speed Time Time Time Current 2020-08-06 02:38:43 10.164.0.109 [1] Dload Upload Total Spent Left Speed 100 19 100 19 0 0 2589 0 --:--:-- --:--:-- --:--:-- 2714 2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data.. 2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/14 Epoch=1 Step=0 Loss=6.87500 Rate=258.47 GlobalRate=258.47 Time=02:38:57 2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data.. 2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/15 Epoch=1 Step=0 Loss=6.87500 Rate=149.45 GlobalRate=149.45 Time=02:38:57 2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data.. 2020-08-06 02:38:57 10.164.0.43 [0] Epoch 1 train begin 02:38:52 2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:1/0 Epoch=1 Step=0 Loss=6.87500 Rate=25.72 GlobalRate=25.72 Time=02:38:57 2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data.. 2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/6 Epoch=1 Step=0 Loss=6.87500 Rate=89.01 GlobalRate=89.01 Time=02:38:57 2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data.. 2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/1 Epoch=1 Step=0 Loss=6.87500 Rate=64.15 GlobalRate=64.15 Time=02:38:57 2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data.. 2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/2 Epoch=1 Step=0 Loss=6.87500 Rate=93.19 GlobalRate=93.19 Time=02:38:57 2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data.. 2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/7 Epoch=1 Step=0 Loss=6.87500 Rate=58.78 GlobalRate=58.78 Time=02:38:57 2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data.. 2020-08-06 02:38:57 10.164.0.109 [1] Epoch 1 train begin 02:38:56 2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:1/8 Epoch=1 Step=0 Loss=6.87500 Rate=100.43 GlobalRate=100.43 Time=02:38:57 2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data.. 2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/13 Epoch=1 Step=0 Loss=6.87500 Rate=66.83 GlobalRate=66.83 Time=02:38:57 2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data.. 2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/11 Epoch=1 Step=0 Loss=6.87500 Rate=64.28 GlobalRate=64.28 Time=02:38:57 2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data.. 2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/10 Epoch=1 Step=0 Loss=6.87500 Rate=73.17 GlobalRate=73.17 Time=02:38:57 2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data.. 2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/9 Epoch=1 Step=0 Loss=6.87500 Rate=27.29 GlobalRate=27.29 Time=02:38:57 2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data.. 2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/12 Epoch=1 Step=0 Loss=6.87500 Rate=110.29 GlobalRate=110.29 Time=02:38:57 2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data.. 2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/20 Epoch=1 Step=0 Loss=6.87500 Rate=100.85 GlobalRate=100.85 Time=02:38:57 2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data.. 2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/22 Epoch=1 Step=0 Loss=6.87500 Rate=93.52 GlobalRate=93.52 Time=02:38:57 2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data.. 2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/23 Epoch=1 Step=0 Loss=6.87500 Rate=165.86 GlobalRate=165.86 Time=02:38:57
Clean up
To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.
Disconnect from the Compute Engine VM:
(vm)$ exit
Delete your Instance Group:
gcloud compute instance-groups managed delete instance-group-name
Delete your TPU Pod:
gcloud compute tpus delete ${TPU_NAME} --zone=europe-west4-a
Delete your Instance Group Template:
gcloud compute instance-templates delete instance-template-name
[Optional] Delete your VM disk image:
gcloud compute images delete image-name
What's next
Try the PyTorch colabs:
- Getting Started with PyTorch on Cloud TPUs
- Training MNIST on TPUs
- Training ResNet18 on TPUs with Cifar10 dataset
- Inference with Pretrained ResNet50 Model
- Fast Neural Style Transfer
- MultiCore Training AlexNet on Fashion MNIST
- Single Core Training AlexNet on Fashion MNIST