This document provides a brief introduction to working with Pax on a single-host TPU (v2-8, v3-8, v4-8).
Pax is a framework to configure and run machine learning experiments on top of JAX. Pax focuses on simplifying ML at scale by sharing infrastructure components with existing ML frameworks and utilizing the Praxis modeling library for modularity.
- Set up TPU resources for training
- Install Pax on a single-host TPU
- Train a transformer based SPMD model using Pax
Before you begin
Run the following commands to configure
gcloud to use
your Cloud TPU project and install components needed to train
a model running Pax on a single-host TPU.
Install the Google Cloud CLI
The Google Cloud CLI contains tools and libraries for interacting with Google Cloud CLI products and services. If you haven't installed it previously, install it now using the instructions in Installing the Google Cloud CLI.
gcloud auth list to see your currently available accounts).
$ gcloud config set account account
$ gcloud config set project project-id
Enable the Cloud TPU API
$ gcloud services enable tpu.googleapis.com
Run the following command to create a service identity (a service account).
$ gcloud beta services identity create --service tpu.googleapis.com
Create a TPU VM
With Cloud TPU VMs, your model and code run directly on the TPU VM. You SSH directly into the TPU VM. You can run arbitrary code, install packages, view logs, and debug code directly on the TPU VM.
Create your TPU VM by running the following command from a Cloud Shell or your computer terminal where the Google Cloud CLI is installed.
zone based on availability in your contract,
reference TPU Regions and Zones
accelerator-type variable to v2-8, v3-8, or v4-8.
version variable to
for v2 and v3 TPU versions or
tpu-vm-v4-base for v4 TPUs.
$ gcloud compute tpus tpu-vm create tpu-name \ --zone zone \ --accelerator-type accelerator-type \ --version version
Connect to your Google Cloud TPU VM
SSH into your TPU VM by using the following command:
$ gcloud compute tpus tpu-vm ssh tpu-name --zone zone
When you are logged into the VM, your shell prompt changes from
Install Pax on the Google Cloud TPU VM
Install Pax, JAX and
libtpu on your TPU VM using the following commands:
(vm)$ python3 -m pip install -U pip \ python3 -m pip install paxml jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Test that everything is installed correctly by checking that JAX sees the TPU cores:
(vm)$ python3 -c "import jax; print(jax.device_count())"
The number of TPU cores is displayed, this should be 8 if you are using a v2-8 or v3-8, or 4 if you are using a v4-8.
Running Pax code on a TPU VM
You can now run any Pax code you wish. The lm_cloud examples are a great place to start running models in Pax. For example, the following commands train a 2B parameter transformer based SPMD language model on synthetic data.
The following commands show training output for a SPMD language model. It trains for 300 step in approximately 20 minutes.
(vm)$ python3 .local/lib/python3.8/site-packages/paxml/main.py --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps --job_log_dir=job_log_dir
On v4-8 slice, the output should include:
Losses and step times
summary tensor at step=step_#
loss = loss
summary tensor at step=step_# Steps/sec x
To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.
When you are done with your TPU VM follow these steps to clean up your resources.
Disconnect from the Compute Engine instance, if you have not already done so:
Delete your Cloud TPU.
$ gcloud compute tpus tpu-vm delete tpu-name --zone zone
For more information about Cloud TPU, see: