Run PyTorch code on TPU Pod slices
PyTorch/XLA requires all TPU VMs to be able to access the model code and data. You can use a startup script to download the software needed to distribute the model data to all TPU VMs.
In the Cloud Shell, run the following command to make sure you are running the current version of
gcloud
:$ gcloud components update
If you need to install
gcloud
, use the following command:$ sudo apt install -y google-cloud-sdk
Create some environment variables:
$ export PROJECT_ID=project-id $ export TPU_NAME=tpu-name $ export ZONE=us-central2-b $ export RUNTIME_VERSION=tpu-vm-v4-pt-2.0 $ export ACCELERATOR_TYPE=v4-32
Create the TPU VM
$ gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--project=${PROJECT_ID} \--accelerator-type=${ACCELERATOR_TYPE} \
--version ${RUNTIME_VERSION}
Configure and run the training script
Add your SSH certificate to your project:
ssh-add ~/.ssh/google_compute_engine
Clone XLA on all TPU VM workers
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all --command="git clone -b r2.0 https://github.com/pytorch/xla.git"
Run the training script on all workers
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --worker=all \ --command="PJRT_DEVICE=TPU python3 ~/xla/test/test_train_mp_imagenet.py \ --fake_data \ --model=resnet50 \ --num_epochs=1 2>&1 | tee ~/logs.txt"
The training takes about 5 minutes. When it completes, you should see a message similar to the following:
Epoch 1 test end 23:49:15, Accuracy=100.00 10.164.0.11 [0] Max Accuracy: 100.00%
Clean up
When you are done with your TPU VM follow these steps to clean up your resources.
Disconnect from the Compute Engine:
(vm)$ exit
Verify the resources have been deleted by running the following command. Make sure your TPU is no longer listed. The deletion might take several minutes.
$ gcloud compute tpus tpu-vm list \ --zone europe-west4-a