Run PyTorch code on TPU Pod slices
Set up a TPU VM Pod running PyTorch and run a calculation
PyTorch/XLA requires all TPU VMs to be able to access the model code and data. A startup script downloads the software needed to distribute the model data to all TPU VMs.
In the Cloud Shell, run the following command to make sure you are running the most current version of
gcloud
:$ gcloud components update
If you need to install
gcloud
, use the following command:$ sudo apt install -y google-cloud-sdk
Configure the PyTorch/XLA runtime:
There are two PyTorch/XLA runtimes: PJRT and XRT. We recommend you use PJRT unless you have a reason to use XRT. To learn more about the different runtime configurations for PyTorch/XLA, see the PJRT runtime documentation.
PJRT
$ export PROJECT_ID=project-id $ export TPU_NAME=tpu-name $ export ZONE=zone $ export RUNTIME_VERSION=tpu-vm-pt-2.0
XRT (legacy)
$ export PROJECT_ID=project-id $ export TPU_NAME=tpu-name $ export ZONE=zone $ export RUNTIME_VERSION=tpu-vm-base
Create the TPU VM
PJRT
$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone ${ZONE} --project ${PROJECT_ID} --accelerator-type v3-32 \ --version ${RUNTIME_VERSION} --metadata startup-script='#! /bin/bash cd /usr/share/ git clone --recursive https://github.com/pytorch/pytorch cd pytorch/ git clone --recursive https://github.com/pytorch/xla.git EOF'
The startup script then generates the
/usr/share/pytorch/xla
directory on the TPU VM and downloads the model code to this directory. The download is quick, however, file processing can take several minutes. You can verify the file processing has completed by runningls -al
in/usr/share/pytorch
to see if thexla
directory has been created.XRT (Legacy)
$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone ${ZONE} --project ${PROJECT_ID} --accelerator-type v3-32 \ --version ${RUNTIME_VERSION} --metadata startup-script='#! /bin/bash # Install torch, torchvision, and torch_xla wheels for TPU Pod with XRT pip install https://storage.googleapis.com/tpu-pytorch/wheels/xrt/torch-2.0-cp38-cp38-linux_x86_64.whl pip install https://storage.googleapis.com/tpu-pytorch/wheels/xrt/torchvision-2.0-cp38-cp38-linux_x86_64.whl pip install https://storage.googleapis.com/tpu-pytorch/wheels/xrt/torch_xla-2.0-cp38-cp38-linux_x86_64.whl # Download test scripts cd /usr/share/ git clone --recursive https://github.com/pytorch/pytorch cd pytorch/ git clone --recursive https://github.com/pytorch/xla.git EOF'
The startup script first installs our custom wheels for
torch
,torchvision
, andtorch_xla
, which are required for running our legacy runtime, XRT, on TPU Pods.The script then generates the
/usr/share/pytorch/xla
directory on the TPU VM and downloads the model code to this directory. The download is quick, however, file processing can take several minutes. You can verify the file processing has completed by runningls -al
in/usr/share/pytorch
to see if thexla
directory has been created.Once file processing completes, run the following
ssh
command to generate ssh-keys tossh
between VM workers on the Pod. Then, start the training.To
ssh
into other TPU VMs associated with the TPU Pod, append--worker ${WORKER_NUMBER}
in thessh
command, where the WORKER_NUMBER is 0-based index. For more details, see the TPU VM user guide.$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone ${ZONE} \ --project ${PROJECT_ID}
Update the project
ssh
metadata:(vm)$ gcloud compute config-ssh
(vm)$ export TPU_NAME=tpu-name
PJRT
(vm)$ gcloud alpha compute tpus tpu-vm ssh \ ${TPU_NAME} --zone=us-${ZONE} --project=${PROJECT_ID} --worker=all \ --command="PJRT_DEVICE=TPU python3 /usr/share/pytorch/xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"
The training takes about 3 minutes. When it completes, you should see a message similar to the following:
Epoch 1 test end 23:49:15, Accuracy=100.00 10.164.0.11 [0] Max Accuracy: 100.00%
XRT (legacy)
(vm)$ python3 -m torch_xla.distributed.xla_dist \ --tpu=${TPU_NAME} -- python3 /usr/share/pytorch/xla/test/test_train_mp_imagenet.py \ --fake_data --model=resnet50 --num_epochs=1
The training takes about 3 minutes. When it completes, you should see a message similar to the following:
Epoch 1 test end 23:49:15, Accuracy=100.00 10.164.0.11 [0] Max Accuracy: 100.00%
You can find the server-side logging in
/tmp/xrt_server_log
on each worker.(vm)$ ls /tmp/xrt_server_log/
server_20210401-031010.log
If you want to restart the
XRT_SERVER
(in case the server is in an unhealthy state) you can pass--restart-tpuvm-pod-server
when runningxla_dist
. New XRT server settings such as environment variables likeLD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
do not take effect until you restart the server.
Clean up
When you are done with your TPU VM follow these steps to clean up your resources.
Disconnect from the Compute Engine:
(vm)$ exit
Delete your Cloud TPU.
$ gcloud compute tpus tpu-vm delete tpu-name \ --zone europe-west4-a
Verify the resources have been deleted by running the following command. Make sure your TPU is no longer listed. The deletion might take several minutes.
$ gcloud compute tpus tpu-vm list \ --zone europe-west4-a