Run a calculation on a Cloud TPU VM using JAX

This document provides a brief introduction to working with JAX and Cloud TPU.

Before you follow this quickstart, you must create a Google Cloud Platform account, install the Google Cloud CLI, and configure the gcloud command. For more information, see Set up an account and a Cloud TPU project.

Install the Google Cloud CLI

The Google Cloud CLI contains tools and libraries for interacting with Google Cloud products and services. For more information, see Installing the Google Cloud CLI.

Configure the gcloud command

Run the following commands to configure gcloud to use your Google Cloud project and install components needed for the TPU VM preview.

  $ gcloud config set account your-email-account
  $ gcloud config set project your-project-id

Enable the Cloud TPU API

  1. Enable the Cloud TPU API using the following gcloud command in Cloud Shell. (You may also enable it from the Google Cloud console).

    $ gcloud services enable tpu.googleapis.com
    
  2. Run the following command to create a service identity.

    $ gcloud beta services identity create --service tpu.googleapis.com
    

Create a Cloud TPU VM with gcloud

With Cloud TPU VMs, your model and code run directly on the TPU host machine. You SSH directly into the TPU host. You can run arbitrary code, install packages, view logs, and debug code directly on the TPU Host.

  1. Create your TPU VM by running the following command from a Cloud Shell or your computer terminal where the Google Cloud CLI is installed.

    (vm)$ gcloud compute tpus tpu-vm create tpu-name \
    --zone=us-central1-a \
    --accelerator-type=v3-8 \
    --version=tpu-ubuntu2204-base
    

    Required fields

    zone
    The zone where you plan to create your Cloud TPU.
    accelerator-type
    The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions.
    version
    The Cloud TPU software version. For all TPU types, use tpu-ubuntu2204-base.

Connect to your Cloud TPU VM

SSH into your TPU VM by using the following command:

$ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central1-a

Required fields

tpu_name
The name of the TPU VM to which you are connecting.
zone
The zone where you created your Cloud TPU.

Install JAX on your Cloud TPU VM

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

System check

Verify that JAX can access the TPU and can run basic operations:

Start the Python 3 interpreter:

(vm)$ python3
>>> import jax

Display the number of TPU cores available:

>>> jax.device_count()

The number of TPU cores is displayed. If you are using a v4 TPU, this should be 4. If you are using a v2 or v3 TPU, this should be 8.

Perform a simple calculation:

>>> jax.numpy.add(1, 1)

The result of the numpy add is displayed:

Output from the command:

Array(2, dtype=int32, weak_type=true)

Exit the Python interpreter:

>>> exit()

Running JAX code on a TPU VM

You can now run any JAX code you want. The flax examples are a great place to start with running standard ML models in JAX. For example, to train a basic MNIST convolutional network:

  1. Install Flax examples dependencies

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
    
  2. Install FLAX

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
    
  3. Run the FLAX MNIST training script

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

The script downloads the dataset and starts training. The script output should look like this:

  0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
  I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
  I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
  I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
  I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

Clean up

When you are done with your TPU VM follow these steps to clean up your resources.

  1. Disconnect from the Compute Engine instance, if you have not already done so:

    (vm)$ exit
    
  2. Delete your Cloud TPU.

    $ gcloud compute tpus tpu-vm delete tpu-name \
      --zone=us-central1-a
    
  3. Verify the resources have been deleted by running the following command. Make sure your TPU is no longer listed. The deletion might take several minutes.

    $ gcloud compute tpus tpu-vm list \
      --zone=us-central1-a
    

Performance Notes

Here are a few important details that are particularly relevant to using TPUs in JAX.

Padding

One of the most common causes for slow performance on TPUs is introducing inadvertent padding:

  • Arrays in the Cloud TPU are tiled. This entails padding one of the dimensions to a multiple of 8, and a different dimension to a multiple of 128.
  • The matrix multiplication unit performs best with pairs of large matrices that minimize the need for padding.

bfloat16 dtype

By default, matrix multiplication in JAX on TPUs uses bfloat16 with float32 accumulation. This can be controlled with the precision argument on relevant jax.numpy function calls (matmul, dot, einsum, etc). In particular:

  • precision=jax.lax.Precision.DEFAULT: uses mixed bfloat16 precision (fastest)
  • precision=jax.lax.Precision.HIGH: uses multiple MXU passes to achieve higher precision
  • precision=jax.lax.Precision.HIGHEST: uses even more MXU passes to achieve full float32 precision

JAX also adds the bfloat16 dtype, which you can use to explicitly cast arrays to bfloat16, for example, jax.numpy.array(x, dtype=jax.numpy.bfloat16).

Running JAX in a Colab

When you run JAX code in a Colab notebook, Colab automatically creates a legacy TPU node. TPU nodes have a different architecture. For more information, see System Architecture.

What's next

For more information about Cloud TPU, see: