Run a calculation on a Cloud TPU VM using JAX
This document provides a brief introduction to working with JAX and Cloud TPU.
Before you begin
Before running the commands in this document, you must create a Google Cloud
account, install the Google Cloud CLI, and configure the gcloud
command. For
more information, see Set up the Cloud TPU environment.
Create a Cloud TPU VM using gcloud
Define some environment variables to make commands easier to use.
export PROJECT_ID=your-project export ACCELERATOR_TYPE=v5p-8 export ZONE=us-east5-a export RUNTIME_VERSION=v2-alpha-tpuv5 export TPU_NAME=your-tpu-name
Environment variable descriptions
PROJECT_ID
- Your Google Cloud project ID.
ACCELERATOR_TYPE
- The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions.
ZONE
- The zone where you plan to create your Cloud TPU.
RUNTIME_VERSION
- The Cloud TPU runtime version. For more information, see TPU VM images.
TPU_NAME
- The user-assigned name for your Cloud TPU.
Create your TPU VM by running the following command from a Cloud Shell or your computer terminal where the Google Cloud CLI is installed.
$ gcloud compute tpus tpu-vm create $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
Connect to your Cloud TPU VM
Connect to your TPU VM over SSH by using the following command:
$ gcloud compute tpus tpu-vm ssh $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
Install JAX on your Cloud TPU VM
(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
System check
Verify that JAX can access the TPU and can run basic operations:
Start the Python 3 interpreter:
(vm)$ python3
>>> import jax
Display the number of TPU cores available:
>>> jax.device_count()
The number of TPU cores is displayed. The number of cores displayed is dependent on the TPU version you are using. For more information, see TPU versions.
Perform a calculation:
>>> jax.numpy.add(1, 1)
The result of the numpy add is displayed:
Output from the command:
Array(2, dtype=int32, weak_type=true)
Exit the Python interpreter:
>>> exit()
Running JAX code on a TPU VM
You can now run any JAX code you want. The flax examples are a great place to start with running standard ML models in JAX. For example, to train a basic MNIST convolutional network:
Install Flax examples dependencies
(vm)$ pip install --upgrade clu (vm)$ pip install tensorflow (vm)$ pip install tensorflow_datasets
Install FLAX
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user flax
Run the FLAX MNIST training script
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
The script downloads the dataset and starts training. The script output should look like this:
0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88 I0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72 I0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04 I0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15 I0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18
Clean up
To avoid incurring charges to your Google Cloud account for the resources used on this page, follow these steps.
When you are done with your TPU VM follow these steps to clean up your resources.
Disconnect from the Compute Engine instance, if you have not already done so:
(vm)$ exit
Delete your Cloud TPU.
$ gcloud compute tpus tpu-vm delete $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
Verify the resources have been deleted by running the following command. Make sure your TPU is no longer listed. The deletion might take several minutes.
$ gcloud compute tpus tpu-vm list \ --zone=$ZONE
Performance notes
Here are a few important details that are particularly relevant to using TPUs in JAX.
Padding
One of the most common causes for slow performance on TPUs is introducing inadvertent padding:
- Arrays in the Cloud TPU are tiled. This entails padding one of the dimensions to a multiple of 8, and a different dimension to a multiple of 128.
- The matrix multiplication unit performs best with pairs of large matrices that minimize the need for padding.
bfloat16 dtype
By default, matrix multiplication in JAX on TPUs uses bfloat16 with float32 accumulation. This can be controlled with the precision argument on relevant jax.numpy function calls (matmul, dot, einsum, etc). In particular:
precision=jax.lax.Precision.DEFAULT
: uses mixed bfloat16 precision (fastest)precision=jax.lax.Precision.HIGH
: uses multiple MXU passes to achieve higher precisionprecision=jax.lax.Precision.HIGHEST
: uses even more MXU passes to achieve full float32 precision
JAX also adds the bfloat16 dtype, which you can use to explicitly cast arrays to
bfloat16
, for example,
jax.numpy.array(x, dtype=jax.numpy.bfloat16)
.
What's next
For more information about Cloud TPU, see: