Run JAX code on TPU slices

Before running the commands in this document, make sure you have followed the instructions in Set up an account and Cloud TPU project.

After you have your JAX code running on a single TPU board, you can scale up your code by running it on a TPU slice. TPU slices are multiple TPU boards connected to each other over dedicated high-speed network connections. This document is an introduction to running JAX code on TPU slices; for more in-depth information, see Using JAX in multi-host and multi-process environments.

If you want to use mounted NFS for data storage, you must set OS Login for all TPU VMs in the slice. For more information, see Using an NFS for data storage.

Create a Cloud TPU slice

  1. Create some environment variables:

    export PROJECT_ID=your-project
    export ACCELERATOR_TYPE=v5p-32
    export ZONE=europe-west4-b
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export TPU_NAME=your-tpu-name

    Environment variable descriptions

    PROJECT_ID
    Your Google Cloud project ID.
    ACCELERATOR_TYPE
    The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions.
    ZONE
    The zone where you plan to create your Cloud TPU.
    RUNTIME_VERSION
    The Cloud TPU runtime version.
    TPU_NAME
    The user-assigned name for your Cloud TPU.
  2. Create a TPU slice using the gcloud command. For example, to create a v5p-32 slice use the following command:

    $ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
    --zone=${ZONE} \
    --project=${PROJECT_ID} \
    --accelerator-type=${ACCELERATOR_TYPE}  \
    --version=${RUNTIME_VERSION} 

Install JAX on your slice

After creating the TPU slice, you must install JAX on all hosts in the TPU slice. You can do this using the gcloud compute tpus tpu-vm ssh command using the --worker=all and --commamnd parameters.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Run JAX code on the slice

To run JAX code on a TPU slice, you must run the code on each host in the TPU slice. The jax.device_count() call stops responding until it is called on each host in the slice. The following example illustrates how to run a JAX calculation on a TPU slice.

Prepare the code

You need gcloud version >= 344.0.0 (for the scp command). Use gcloud --version to check your gcloud version, and run gcloud components upgrade, if needed.

Create a file called example.py with the following code:


import jax

# The total number of TPU cores in the slice
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

Copy example.py to all TPU worker VMs in the slice

$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \
  --worker=all \
  --zone=${ZONE} \
  --project=${PROJECT_ID}

If you have not previously used the scp command, you might see an error similar to the following:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

To resolve the error, run the ssh-add command as displayed in the error message and rerun the command.

Run the code on the slice

Launch the example.py program on every VM:

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command="python3 ./example.py"

Output (produced with a v4-32 slice):

global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]

Clean up

When you are done with your TPU VM follow these steps to clean up your resources.

  1. Disconnect from the Compute Engine instance, if you have not already done so:

    (vm)$ exit

    Your prompt should now be username@projectname, showing you are in the Cloud Shell.

  2. Delete your Cloud TPU and Compute Engine resources.

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID}
  3. Verify the resources have been deleted by running gcloud compute tpus execution-groups list. The deletion might take several minutes. The output from the following command shouldn't include any of the resources created in this tutorial:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE} \
    --project=${PROJECT_ID}