Cloud TPU v4 user's guide
This guide describes how to set up and use a Google Cloud Platform project to use the Cloud TPU v4 release and the TPU VM architecture. If you are new to Cloud TPUs, you can learn about them from the TPUs, System architecture and the Cloud TPU quickstarts.
Cloud TPU VMs run on a TPU host machine (the machine connected to the Cloud TPU device) and offer significantly better performance and usability than TPU Nodes when working with TPUs. The architectural differences between TPU VMs and TPU Nodes are described in the System architecture document.
This guide describes the commands used to set up and run Cloud TPU v4 applications using TPU VMs with TensorFlow, PyTorch, and JAX. It also describes solutions to common issues you might encounter when starting to use Cloud TPU v4.
This document uses gcloud
commands to perform many of the tasks needed
to use Cloud TPU v4. For more information on the gcloud
API, see the
gcloud
reference.
Concepts and terminology
Quota
For v4 there is only one quota type for all TPU slice sizes. This is different from how quota works for v2 and v3 TPU types where there is a quota for a single-device TPU and a different quota for Pod TPUs. For Cloud TPU v4 slices, the default quota is 0. You must contact sales to have a quota allocated.
Accelerator type
TPU v4 has a wider range of accelerator types than previous generations. See AcceleratorConfig section for more information.
Cores, devices, and chips
- Each TPU host has 4 chips and 8 TensorCores (2 TensorCores per chip)
- XLA combines the resources of each chip to make a virtual core, therefore, when training, you will see 4 virtual cores per device rather than 8 physical TensorCores per device.
Improved CPU performance on v4 TPUs
The TPU v4 runtime is not natively NUMA-aware. You can take advantage of NUMA-locality benefits by binding your training script to NUMA Node 0.
Non Uniform Memory Access (NUMA) is a computer memory architecture for machines that have multiple CPUs. Each CPU has direct access to a block of high-speed memory. A CPU and it's memory is called a Node. Nodes are connected to Nodes directly adjacent to each other. A CPU from one Node can access memory in another Node, but this access is slower than accessing memory within a Node.
Software running on a multi-CPU machine can place data needed by a CPU within its Node, increasing memory throughput. For more information about NUMA, see Non Uniform Memory Access on Wikipedia.
To enable NUMA Node binding
Install the numactl
command line tool.
$ sudo apt-get update $ sudo apt-get install numactl
Use numactl --cpunodebind=0
when launching your training script. This binds
your script code to NUMA Node 0.
$ numactl --cpunodebind=0 python3 your-training-script
When should I use this?
- If your workload has a heavy dependence on CPU workloads (for example, image classification, recommendation workloads) regardless of framework.
- If you are using a TPU runtime version without a
-pod
suffix (for example, tpu-vm-tf-2.10.0-v4).
Known Issues
- A subnet must be created in
us-central2
before using a TPU v4. See Custom Network Resources for more information. - By default, projects have a low IN_USE_ADDRESSES limit so this quota needs to be increased before creating large slices. You need to contact your sales representative to increase this quota. Small increases (i.e., 8->16 for us-central2-b) should be approved automatically.
Set up and prepare a Google Cloud project
To use v4 Cloud TPUs, you must first prepare a Google Cloud project. Use the following steps to do that.
Set up a Google Cloud Project
- Sign in to your Google Account. If you don't already have a Google account, sign up for a new account.
- In the Cloud console, select or create a Cloud project from the project selector page.
- Make sure billing is enabled for your project.
Set your project ID using the
gcloud
CLI. The project ID is the name of your project shown on the Cloud console.$ gcloud config set project project-ID
Enable TPU API
Enable the TPU API using the following
gcloud
command in Cloud Shell. (You can also enable it from the Google Cloud console.)$ gcloud services enable tpu.googleapis.com
Request allowlist
Please use this form to request adding your account to the allowlist.
Create a TPU service account using the following command:
$ gcloud compute tpus tpu-vm service-identity create --zone=us-central2-b
Create a subnet for TPUs
$ gcloud compute networks subnets create tpusubnet \ --network=default \ --range=10.12.0.0/20 \ --region=us-central2 \ --enable-private-ip-google-access
TPU setup
After setting up your project, create a Cloud TPU using the following steps.
Set some Cloud TPU variables
$ export TPU_NAME=tpu-name $ export ZONE=us-central2-b $ export RUNTIME_VERSION=tpu-vm-tf-2.10.0-v4 $ export PROJECT_ID=project-id
- TPU_NAME: A user-assigned name of the TPU node.
- ZONE: The location of the TPU node. Currently, only
us-central2-b
is supported. - PROJECT_ID: Your project ID
- RUNTIME_VERSION
- If you are using JAX, use
tpu-vm-v4-base
- If you are using PyTorch, use
tpu-vm-v4-pt-2.0
- If you are using TensorFlow on a
v4-8
TPU, usetpu-vm-tf-2.10.0-v4
. - If you are using TensorFlow on a larger Pod slice, use
tpu-vm-tf-2.10.0-pod-v4
.
- If you are using JAX, use
Create a v4 TPU VM
There are two ways to specify the TPU Pod Slice you want to create:
Using
accelerator-type
input.This way is recommended when you are not specifying any topology.
Using TPU
type
andtopology
flags.This way is recommended when you would like to customize the physical topology. This is generally required for performance tuning with slices >= 256 chips.
Create a v4 TPU using the
accelerator-type
flag ingcloud
.- Set your
accelerator-type
variable:
$ export ACCELERATOR_TYPE=v4-8
- ACCELERATOR_TYPE: See the TPU type column in regions and zones for supported accelerator types.
- Set your
Create the v4 TPU:
$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone us-central2-b \ --accelerator-type ${ACCELERATOR_TYPE} \ --version ${RUNTIME_VERSION} \ --subnetwork=tpusubnet
Alternatively, you can create a v4 TPU VM using the
type
andtopology
flags ingcloud
.- Set your
type
andtopology
variables:
$ export TPU_TYPE=v4 $ export TOPOLOGY=2x2x1
- TPU_TYPE: See the chip-topology-based topologies documentation for more information.
TOPOLOGY: See the types and topologies for information on v4 topologies.
Create the v4 TPU:
$ gcloud alpha compute tpus tpu-vm create ${TPU_NAME} \ --zone=us-central2-b \ --subnetwork=tpusubnet \ --type=${TPU_TYPE} \ --topology=${TOPOLOGY} \ --version=${RUNTIME_VERSION}
Required flags
tpu-name
- The name of the TPU VM you are creating.
zone
- The zone where you are creating your Cloud TPU.
subnet
- The subnet you created previously.
tpu-type
- See the topology section for the supported TPU types.
topology
- See the topology section for the supported topologies.
version
- The runtime version you wish to use.
Optional flags
preemptible
- Create a preemptible TPU. It may be preempted to free up resources. See preemptible TPUs for more details.
PROJECT_ID
- The project you are using to set up the TPU.
enable_external_ips
- When set to true, add access configs to the TPU VMs when the TPU is created. Refer to Private Google Access for more information.
See Types and topologies for more details on the supported TPU types and topologies.
- Set your
Create a v4 TPU VM using
curl
:
$ curl -X POST -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" -d "{accelerator_type: '${ACCELERATOR_TYPE}', runtime_version:'${RUNTIME_VERSION}', network_config: {enable_external_ips: true}}" https://tpu.googleapis.com/v2/projects/${PROJECT_ID}/locations/us-central2-b/nodes?node_id=${TPU_NAME}
SSH into the TPU VM
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE}
Train ResNet with TensorFlow
You can train any TPU-compatible model with TensorFlow on a v4 Pod slice. This section shows how to train ResNet on a TPU.
Train ResNet on a v4-8 TPU
Complete the instructions in Project setup and TPU setup to set up a v4-8 slice. Then, run the following command on your TPU VM to train ResNet:
(vm)$ export PYTHONPATH=/usr/share/tpu/tensorflow/resnet50_keras (vm)$ python3 /usr/share/tpu/tensorflow/resnet50_keras/resnet50.py --tpu=local --data=gs://cloud-tpu-test-datasets/fake_imagenet
Train ResNet on a TPU Pod slice
To train ResNet on a v4 Pod slice, must create a TPU v4 Pod slice. To do this,
use the instructions in Project setup and TPU setup
and specify a Pod type (for example, --accelerator-type=v4-32
or
--type=v4
, and --topology=2x2x4
) as the
accelerator type or accelerator-config
respectively and specify the Pod runtime version
(tpu-vm-tf-2.10.0-pod-v4
).
Export a required environment variable
SSH to any of the TPU workers, (for example, worker 0) and export the following environment variable.
export TPU_LOAD_LIBRARY=0
Run the following commands to train the model. Substitute the TPU name you have
chosen into the TPU_NAME
variable.
export PYTHONPATH=/usr/share/tpu/tensorflow/resnet50_keras export TPU_NAME=tpu-name python3 /usr/share/tpu/tensorflow/resnet50_keras/resnet50.py --tpu=${TPU_NAME} --data=gs://cloud-tpu-test-datasets/fake_imagenet
You can check the logs of the TPU worker with:
sudo docker logs tpu-runtime
For other TF 2.x examples, you can follow TPU VM tutorials in the Cloud TPU documentation, for example, BERT on TensorFlow 2.x.
Train ML workloads with PyTorch / XLA
This section describes how to run a simple calculation using a v4-8 TPU with PyTorch / XLA. Train the ResNet model on a v4-8, extends the TPU use case to train ResNet on either a v4-8 TPU or on a larger v4 Pod slice.
Set XRT TPU device configuration:
export XRT_TPU_CONFIG="localservice;0;localhost:51011" # Set the environment variable to visible devices* export TPU_NUM_DEVICES=4 # Allow LIBTPU LOAD by multiple processes export ALLOW_MULTIPLE_LIBTPU_LOAD=1
For models that have sizable, frequent allocations, memory allocation using
tcmalloc
significantly improves performance compared to the default malloc
implementation. Therefore, tcmalloc
is the default malloc
used on TPU VM.
However, depending on your workload (for example, with DLRM which has very large
allocations for its embedding tables) tcmalloc
might cause a slowdown. In this
case, you can change the default to malloc
by unsetting the following
variable:
unset LD_PRELOAD
Perform a simple calculation
Start the Python 3 interpreter:
python3
import torch import torch_xla.core.xla_model as xm dev = xm.xla_device()
t1 = torch.randn(3,3,device=dev) t2 = torch.randn(3,3,device=dev) print(t1 + t2)
This generates the following output:tensor([[-0.3689, -1.1727, 0.6910], [ 0.0431, 1.0158, 1.6740], [-0.8026, 2.5862, -1.5649]], device='xla:1')
Use exit()
or Ctrl-D
(i.e. EOF) to exit the Python3 interpreter.
Train ResNet on a v4-8 TPU with PyTorch / XLA
A v4-8 is the smallest supported v4 configuration. The following section shows how to train ResNet on a larger configuration.
Follow the instructions for setting up a v4-8 TPU VM and perform the following steps on the TPU VM.
Export environment variables
As you continue these instructions, run each command that begins with
(vm)$
in your TPU VM.
After you create the v4-8 TPU and SSH into the TPU VM, export the following variables to the TPU VM:
(vm)$ export TPU_NAME=tpu-name (vm)$ export ZONE=us-central2-b (vm)$ export XRT_TPU_CONFIG='localservice;0;localhost:51011' (vm)$ export TPU_NUM_DEVICES=4
Clone PyTorch and PyTorch/XLA
(vm)$ cd /usr/share/ (vm)$ sudo git clone -b release/1.13 --recursive https://github.com/pytorch/pytorch (vm)$ cd pytorch/ (vm)$ sudo git clone -b r1.13 --recursive https://github.com/pytorch/xla.git
Run the training
(vm)$ python3 /usr/share/pytorch/xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1
The training takes approximately 7 minutes to run and generates output similar to the following:
Epoch 1 test end 17:04:13, Accuracy=100.00 Max Accuracy: 100.00%
Clean up your TPU VM resources
Disconnect from the TPU VM, if you have not already done so:
(vm)$ exit
Your prompt should now be
username@projectname
, showing you are in the Cloud Shell.Delete your TPU VM.
$ gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --zone=${ZONE}
Train ResNet on a larger v4 Pod slice with PyTorch / XLA
The previous section specified a v4-8 configuration. This section specifies a v4-32 configuration.
Perform the following steps:
Export TPU configuration variables
$ export TPU_NAME=tpu-name $ export ZONE=us-central2-b $ export PROJECT_ID=project-id
Create a v4-Pod with a startup script
$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone ${ZONE} \ --accelerator-type v4-32 \ --project ${PROJECT_ID} \ --version tpu-vm-v4-pt-1.13
To run the example, first clone the XLA repo to all workers using this command
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} --project=${PROJECT_ID} \ --worker=all --command="git clone -b r1.13 https://github.com/pytorch/xla.git"
SSH to the TPU VM
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone ${ZONE} \ --project ${PROJECT_ID}
As you continue these instructions, run each command that begins with
(vm)$
in your TPU VM.
Configure SSH for Pod use
(vm)$ gcloud compute config-ssh
This command asks for a passphrase, you can hit enter twice for an empty passphrase.
Run the training script
(vm)$ export TPU_NAME=tpu-name
(vm)$ python3 -m torch_xla.distributed.xla_dist \ --tpu=${TPU_NAME} \ --restart-tpuvm-pod-server \ -- python3 ~/xla/test/test_train_mp_imagenet.py \ --fake_data \ --model=resnet50 \ --num_epochs=1 2>&1 | tee ~/logs.txt
The training takes approximately 3 minutes to run and generates output similar to:
Epoch 1 test end 15:50:56, Accuracy=100.00 Max Accuracy: 100.00%
Clean up your TPU VM resources
Disconnect from the TPU VM, if you have not already done so:
(vm)$ exit
Your prompt should now be
username@projectname
, showing you are in the Cloud Shell.Delete your TPU VM.
$ gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --zone=${ZONE}
Set up and train ML workloads on JAX
To train ResNet on a TPU v4, create TPU Pod slice. To do this,
follow the instructions in
TPU setup,
to specify a type as the
accelerator type
(for example, v4-32), and specify the Pod runtime version
(tpu-vm-v4-base
).
Basic JAX setup
Install JAX and jaxlib
on a Cloud TPU VM:
(vm)$ sudo pip uninstall jax jaxlib libtpu-nightly libtpu -y (vm)$ pip3 install -U pip (vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
(vm)$ python3
Python 3.6.9 (default, Jul 17 2020, 12:50:27) [GCC 8.4.0] on linux Type "help", "copyright", "credits" or "license" for more information.
(vm)$ import jax (vm)$ jax.device_count()4
(vm)$ jax.numpy.add(1, 1)
DeviceArray(2, dtype=int32)
At this point, you're ready to run any JAX code you please! The flax examples are a great place to start running standard ML models in JAX. For instance, to train a basic MNIST convolutional model:
# Run flax mnist example (optional) (vm)$ pip install --user tensorflow-datasets==3.1.0 ml_collections clu (vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user -e flax (vm)$ cd flax/examples/mnist (vm)$ mkdir /tmp/mnist (vm)$ python3 main.py --workdir=/tmp/mnist --config=configs/default.py --config.learning_rate=0.05 --config.num_epochs=5
JAX training on TPU Pods
TPU Pod slices give you access to even more networked TPU TensorCores (see the
System Architecture documentation
for more information on what Pods are). The main difference when running JAX
code on Pods is that a Pod includes multiple host machines. In general, you
should run your JAX program on each host in the Pod, using jax.pmap
to perform
cross-Pod computation and communication. See the
pmap documentation
for more details, especially the "Multi-host platforms" section.
Pod training setup
This section shows how to set up a v4-16 Pod slice and run a small program on each Pod host.
If you haven't already done so for your project, create a TPU service account using the following command:
$ gcloud compute tpus tpu-vm service-identity create --zone=us-central2-b
If you haven't already done so for your project, create a subnet for TPUs
$ gcloud compute networks subnets create tpusubnet \ --network=default \ --range=10.12.0.0/20 \ --region=us-central2 \ --enable-private-ip-google-access
Set up Cloud TPU variables
$ export TPU_NAME=tpu-name $ export ZONE=us-central2-b $ export ACCELERATOR_TYPE=v4-16 $ export RUNTIME_VERSION=tpu-vm-v4-base $ export PROJECT_ID=project-id
Create a TPU VM using
gcloud
$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --version ${RUNTIME_VERSION} \ --subnetwork=tpusubnet
Set up a firewall for SSH
The default network comes preconfigured to allow SSH access to all VMs. If you don't use the default network, or the default network was edited, you may need to explicitly enable SSH access by adding a firewall-rule:
$ gcloud compute firewall-rules create \ --network=NETWORK allow-ssh \ --allow=tcp:22
Install JAX into the TPU VM
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} --worker=all --command="pip install \ --upgrade 'jax[tpu]>0.3.0' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html" --project=${PROJECT_ID}
Set up a local example file with commands to run on the Pod.
# The following code snippet will be run on all TPU hosts import jax # The total number of TPU chips in the Pod device_count = jax.device_count() # The number of TPU chip attached to this host local_device_count = jax.local_device_count() # The psum is performed over all mapped devices across the Pod xs = jax.numpy.ones(jax.local_device_count()) r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs) # Print from a single host to avoid duplicated output if jax.process_index() == 0: print('global device count:', jax.device_count()) print('local device count:', jax.local_device_count()) print('pmap result:', r)
Copy the example file to all Pod hosts
$ gcloud compute tpus tpu-vm scp example.py tpu-name: --worker=all --zone=${ZONE}
Launch the example.py program on each host (TPU worker) in the Pod:
gcloud compute tpus tpu-vm ssh tpu-name \ --zone ${ZONE} --worker=all --command "python3 example.py"
Running example.py on all the Pod hosts generates the following output:
SSH: Attempting to connect to worker 0... SSH: Attempting to connect to worker 1... global device count: 8 local device count: 4 pmap result: [8. 8. 8. 8.]
Manage TPUs
You can manage TPUs with the gcloud
CLI or with curl
calls.
Create TPUs
TPU setup shows an example of how to create a TPU v4.
See the gcloud
API documentation
for details on the gcloud create
command.
Get TPU details
You can get the details of a node through TPU API requests.
$ gcloud compute tpus tpu-vm describe ${TPU_NAME} \
--zone ${ZONE} \
--project ${PROJECT_ID}
Using a curl
call:
curl -H "Authorization: Bearer $(gcloud auth print-access-token)" https://tpu.googleapis.com/v2/projects/${PROJECT_ID}/locations/${ZONE}/nodes/${TPU_NAME}
The response body contains an instance of Node.
List TPUs
You can get a list of Cloud TPUs through TPU API requests.
Using the gcloud
CLI:
$ gcloud compute tpus tpu-vm list \
--zone ${ZONE} \
--project ${PROJECT_ID}
Using a curl
call:
curl -H "Authorization: Bearer $(gcloud auth print-access-token)" https://tpu.googleapis.com/v2/projects/${PROJECT_ID}/locations/${ZONE}/nodes/
Delete TPUs
You can delete created Cloud TPUs through TPU API requests.
Using the gcloud
CLI:
$ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--zone ${ZONE} \
--project ${PROJECT_ID}
Using a curl
call:
curl -X DELETE -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" https://tpu.googleapis.com/v2/projects/${PROJECT_ID}/locations/${ZONE}/nodes/${TPU_NAME}
Access TPU VMs with SSH
(optional). Set up a firewall for SSH
The default network comes preconfigured to allow SSH access to all VMs. If you don't use the default network, or the default network was edited, you may need to explicitly enable SSH access by adding a firewall-rule:
$ gcloud compute firewall-rules create \ --network=NETWORK allow-ssh \ --allow=tcp:22
SSH into the TPU VMs
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone ${ZONE} \ --project ${PROJECT_ID}
Required fields
- TPU_NAME: Name of the TPU node.
- ZONE: The location of the TPU node. Currently, only
us-central2-b
is supported. - PROJECT_ID: The project you created above.
See the
gcloud
API documentation for a list of optional fields.
Use TPUs with Advanced Configs
Custom Network Resources
Private Google Access
Custom Service Account
Custom VM SSH methods
Accelerator Configuration
With TPU v4, types and topologies
can be specified using one of two gcloud
flags when creating a TPU:
AcceleratorType
or AcceleratorConfig
.
AcceleratorType
is the current flag consisting of a TPU type followed
by the number of TensorCores. For example, v3-128 specifies a TPU v3 with 128 TensorCores.
The v4 AcceleratorConfig
feature offers more configuration options. See the
Topology
section for more information on AcceleratorConfig
topology options.
AcceleratorType
and AcceleratorConfig
are both supported for all
TPU versions (v2, v3, v4).
V4 TPU types
The v4 type must be v4
.
V4 Topology
The v4 topology is specified in chips (unlike AcceleratorType
which
uses TensorCores) and
there are 3 aspects to v4 topologies, e.g., 4x4x4. Note that there are two TensorCores
per chip.
Large slices can be built from one or more 4x4x4 "cubes" of chips. Refer to the Types and topologies document for details on the possible v4 topologies.
Other Information
Request More TPU quota
The default quota allocation for Cloud TPU v4 is zero for all projects. Request quota by contacting your sales representative. following the instructions in the quota policy.
Troubleshooting
["gcloud auth login" cannot open browser]
When running
$ gcloud auth login
It attempts to open a browser window over SSH, and prints a link that leads to
an invalid localhost URL. Use the --no-launch-browser
flag instead:
$ gcloud auth login --no-launch-browser
[Cannot SSH to TPU VM]
When running
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE}
Example error message:
Waiting for SSH key to propagate.
ssh: connect to host 34.91.136.59 port 22: Connection timed out
ssh: connect to host 34.91.136.59 port 22: Connection timed out
ssh: connect to host 34.91.136.59 port 22: Connection timed out
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.ssh) Could not SSH into the instance.
It is possible that your SSH key has not propagated to the instance yet.
Try running this command again. If you still cannot connect, verify that
the [firewall](#access-tpu-vms-with-ssh) and instance are set to
accept SSH traffic.
Something might be wrong with the SSH key propagation. Try moving the
automatically-generated keys to a backup location to force gcloud
to recreate
them:
mv ~/.ssh/google_compute_engine ~/.ssh/old-google_compute_engine mv ~/.ssh/google_compute_engine.pub ~/.ssh/old-google_compute_engine.pub
Clean up
You should delete your TPUs when they are no longer needed. Follow the TPU deletion instructions to delete your TPU.
Q/A
JAX jobs on Pods
We're working on tools and recommendations for orchestrating JAX jobs on Pods,
but we'd also like it to be possible for users to bring their own if their lab
already uses a multi-machine job scheduler or cluster manager (e.g., SLURM or
Kubernetes).
Can I use V1Alpha1 and V1 APIs to manage direct-access Cloud TPU VMs?
Get/List is allowed, but mutations are only available in the V2 API
Version.
Requesting help
Contact Cloud TPU support. If you have an active Google Cloud project, be prepared to provide the following information:
- Your Google Cloud project ID
- Your TPU node name, if exists
- Other information you want to provide