Stay organized with collections Save and categorize content based on your preferences.

Run PyTorch code on TPU Pod slices

Set up a TPU VM Pod running PyTorch and run a calculation

PyTorch/XLA requires all TPU VMs to be able to access the model code and data. A startup script downloads the software needed to distribute the model data to all TPU VMs.

  1. In the Cloud Shell, run the following command to make sure you are running the most current version of gcloud:

    $ gcloud components update
    

    If you need to install gcloud, use the following command:

    $ sudo apt install -y google-cloud-sdk
  2. Configure the PyTorch/XLA runtime:

    There are two PyTorch/XLA runtimes: PJRT and XRT. We recommend you use PJRT unless you have a reason to use XRT. To learn more about the different runtime configurations for PyTorch/XLA, see the PJRT runtime documentation.

    PJRT

    $ export PROJECT_ID=project-id
    $ export TPU_NAME=tpu-name
    $ export ZONE=zone
    $ export RUNTIME_VERSION=tpu-vm-pt-2.0
    

    XRT (legacy)

    $ export PROJECT_ID=project-id
    $ export TPU_NAME=tpu-name
    $ export ZONE=zone
    $ export RUNTIME_VERSION=tpu-vm-base
    
  3. Create the TPU VM

    PJRT

    $ gcloud compute tpus tpu-vm create ${TPU_NAME} \
    --zone ${ZONE} --project ${PROJECT_ID} --accelerator-type v3-32 \
    --version ${RUNTIME_VERSION} --metadata startup-script='#! /bin/bash
    cd /usr/share/
    git clone --recursive https://github.com/pytorch/pytorch
    cd pytorch/
    git clone --recursive https://github.com/pytorch/xla.git
    EOF'
    

    The startup script then generates the /usr/share/pytorch/xla directory on the TPU VM and downloads the model code to this directory. The download is quick, however, file processing can take several minutes. You can verify the file processing has completed by running ls -al in /usr/share/pytorch to see if the xla directory has been created.

    XRT (Legacy)

    $ gcloud compute tpus tpu-vm create ${TPU_NAME} \
    --zone ${ZONE} --project ${PROJECT_ID} --accelerator-type v3-32 \
    --version ${RUNTIME_VERSION} --metadata startup-script='#! /bin/bash
    # Install torch, torchvision, and torch_xla wheels for TPU Pod with XRT
    pip install https://storage.googleapis.com/tpu-pytorch/wheels/xrt/torch-2.0-cp38-cp38-linux_x86_64.whl
    pip install https://storage.googleapis.com/tpu-pytorch/wheels/xrt/torchvision-2.0-cp38-cp38-linux_x86_64.whl
    pip install https://storage.googleapis.com/tpu-pytorch/wheels/xrt/torch_xla-2.0-cp38-cp38-linux_x86_64.whl
    
    # Download test scripts
    cd /usr/share/
    git clone --recursive https://github.com/pytorch/pytorch
    cd pytorch/
    git clone --recursive https://github.com/pytorch/xla.git
    EOF'
    

    The startup script first installs our custom wheels for torch, torchvision, and torch_xla, which are required for running our legacy runtime, XRT, on TPU Pods.

    The script then generates the /usr/share/pytorch/xla directory on the TPU VM and downloads the model code to this directory. The download is quick, however, file processing can take several minutes. You can verify the file processing has completed by running ls -al in /usr/share/pytorch to see if the xla directory has been created.

  4. Once file processing completes, run the following ssh command to generate ssh-keys to ssh between VM workers on the Pod. Then, start the training.

    To ssh into other TPU VMs associated with the TPU Pod, append --worker ${WORKER_NUMBER} in the ssh command, where the WORKER_NUMBER is 0-based index. For more details, see the TPU VM user guide.

    $ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
     --zone ${ZONE} \
     --project ${PROJECT_ID}
    

    Update the project ssh metadata:

    (vm)$ gcloud compute config-ssh
    
    (vm)$ export TPU_NAME=tpu-name
    

    PJRT

    (vm)$ gcloud alpha compute tpus tpu-vm ssh \
    ${TPU_NAME} --zone=us-${ZONE} --project=${PROJECT_ID} --worker=all \
    --command="PJRT_DEVICE=TPU python3 /usr/share/pytorch/xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"
    

    The training takes about 3 minutes. When it completes, you should see a message similar to the following:

    Epoch 1 test end 23:49:15, Accuracy=100.00
    10.164.0.11 [0] Max Accuracy: 100.00%
    

    XRT (legacy)

    (vm)$ python3 -m torch_xla.distributed.xla_dist \
    --tpu=${TPU_NAME} -- python3 /usr/share/pytorch/xla/test/test_train_mp_imagenet.py \
    --fake_data --model=resnet50 --num_epochs=1
    

    The training takes about 3 minutes. When it completes, you should see a message similar to the following:

    Epoch 1 test end 23:49:15, Accuracy=100.00
    10.164.0.11 [0] Max Accuracy: 100.00%
    

    You can find the server-side logging in /tmp/xrt_server_log on each worker.

    (vm)$ ls /tmp/xrt_server_log/
    
    server_20210401-031010.log
    

    If you want to restart the XRT_SERVER (in case the server is in an unhealthy state) you can pass --restart-tpuvm-pod-server when running xla_dist. New XRT server settings such as environment variables like LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 do not take effect until you restart the server.

Clean up

When you are done with your TPU VM follow these steps to clean up your resources.

  1. Disconnect from the Compute Engine:

    (vm)$ exit
    
  2. Delete your Cloud TPU.

    $ gcloud compute tpus tpu-vm delete tpu-name \
      --zone europe-west4-a
    
  3. Verify the resources have been deleted by running the following command. Make sure your TPU is no longer listed. The deletion might take several minutes.

    $ gcloud compute tpus tpu-vm list \
      --zone europe-west4-a