Run a calculation on a Cloud TPU VM using PyTorch
This document provides a brief introduction to working with PyTorch and Cloud TPU.
Before you begin
Before running the commands in this document, you must create a Google Cloud account,
install the Google Cloud CLI, and configure the gcloud command. For more
information, see Set up the Cloud TPU environment.
Create a Cloud TPU using gcloud
- Define some environment variables to make the commands easier to use. - export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-east5-a export ACCELERATOR_TYPE=v5litepod-8 export RUNTIME_VERSION=v2-alpha-tpuv5-lite - Environment variable descriptions- Variable - Description - PROJECT_ID- Your Google Cloud project ID. Use an existing project or create a new one. - TPU_NAME- The name of the TPU. - ZONE- The zone in which to create the TPU VM. For more information about supported zones, see TPU regions and zones. - ACCELERATOR_TYPE- The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions. - RUNTIME_VERSION- The Cloud TPU software version. 
- Create your TPU VM by running the following command: - $ gcloud compute tpus tpu-vm create $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION 
Connect to your Cloud TPU VM
Connect to your TPU VM over SSH using the following command:
$ gcloud compute tpus tpu-vm ssh $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
If you fail to connect to a TPU VM using SSH, it might be because the TPU VM doesn't have an external IP address. To access a TPU VM without an external IP address, follow the instructions in Connect to a TPU VM without a public IP address.
Install PyTorch/XLA on your TPU VM
$ (vm) sudo apt-get update $ (vm) sudo apt-get install libopenblas-dev -y $ (vm) pip install numpy $ (vm) pip install torch torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
Verify PyTorch can access TPUs
Use the following command to verify PyTorch can access your TPUs:
$ (vm) PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))"
The output from the command should look like the following:
['xla:0', 'xla:1', 'xla:2', 'xla:3', 'xla:4', 'xla:5', 'xla:6', 'xla:7']
Perform a basic calculation
- Create a file named - tpu-test.pyin the current directory and copy and paste the following script into it:- import torch import torch_xla.core.xla_model as xm dev = xm.xla_device() t1 = torch.randn(3,3,device=dev) t2 = torch.randn(3,3,device=dev) print(t1 + t2)
- Run the script: - (vm)$ PJRT_DEVICE=TPU python3 tpu-test.py - The output from the script shows the result of the computation: - tensor([[-0.2121, 1.5589, -0.6951], [-0.7886, -0.2022, 0.9242], [ 0.8555, -1.8698, 1.4333]], device='xla:1')
Clean up
To avoid incurring charges to your Google Cloud account for the resources used on this page, follow these steps.
- Disconnect from the Cloud TPU instance, if you have not already done so: - (vm)$ exit - Your prompt should now be - username@projectname, showing you are in the Cloud Shell.
- Delete your Cloud TPU. - $ gcloud compute tpus tpu-vm delete $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE 
- Verify the resources have been deleted by running the following command. Make sure your TPU is no longer listed. The deletion might take several minutes. - $ gcloud compute tpus tpu-vm list \ --zone=$ZONE 
What's next
Read more about Cloud TPU VMs: