Train on a single-host TPU using Pax

This document provides a brief introduction to working with Pax on a single-host TPU (v2-8, v3-8, v4-8).

Pax is a framework to configure and run machine learning experiments on top of JAX. Pax focuses on simplifying ML at scale by sharing infrastructure components with existing ML frameworks and utilizing the Praxis modeling library for modularity.


  • Set up TPU resources for training
  • Install Pax on a single-host TPU
  • Train a transformer based SPMD model using Pax

Before you begin

Run the following commands to configure gcloud to use your Cloud TPU project and install components needed to train a model running Pax on a single-host TPU.

Install the Google Cloud CLI

The Google Cloud CLI contains tools and libraries for interacting with Google Cloud CLI products and services. If you haven't installed it previously, install it now using the instructions in Installing the Google Cloud CLI.

Configure the gcloud command

(Run gcloud auth list to see your currently available accounts).

$ gcloud config set account account

$ gcloud config set project project-id

Enable the Cloud TPU API

Enable the Cloud TPU API using the following gcloud command in Cloud Shell. (You may also enable it from the Google Cloud console).

$ gcloud services enable

Run the following command to create a service identity (a service account).

$ gcloud beta services identity create --service

Create a TPU VM

With Cloud TPU VMs, your model and code run directly on the TPU VM. You SSH directly into the TPU VM. You can run arbitrary code, install packages, view logs, and debug code directly on the TPU VM.

Create your TPU VM by running the following command from a Cloud Shell or your computer terminal where the Google Cloud CLI is installed.

Set the zone based on availability in your contract, reference TPU Regions and Zones if needed.

Set the accelerator-type variable to v2-8, v3-8, or v4-8.

Set the version variable to tpu-vm-base for v2 and v3 TPU versions or tpu-vm-v4-base for v4 TPUs.

$ gcloud compute tpus tpu-vm create tpu-name \
--zone zone \
--accelerator-type accelerator-type \
--version version

Connect to your Google Cloud TPU VM

SSH into your TPU VM by using the following command:

$ gcloud compute tpus tpu-vm ssh tpu-name --zone zone

When you are logged into the VM, your shell prompt changes from username@projectname to username@vm-name:

Install Pax on the Google Cloud TPU VM

Install Pax, JAX and libtpu on your TPU VM using the following commands:

(vm)$ python3 -m pip install -U pip \
python3 -m pip install paxml jax[tpu] 

System check

Test that everything is installed correctly by checking that JAX sees the TPU cores:

(vm)$ python3 -c "import jax; print(jax.device_count())"

The number of TPU cores is displayed, this should be 8 if you are using a v2-8 or v3-8, or 4 if you are using a v4-8.

Running Pax code on a TPU VM

You can now run any Pax code you wish. The lm_cloud examples are a great place to start running models in Pax. For example, the following commands train a 2B parameter transformer based SPMD language model on synthetic data.

The following commands show training output for a SPMD language model. It trains for 300 step in approximately 20 minutes.

(vm)$ python3 .local/lib/python3.8/site-packages/paxml/  --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps --job_log_dir=job_log_dir

On v4-8 slice, the output should include:

Losses and step times

summary tensor at step=step_# loss = loss
summary tensor at step=step_# Steps/sec x

Clean up

To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.

When you are done with your TPU VM follow these steps to clean up your resources.

Disconnect from the Compute Engine instance, if you have not already done so:

(vm)$ exit

Delete your Cloud TPU.

$ gcloud compute tpus tpu-vm delete tpu-name  --zone zone

What's next

For more information about Cloud TPU, see: