This document describes provides a brief introduction to working with JAX and Cloud TPU.
Before you follow this quickstart, you must create a Google Cloud Platform
account, install the Cloud SDK. and configure the
For more information, see Set up an account and a Cloud TPU project.
Install the Cloud SDK
The Cloud SDK contains tools and libraries for interacting with Google Cloud products and services. For more information, see Installing the Cloud SDK.
Run the following commands to configure
gcloud to use your GCP project and
install components needed for the TPU VM preview.
$ gcloud config set account your-email-account $ gcloud config set project project-id
Enable the Cloud TPU API
$ gcloud services enable tpu.googleapis.com
Run the following command to create a service identity.
$ gcloud beta services identity create --service tpu.googleapis.com
Create a Cloud TPU VM with
With Cloud TPU VMs, your model and code run directly on the TPU host machine. You SSH directly into the TPU host. You can run arbitrary code, install packages, view logs, and debug code directly on the TPU Host.
Create your TPU VM by running the following command from a GCP Cloud Shell or your computer terminal where the Cloud SDK is installed.
(vm)$ gcloud alpha compute tpus tpu-vm create tpu-name \ --zone europe-west4-a \ --accelerator-type v3-8 \ --version v2-alpha
Connect to your Cloud TPU VM
SSH into your TPU VM by using the following command:
$ gcloud alpha compute tpus tpu-vm ssh tpu-name --zone europe-west4-a
- The name of the TPU VM to which you are connecting.
- The zone where you created your Cloud TPU.
Install JAX on your Cloud TPU VM
(vm)$ pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Test that everything is installed correctly by checking that JAX sees the Cloud TPU cores and can run basic operations:
Start the Python 3 interpreter:
>>> import jax
Display the number of TPU cores available:
The number of TPU cores is displayed, this should be
Perform a simple calculation:
>>> jax.numpy.add(1, 1)
The result of the numpy add is displayed:
Output from the command:
Exit the Python interpreter:
Running JAX code on a TPU VM
You can now run any JAX code you please. The flax examples are a great place to start with running standard ML models in JAX. For instance, to train a basic MNIST convolutional network:
Install Tensorflow datasets
(vm)$ pip install --upgrade clu
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user -e flax
Run the FLAX MNIST training script
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
The script output should look like this:
I0513 21:09:35.448946 140431261813824 train.py:125] train epoch: 1, loss: 0.2312, accuracy: 93.00 I0513 21:09:36.402860 140431261813824 train.py:176] eval epoch: 1, loss: 0.0563, accuracy: 98.05 I0513 21:09:37.321380
When you are done with your TPU VM follow these steps to clean up your resources.
Disconnect from the Compute Engine instance, if you have not already done so:
Delete your Cloud TPU.
$ gcloud alpha compute tpus tpu-vm delete tpu-name \ --zone europe-west4-a
Verify the resources have been deleted by running the following command. Make sure your TPU is no longer listed. The deletion might take several minutes.
Here are a few important details that are particularly relevant to using TPUs in JAX.
One of the most common causes for slow performance on TPUs is introducing inadvertent padding:
- Arrays in the Cloud TPU are tiled. This entails padding one of the dimensions to a multiple of 8, and a different dimension to a multiple of 128.
- The matrix multiplication unit performs best with pairs of large matrices that minimize the need for padding.
By default, matrix multiplication in JAX on TPUs uses bfloat16 with float32 accumulation. This can be controlled with the precision argument on relevant jax.numpy function calls (matmul, dot, einsum, etc). In particular:
precision=jax.lax.Precision.DEFAULT: uses mixed bfloat16 precision (fastest)
precision=jax.lax.Precision.HIGH: uses multiple MXU passes to achieve higher precision
precision=jax.lax.Precision.HIGHEST: uses even more MXU passes to achieve full float32 precision
JAX also adds the bfloat16 dtype, which you can use to explicitly cast arrays to
Running JAX in a Colab
When you run JAX code in a Colab notebook, Colab automatically creates a legacy TPU node. TPU nodes have a different architecture. For more information, see System Architecture.