Run JAX code on TPU Pod slices
After you have your JAX code running on a single TPU board, you can scale up your code by running it on a TPU Pod slice. TPU Pod slices are multiple TPU boards connected to each other over dedicated high-speed network connections. This document is an introduction to running JAX code on TPU Pod slices; for more in-depth information, see Using JAX in multi-host and multi-process environments.
If you want to use mounted NFS for data storage, you must set OS Login for all TPU VMs in the Pod slice. For more information, see Using an NFS for data storage.Set up your environment
In the Cloud Shell, run the following command to make sure you are running the current version of
gcloud
:$ gcloud components update
If you need to install
gcloud
, use the following command:$ sudo apt install -y google-cloud-sdk
Create some environment variables:
$ export TPU_NAME=tpu-name $ export ZONE=us-central2-b $ export RUNTIME_VERSION=tpu-ubuntu2204-base $ export ACCELERATOR_TYPE=v4-32
Create a TPU Pod slice
Before running the commands in this document, make sure you have followed the instructions in Set up an account and Cloud TPU project. Run the following commands on your local machine.
Create a TPU Pod slice using the gcloud
command. For example, to create a
v4-32 Pod slice use the following command:
$ gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version=${RUNTIME_VERSION}
Install JAX on the Pod slice
After creating the TPU Pod slice, you must install JAX on all hosts in the TPU
Pod slice. You can install JAX on all hosts with a single command using the
--worker=all
option:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --worker=all \ --command="pip install \ --upgrade 'jax[tpu]>0.3.0' \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
Run JAX code on the Pod slice
To run JAX code on a TPU Pod slice, you must run the code on each host in the
TPU Pod slice. The jax.device_count()
call stops responding until it is
called on each host in the Pod slice. The following example illustrates how to
run a JAX calculation on a TPU Pod slice.
Prepare the code
You need gcloud
version >= 344.0.0 (for the
scp
command).
Use gcloud --version
to check your gcloud
version, and
run gcloud components upgrade
, if needed.
Create a file called example.py
with the following code:
# The following code snippet will be run on all TPU hosts
import jax
# The total number of TPU cores in the Pod
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
Copy example.py
to all TPU worker VMs in the Pod slice
$ gcloud compute tpus tpu-vm scp example.py ${TPU_NAME} \
--worker=all \
--zone=${ZONE}
If you have not previously used the scp
command, you might see an
error similar to the following:
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.
To resolve the error, run the ssh-add
command as displayed in the
error message and rerun the command.
Run the code on the Pod slice
Launch the example.py
program on every VM:
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--zone=${ZONE} \
--worker=all \
--command="python3 example.py"
Output (produced with a v4-32 Pod slice):
global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]
Clean up
When you are done with your TPU VM follow these steps to clean up your resources.
Disconnect from the Compute Engine instance, if you have not already done so:
(vm)$ exit
Your prompt should now be
username@projectname
, showing you are in the Cloud Shell.Delete your Cloud TPU and Compute Engine resources.
$ gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --zone=${ZONE}
Verify the resources have been deleted by running
gcloud compute tpus execution-groups list
. The deletion might take several minutes. The output from the following command shouldn't include any of the resources created in this tutorial:$ gcloud compute tpus tpu-vm list --zone=${ZONE}