After you have your JAX code running on a single TPU board, you can scale up your code by running it on a TPU Pod slice. TPU Pod slices are multiple TPU boards connected to each other over dedicated high-speed network connections. This document is an introduction to running JAX code on TPU Pod slices; for more in-depth information, see Using JAX in multi-host and multi-process environments.
If you want to use mounted NFS for data storage you must set OS Login for all TPU VMs in the Pod slice. For more information, see Using an NFS for data storage.
Create a TPU Pod slice
Before running the commands in this document, make sure you have followed the instructions in Set up an account and a Cloud TPU project. Run the following commands on your local machine.
Create a TPU Pod slice using the
gcloud command. For example, to create a
v2-32 Pod slice use the following command:
$ gcloud alpha compute tpus tpu-vm create tpu-name \ --zone europe-west4-a \ --accelerator-type v2-32 \ --version v2-alpha
Install JAX on the Pod slice
After creating the TPU Pod slice, you must install JAX on all hosts in the TPU
Pod slice. You can install JAX on all hosts with a single command using the
$ gcloud alpha compute tpus tpu-vm ssh tpu-name \ --zone europe-west4-a \ --worker=all \ --command="pip install 'jax[tpu]>=0.2.16' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
Run JAX code on the Pod slice
To run JAX code on a TPU Pod slice, you must run the code on each host in the
TPU Pod slice. This means you must ssh into each host and execute the JAX code
on each host. The following Python code illustrates how to run a simple JAX
calculation on a TPU Pod slice using the
Prepare the code
gcloud version >= 344.0.0 (for the
gcloud --version to check your
gcloud components upgrade, if needed.
example.py to the local machine:
cat > example.py << EOF # The following code snippet will be run on all TPU hosts import jax # The total number of TPU cores in the Pod device_count = jax.device_count() # The number of TPU cores attached to this host local_device_count = jax.local_device_count() # The psum is performed over all mapped devices across the Pod xs = jax.numpy.ones(jax.local_device_count()) r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs) # Print from a single host to avoid duplicated output if jax.process_index() == 0: print('global device count:', jax.device_count()) print('local device count:', jax.local_device_count()) print('pmap result:', r) EOF
example.py to all VMs in the Pod slice.
$ gcloud alpha compute tpus tpu-vm scp example.py tpu-name: --worker=all --zone=europe-west4-a
If this is the first time using the
scp command, you might see
an error similar to the following:
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH agent. Please run "ssh-add /.../.ssh/google_compute_engine" to add it, and try again.Run the
ssh-addcommand as displayed in the error message and rerun the command to resolve the error.
Run the code on the Pod slice
example.py program on every VM:
$ gcloud alpha compute tpus tpu-vm ssh tpu-name --zone europe-west4-a --worker=all --command "python3 example.py"
Output (produced with a v2-32 Pod slice):
global device count: 32 local device count: 8 pmap result: [32. 32. 32. 32. 32. 32. 32. 32.]
This is one way to run JAX Python code on each host, but you can use whatever
methods you like. However you run it, the above
jax.device_count() call will
hang until it's called on each host in the Pod slice, because all hosts must be
present in order to initialize the TPU runtime.
When you are done, you can release your TPU VM resources using the
$ gcloud alpha compute tpus tpu-vm delete tpu-name \ --zone europe-west4-a