This document provides a brief introduction to working with Pax on a single-host TPU (v2-8, v3-8, v4-8).
Pax is a framework to configure and run machine learning experiments on top of JAX. Pax focuses on simplifying ML at scale by sharing infrastructure components with existing ML frameworks and utilizing the Praxis modeling library for modularity.
Objectives
- Set up TPU resources for training
- Install Pax on a single-host TPU
- Train a transformer based SPMD model using Pax
Before you begin
Run the following commands to configure gcloud
to use
your Cloud TPU project and install components needed to train
a model running Pax on a single-host TPU.
Install the Google Cloud CLI
The Google Cloud CLI contains tools and libraries for interacting with Google Cloud CLI products and services. If you haven't installed it previously, install it now using the instructions in Installing the Google Cloud CLI.
Configure the gcloud
command
(Run gcloud auth list
to see your available accounts).
$ gcloud config set account account
$ gcloud config set project project-id
Enable the Cloud TPU API
Enable the Cloud TPU API using the following gcloud
command
in Cloud Shell.
(You may also enable it from the
Google Cloud console).
$ gcloud services enable tpu.googleapis.com
Run the following command to create a service identity (a service account).
$ gcloud beta services identity create --service tpu.googleapis.com
Create a TPU VM
With Cloud TPU VMs, your model and code run directly on the TPU VM. You SSH directly into the TPU VM. You can run arbitrary code, install packages, view logs, and debug code directly on the TPU VM.
Create your TPU VM by running the following command from a Cloud Shell or your computer terminal where the Google Cloud CLI is installed.
Set the zone
based on availability in your contract,
reference TPU Regions and Zones
if needed.
Set the accelerator-type
variable to v2-8, v3-8, or v4-8.
Set the version
variable to tpu-vm-base
for v2 and v3 TPU versions or tpu-vm-v4-base
for v4 TPUs.
$ gcloud compute tpus tpu-vm create tpu-name \
--zone zone \
--accelerator-type accelerator-type \
--version version
Connect to your Google Cloud TPU VM
SSH into your TPU VM by using the following command:
$ gcloud compute tpus tpu-vm ssh tpu-name --zone zone
When you are logged into the VM, your shell prompt changes from
username@projectname
to username@vm-name
:
Install Pax on the Google Cloud TPU VM
Install Pax, JAX and libtpu
on your TPU VM using the following commands:
(vm)$ python3 -m pip install -U pip \
python3 -m pip install paxml jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
System check
Test that everything is installed correctly by checking that JAX sees the TPU cores:
(vm)$ python3 -c "import jax; print(jax.device_count())"
The number of TPU cores is displayed, this should be 8 if you are using a v2-8 or v3-8, or 4 if you are using a v4-8.
Running Pax code on a TPU VM
You can now run any Pax code you wish. The lm_cloud examples are a great place to start running models in Pax. For example, the following commands train a 2B parameter transformer based SPMD language model on synthetic data.
The following commands show training output for a SPMD language model. It trains for 300 step in approximately 20 minutes.
(vm)$ python3 .local/lib/python3.10/site-packages/paxml/main.py --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps --job_log_dir=job_log_dir
On v4-8 slice, the output should include:
Losses and step times
summary tensor at step=step_# loss
= loss
summary tensor at step=step_# Steps per second x
Clean up
To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.
When you are done with your TPU VM follow these steps to clean up your resources.
Disconnect from the Compute Engine instance, if you have not already done so:
(vm)$ exit
Delete your Cloud TPU.
$ gcloud compute tpus tpu-vm delete tpu-name --zone zone
What's next
For more information about Cloud TPU, see: