Running JAX code on TPU Pod slices

After you have your JAX code running on a single TPU board, you can scale up your code by running it on a TPU Pod slice. TPU Pod slices are multiple TPU boards connected to each other over dedicated high-speed network connections. This document is an introduction to running JAX code on TPU Pod slices; for more in-depth information, see Using JAX in multi-host and multi-process environments.

Create a TPU Pod slice

You create a TPU Pod slice using the gcloud command. For example, to create a v2-32 Pod slice use the following command:

$ gcloud alpha compute tpus tpu-vm create tpu-name \
  --zone europe-west4-a \
  --accelerator-type v2-32 \
  --version v2-alpha

Install JAX on the Pod slice

After creating the TPU Pod slice, you must install JAX on all hosts in the TPU Pod slice. You can install JAX on all hosts with a single command using the --worker=all option:

$ gcloud alpha compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a \
  --worker=all \
  --command="pip install --upgrade jax jaxlib"

Run JAX code on the Pod slice

To run JAX code on a TPU Pod slice, you must run the code on each host in the TPU Pod slice. This means you must ssh into each host and execute the JAX code on each host. The following Python code illustrates how to run a simple JAX calculation on a TPU Pod slice using the gcloud command's --worker=all option.

Prepare code

$ read -r -d '' PYTHON_CMD << EOF
# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the pod
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)
EOF

Run the code on the Pod slice

$ gcloud alpha compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a \
  --worker=all \
  --command "python3 -c \"$PYTHON_CMD\""

Output (produced with a v2-32 pod slice):

global device count: 32
local device count: 8
pmap result: [32. 32. 32. 32. 32. 32. 32. 32.]

This is one way to run JAX Python code on each host, but you can use whatever methods you like. However you run it, the above jax.device_count() call will hang until it's called on each host in the Pod slice, because all hosts must be present in order to initialize the TPU runtime.

Clean up

When you are done, you can release your TPU VM resources using the gcloud command:

$ gcloud alpha compute tpus tpu-vm delete tpu-name \
  --zone europe-west4-a