Trillium (v6e) introduction

v6e is used to refer to Trillium in this documentation, TPU API, and logs. v6e represents Google's 6th generation of TPU.

With 256 chips per Pod, v6e shares many similarities with v5e. This system is optimized to be the highest value product for transformer, text-to-image, and convolutional neural network (CNN) training, fine-tuning, and serving.

v6e system architecture

For information about Cloud TPU configuration, see the v6e documentation.

This document focuses on the setup process for model training using JAX, PyTorch, or TensorFlow frameworks. With each framework, you can provision TPUs using queued resources or Google Kubernetes Engine (GKE). GKE setup can be done using XPK or GKE commands.

Prepare a Google Cloud project

  1. Sign in to your Google Account. If you haven't already, sign up for a new account.
  2. In the Google Cloud console, select or create a Cloud project from the project selector page.
  3. Enable billing for your Google Cloud project. Billing is required for all Google Cloud usage.
  4. Install the gcloud alpha components.
  5. Run the following command to make install the latest version of gcloudcomponents.

    gcloud components update
    
  6. Enable the TPU API through the following gcloud command in Cloud Shell. You can also enable it from the Google Cloud console.

    gcloud services enable tpu.googleapis.com
    
  7. Enable permissions with the TPU service account for Compute Engine API

    Service accounts allow the Cloud TPU service to access other Google Cloud services. A user-managed service account is a recommended Google Cloud practice. Follow these guides to create and grant roles. The following roles are necessary:

    • TPU Admin
    • Storage Admin
    • Logs Writer
    • Monitoring Metric Writer

    a. Set up XPK permissions with your user account for GKE: XPK.

  8. Create environment variables for the project ID and zone.

     gcloud auth login
     gcloud config set project ${PROJECT_ID}
     gcloud config set compute/zone ${ZONE}
    
  9. Create a service identity for the TPU VM.

     gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
    

Secure capacity

Contact your Cloud TPU support sales/account to request TPU quota and to answer any questions about capacity.

Provision the Cloud TPU environment

v6e TPUs can be provisioned and managed with GKE, with GKE and XPK (a wrapper CLI tool over GKE), or as queued resources.

Prerequisites

  • Verify that your project has enough TPUS_PER_TPU_FAMILY quota, which specifies the maximum number of chips you can access within your Google Cloud project.
  • v6e has been tested with the following configuration:
    • python 3.10 or later
    • Nightly software versions:
      • nightly JAX 0.4.32.dev20240912
      • nightly LibTPU 0.1.dev20240912+nightly
    • Stable software versions:
      • JAX + JAX Lib of v0.4.35
  • Verify that your project has enough TPU quota for:
    • TPU VM quota
    • IP Address quota
    • Hyperdisk-balance quota
  • User project permissions

Environment variables

In a Cloud Shell, create the following environment variables:

export NODE_ID=TPU_NODE_ID # TPU name
export PROJECT_ID=PROJECT_ID
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-central2-b
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=YOUR_SERVICE_ACCOUNT
export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID
export VALID_DURATION=VALID_DURATION

# Additional environment variable needed for Multislice:
export NUM_SLICES=NUM_SLICES

# Use a custom network for better performance as well as to avoid having the
# default network becoming overloaded.
export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

Command flag descriptions

Variable Description
NODE_ID The user-assigned ID of the TPU which is created when the queued resource request is allocated.
PROJECT_ID Google Cloud Project Name. Use an existing project or create a new one at
ZONE See the TPU regions and zones document for the supported zones.
ACCELERATOR_TYPE See Accelerator Types.
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT This is the email address for your service account that you can find in Google Cloud Console -> IAM -> Service Accounts

For example: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

NUM_SLICES The number of slices to create (needed for Multislice only)
QUEUED_RESOURCE_ID The user-assigned text ID of the queued resource request.
VALID_DURATION The duration for which the queued resource request is valid.
NETWORK_NAME The name of a secondary network to use.
NETWORK_FW_NAME The name of a secondary network firewall to use.

Network performance optimizations

For best performance use a network with 8,896 MTU (maximum transmission unit).

By default, a Virtual Private Cloud (VPC) only provides an MTU of 1,460 bytes which will provide suboptimal network performance. You can set a VPC network's MTU to any value between 1,300 bytes and 8,896 bytes (inclusive). Common custom MTU sizes are 1,500 bytes (standard Ethernet) or 8,896 bytes (the maximum possible). For more information, see Valid VPC network MTU sizes.

For more information about changing the MTU setting for an existing or default network, see Change the MTU setting of a VPC network.

The following example creates a network with 8,896 MTU

export RESOURCE_NAME=RESOURCE_NAME
export NETWORK_NAME=${RESOURCE_NAME}
export NETWORK_FW_NAME=${RESOURCE_NAME}
export PROJECT=X
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT} --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} \

Using multi-NIC (Option for Multislice)

The following environment variables are needed for a secondary subnet when you are using a Multislice environment.

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2

Use the following commands to create custom IP routing for the network and subnet.

gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
   --bgp-routing-mode=regional --subnet-mode=custom --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
   --network="${NETWORK_NAME_2}" \
   --range=10.10.0.0/18 --region="${REGION}" \
   --project=$PROJECT

gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
   --network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project="${PROJECT}"

gcloud compute routers create "${ROUTER_NAME}" \
  --project="${PROJECT}" \
  --network="${NETWORK_NAME_2}" \
  --region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \
  --router="${ROUTER_NAME}" \
  --region="${REGION}" \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project="${PROJECT}" \
  --enable-logging

Once a multi-network slice has been created, you can validate that both NICs are being used by running --command ifconfig as part of the XPK workload. Then, look at the printed output of that XPK workload in Cloud console logs and check that both eth0 and eth1 have mtu=8896.

python3 xpk.py workload create \
   --cluster ${CLUSTER_NAME} \
   (--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}) \
   --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone $ZONE \
   --project $PROJECT_ID \
   [--enable-debug-logs] \
   [--use-vertex-tensorboard] \
   --command "ifconfig"

Verify that both eth0 and eth1 have mtu=8,896. a way to verify you have multi-nic running is by running the command --command "ifconfig" as part of the XPK workload. Then look at the printed output of that xpk workload in cloud console logs and check that both eth0 and eth1 have mtu=8896.

Improved TCP settings

For TPUs created using the queued resources interface, you can run the following command to improve network performance by changing the default TCP settings for rto_min and quickack.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
   --project "$PROJECT" --zone "${ZONE}" \
   --command='ip route show | while IFS= read -r route; do if ! echo $route | \
   grep -q linkdown; then sudo ip route change ${route/lock/} rto_min 5ms quickack 1; fi; done' \
   --worker=all

Provisioning with queued resources (Cloud TPU API)

Capacity can be provisioned using the queued-resource create command.

  1. Create a TPU queued resource request.

    The --reserved flag is only needed for reserved resources, not for on-demand resources.

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --valid-until-duration ${VALID_DURATION} \
      --service-account ${SERVICE_ACCOUNT} \
      [--reserved]

    If the queued resource request is created successfully, the state within the "response" field will be either "WAITING_FOR_RESOURCES" or "FAILED". If the queued resource request is in the "WAITING_FOR_RESOURCES" state, the queued resource has been enqueued and will be provisioned when there is enough TPU capacity. If the queued resource request is in the "FAILED" state, the failure reason will be in the output. The queued resource request will expire if a v6e isn't provisioned within the specified duration, and the state becomes "FAILED". See the Queued Resources public documentation for more information.

    When your queued resource request is in the "ACTIVE" state, you can connect to your TPU VMs using SSH. Use the list or describe commands to query the status of your queued resource.

    gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project ${PROJECT_ID} --zone ${ZONE}
    

    When the queued resource is in the "ACTIVE" state, the output is similar to the following:

      state:
       state: ACTIVE
    
  2. Manage your TPU VMs. For options to manage your TPU VMs, see managing TPU VMs.

  3. Connect to your TPU VMs using SSH

    You can install binaries on each TPU VM in your TPU slice and run code. See the VM Types section to determine how many VMs your slice will have.

    To install the binaries or run code, you can use SSH to connect to a VM using the tpu-vm ssh command.

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --node=all # add this flag if you are using Multislice
    

    To use SSH to connect to a specific VM, use the --worker flag which follows a 0-based index:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
    

    If you have slice shapes greater than 8 chips, you will have multiple VMs in one slice. In this case use the --worker=all and --command parameters in your gcloud alpha compute tpus tpu-vm ssh command to run a command on all VMs simultaneously. For example:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
      --zone  ${ZONE} --worker=all \
      --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
      -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Delete a queued resource

    Delete a queued resource at the end of the session or remove queued resource requests that are in the "FAILED" state. To delete a queued resource, delete the slice and then the queued resource request in 2 steps:

    gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \
     --zone=${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
     --project ${PROJECT_ID} --zone ${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
    

Using GKE with v6e

If you are using GKE commands with v6e, you can use Kubernetes commands or XPK to provision TPUs and train or serve models. See Plan for TPUs in GKE to learn how to use GKE with TPUs and v6e.

Framework setup

This section describes the general setup process for ML model training using JAX, PyTorch, or TensorFlow frameworks. You can provision TPUs using queued resources or GKE. GKE setup can be done using XPK or Kubernetes commands.

Setup JAX using queued resources

Install JAX on all TPU VMs in your slice or slices simultaneously using gcloud alpha compute tpus tpu-vm ssh. For Multislice, add --node=all.


gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'

You can run the following Python code to check how many TPU cores are available in your slice and to test that everything is installed correctly (the outputs shown here were produced with a v6e-16 slice):

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

The output is similar to the following:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() shows the total number of chips in the given slice. jax.local_device_count() indicates the count of chips accessible by a single VM in this slice.

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 &&
   && pip install -r requirements.txt  && pip install . '

Troubleshooting JAX setups

A general tip is to enable verbose logging in your GKE workload manifest. Then, provide the logs to GKE support.

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

Error messages

no endpoints available for service 'jobset-webhook-service'

This error means the jobset wasn't installed properly. Check to see if jobset-controller-manager deployment Kubernetes Pods are running. For more information, see the JobSet troubleshooting documentation for details.

TPU initialization failed: Failed to connect

Make sure your GKE node version is 1.30.4-gke.1348000 or later (GKE 1.31 is not supported).

Setup for PyTorch

This section describes how to start using PJRT on v6e with PyTorch/XLA. Python 3.10 is the recommended Python version.

Setup PyTorch using GKE with XPK

You can use the following Docker container with XPK which has PyTorch dependencies already installed:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

To create a XPK workload, use the following command:

python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -$NUM_SLICES \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES}  \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

Using --base-docker-image creates a new Docker image with the current working directory built into the new Docker.

Setup up PyTorch using queued resources

Follow these steps to install PyTorch using queued resources and run a small script on v6e.

Install dependencies using SSH to access the VMs.

For Multislice, add --node=all:

   gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt install -y libopenblas-base pip3 \
    install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
    --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

Improve performance of models with sizable, frequent allocations

For models which have sizable, frequent allocations we've observed that using tcmalloc improves performance significantly compared to the default malloc implementation, so the default malloc used on TPU VM is tcmalloc. However, depending on your workload (for example, with DLRM which has very large allocations for its embedding tables) tcmalloc may cause a slowdown in which case you may try to unset the following variable using the default malloc instead:

unset LD_PRELOAD

Use a Python script to do a calculation on v6e VM:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

This generates output similar to the following:

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

Setup for TensorFlow

For v6e Public Preview, only the tf-nightly runtime version is supported.

You can reset tpu-runtime with the v6e compatible TensorFlow version by running the following commands:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
    --zone  ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
    --zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

Use SSH to access worker-0:

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
     --zone ${ZONE}

Install TensorFlow on worker-0:

sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

Export the TPU_NAME environment variable:

export TPU_NAME=v6e-16

You can run the following Python script to check how many TPU cores are available in your slice and to test that everything is installed correctly (the outputs shown were produced with a v6e-16 slice):

import TensorFlow as tf
print("TensorFlow version " + tf.__version__)

@tf.function
  def add_fn(x,y):
  z = x + y
  return z

  cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  tf.config.experimental_connect_to_cluster(cluster_resolver)
  tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
  strategy = tf.distribute.TPUStrategy(cluster_resolver)

  x = tf.constant(1.)
  y = tf.constant(1.)
  z = strategy.run(add_fn, args=(x,y))
  print(z)

The output is similar to the following:

PerReplica:{
  0: tf.Tensor(2.0, shape=(), dtype=float32),
  1: tf.Tensor(2.0, shape=(), dtype=float32),
  2: tf.Tensor(2.0, shape=(), dtype=float32),
  3: tf.Tensor(2.0, shape=(), dtype=float32),
  4: tf.Tensor(2.0, shape=(), dtype=float32),
  5: tf.Tensor(2.0, shape=(), dtype=float32),
  6: tf.Tensor(2.0, shape=(), dtype=float32),
  7: tf.Tensor(2.0, shape=(), dtype=float32)
}

v6e with SkyPilot

You can use TPU v6e with SkyPilot. Use the following steps to add v6e-related location/pricing information to SkyPilot.

  1. Add the following to the end of ~/.sky/catalogs/v5/gcp/vms.csv :

    ,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
    
  2. Specify the following resources in a YAML file:

    # tpu_v6.yaml
    resources:
      accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use
      accelerator_args:
        runtime_version: v2-alpha-tpuv6e # Official suggested runtime
    
  3. Launch a cluster with TPU v6e:

       sky launch tpu_v6.yaml -c tpu_v6
    
  4. Connect to the TPU v6e using SSH: ssh tpu_v6

Inference tutorials

The following sections provide tutorials for serving MaxText and PyTorch models using JetStream, as well as serving MaxDiffusion models on TPU v6e.

MaxText on JetStream

This tutorial shows how to use JetStream to serve MaxText (JAX) models on TPU v6e. JetStream is a throughput and memory optimized engine for large language model (LLM) inference on XLA devices (TPUs). In this tutorial, you will run the inference benchmark for the Llama2-7B model.

Before you begin

  1. Create a TPU v6e with 4 chips:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. Connect to the TPU using SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME

Run the tutorial

To set up JetStream and MaxText, convert the model checkpoints, and run the inference benchmark, follow the instructions in the GitHub repository.

Clean up

Delete the TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

vLLM on PyTorch TPU

Below is a simple tutorial showing how to get started with vLLM on TPU VM. For our best practices example of deploying vLLM on Trillium in production, we will be publishing a GKE user guide in the next few days (stay tuned!).

Before you begin

  1. Create a TPU v6e with 4 chips:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
       --node-id TPU_NAME \
       --project PROJECT_ID \
       --zone ZONE \
       --accelerator-type v6e-4 \
       --runtime-version v2-alpha-tpuv6e \
       --service-account SERVICE_ACCOUNT

    Command flag descriptions

    Variable Description
    NODE_ID The user-assigned ID of the TPU which is created when the queued resource request is allocated.
    PROJECT_ID Google Cloud Project Name. Use an existing project or create a new one at
    ZONE See the TPU regions and zones document for the supported zones.
    ACCELERATOR_TYPE See Accelerator Types.
    RUNTIME_VERSION v2-alpha-tpuv6e
    SERVICE_ACCOUNT This is the email address for your service account that you can find in Google Cloud Console -> IAM -> Service Accounts

    For example: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

  2. Connect to the TPU using SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME
    

Create a Conda environment

  1. (Recommended) Create a new conda environment for vLLM:

    conda create -n vllm python=3.10 -y
    conda activate vllm

Set up vLLM on TPU

  1. Clone the vLLM repository and navigate to the vLLM directory:

    git clone https://github.com/vllm-project/vllm.git && cd vllm
    
  2. Clean up the existing torch and torch-xla packages:

    pip uninstall torch torch-xla -y
    
  3. Install PyTorch and PyTorch XLA:

    pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
    
  4. Install JAX and Pallas:

    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
    pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
    
    
  5. Install other build dependencies:

    pip install -r requirements-tpu.txt
    VLLM_TARGET_DEVICE="tpu" python setup.py develop
    sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
    

Get access to the model

You must sign the consent agreement to use Llama3 family of models in the HuggingFace repo

Generate a new Hugging Face token if you don't already have one:

  1. Click Your Profile > Settings > Access Tokens.
  2. Select New Token.
  3. Specify a Name of your choice and a Role of at least Read.
  4. Select Generate a token.
  5. Copy the generated token to your clipboard, set it as an environment variable and authenticate with the huggingface-cli:

    export TOKEN=''
    git config --global credential.helper store
    huggingface-cli login --token $TOKEN

Download Benchmarking Data

  1. Create a /data directory and download the ShareGPT dataset from Hugging Face.

    mkdir ~/data && cd ~/data
    wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
    

Launch the vLLM server

The following command downloads the model weights from Hugging Face Model Hub to the TPU VM's /tmp directory, pre-compiles a range of input shapes and writes the model compilation to ~/.cache/vllm/xla_cache.

For more details, please refer to the vLLM docs.

   cd ~/vllm
   vllm serve "meta-llama/Meta-Llama-3.1-8B" --download_dir /tmp --num-scheduler-steps 4 --swap-space 16 --disable-log-requests --tensor_parallel_size=4 --max-model-len=2048 &> serve.log &

Run vLLM Benchmarks

Run the vLLM benchmarking script:

   python benchmarks/benchmark_serving.py \
       --backend vllm \
       --model "meta-llama/Meta-Llama-3.1-8B"  \
       --dataset-name sharegpt \
       --dataset-path ~/data/ShareGPT_V3_unfiltered_cleaned_split.json  \
       --num-prompts 1000

Clean up

Delete the TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

PyTorch on JetStream

This tutorial shows how to use JetStream to serve PyTorch models on TPU v6e. JetStream is a throughput and memory optimized engine for large language model (LLM) inference on XLA devices (TPUs). In this tutorial, you will run the inference benchmark for the Llama2-7B model.

Before you begin

  1. Create a TPU v6e with 4 chips:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. Connect to the TPU using SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME

Run the tutorial

To set up JetStream-PyTorch, convert the model checkpoints, and run the inference benchmark, follow the instructions in the GitHub repository.

Clean up

Delete the TPU:

   gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --force \
      --async

MaxDiffusion inference

This tutorial shows how to serve MaxDiffusion models on TPU v6e. In this tutorial, you will generate images using the Stable Diffusion XL model.

Before you begin

  1. Create a TPU v6e with 4 chips:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. Connect to the TPU using SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME

Create a Conda environment

  1. Create a directory for Miniconda:

    mkdir -p ~/miniconda3
  2. Download the Miniconda installer script:

    wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
  3. Install Miniconda:

    bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
  4. Remove the Miniconda installer script:

    rm -rf ~/miniconda3/miniconda.sh
  5. Add Miniconda to your PATH variable:

    export PATH="$HOME/miniconda3/bin:$PATH"
  6. Reload ~/.bashrc to apply the changes to the PATH variable:

    source ~/.bashrc
  7. Create a new Conda environment:

    conda create -n tpu python=3.10
  8. Activate the Conda environment:

    source activate tpu

Set up MaxDiffusion

  1. Clone the MaxDiffusion repository and navigate to the MaxDiffusion directory:

    https://github.com/google/maxdiffusion.git && cd maxdiffusion
  2. Switch to the mlperf-4.1 branch:

    git checkout mlperf4.1
  3. Install MaxDiffusion:

    pip install -e .
  4. Install dependencies:

    pip install -r requirements.txt
  5. Install JAX:

    pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Generate images

  1. Set environment variables to configure the TPU runtime:

    LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536"
  2. Generate images using the prompt and configurations defined in src/maxdiffusion/configs/base_xl.yml:

    python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run"

Clean up

Delete the TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

Training tutorials

The following sections provide tutorials for training MaxText,

MaxDiffusion and PyTorch models on TPU v6e.

MaxText and MaxDiffusion

The following sections cover the training lifecycle of the MaxText and MaxDiffusion models.

In general, the high-level steps are:

  1. Build the workload base image.
  2. Run your workload using XPK.
    1. Build the training command for the workload.
    2. Deploy the workload.
  3. Follow the workload and view metrics.
  4. Delete the XPK workload if it isn't needed.
  5. Delete the XPK cluster when it's no longer needed.

Build base image

Install MaxText or MaxDiffusion and build the Docker image:

  1. Clone the repository you want to use and change to the directory for the repository:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. Configure Docker to use the Google Cloud CLI:

    gcloud auth configure-docker
    
  3. Build the Docker image using the following command or using JAX Stable Stack. For more information about JAX Stable Stack, see Build Docker image with JAX Stable Stack.

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    
  4. If you're launching the workload from a machine that doesn't have the image built locally, upload the image:

    bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
    
Build a Docker image with JAX Stable Stack

You can build the MaxText and MaxDiffusion Docker images using the JAX Stable Stack base image.

JAX Stable Stack provides a consistent environment for MaxText and MaxDiffusion by bundling JAX with core packages like orbax, flax, and optax, along with a well-qualified libtpu.so that drives TPU program utilities and other essential tools. These libraries are tested to ensure compatibility, providing a stable foundation for building and running MaxText and MaxDiffusion and eliminating potential conflicts due to incompatible package versions.

JAX Stable Stack includes a fully released and qualified libtpu.so, the core library that drives TPU program compilation, execution, and ICI network configuration. The libtpu release replaces the nightly build previously used by JAX, and ensures consistent functionality of XLA computations on TPU with PJRT-level qualification tests in HLO/StableHLO IRs.

To build the MaxText and MaxDiffusion Docker image with JAX Stable Stack, when you run the docker_build_dependency_image.sh script, set the MODE variable to stable_stack and set the BASEIMAGE variable to the base image you want to use.

The following example specifies us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1 as the base image:

bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1

For a list of available JAX Stable Stack base images, see JAX Stable Stack images in Artifact Registry.

Run your workload using XPK

  1. Set the following environment variables if you're not using the default values set by MaxText or MaxDiffusion:

    BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    PER_DEVICE_BATCH_SIZE=2
    NUM_STEPS=30
    MAX_TARGET_LENGTH=8192
  2. Build your model script to be copied as a training command in the next step. Don't execute the model script yet.

    MaxText

    MaxText is a high performance, highly scalable, open-source LLM written in pure Python and JAX and targeting Google Cloud TPUs and GPUs for training and inference.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \
            base_output_directory=$BASE_OUTPUT_DIR \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}"  # attention='dot_product'"
    

    Gemma2

    Gemma is a family of open-weights large language models (LLM) developed by Google DeepMind, based on Gemini research and technology.

    # Requires v6e-256
    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral is a state-of-the-art AI model developed by Mistral AI, utilizing a sparse mixture-of-experts (MoE) architecture.

    python3 MaxText/train.py MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama is a family of open-weights large language models (LLM) developed by Meta.

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=llama3-8b \
        base_output_directory=${BASE_OUTPUT_DIR} \
        dataset_type=synthetic \
        tokenizer_path=assets/tokenizer_llama3.tiktoken \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \
        gcs_metrics=true \
        profiler=xplane \
        skip_first_n_steps_for_profiler=5 \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        attention=flash"
    

    MaxDiffusion

    MaxDiffusion is a collection of reference implementations of various latent diffusion models written in pure Python and JAX that run on XLA devices including Cloud TPUs and GPUs. Stable Diffusion is a latent text-to-image model that generates photo-realistic images from any text input.

    You need to install a specific branch to run MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    Training script:

        cd maxdiffusion && OUT_DIR=${your_own_bucket}
        python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml \
            run_name=v6e-sd2 \
            split_head_dim=True \
            attention=flash \
            train_new_unet=false \
            norm_num_groups=16 \
            output_dir=${BASE_OUTPUT_DIR} \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            [dcn_data_parallelism=2] \
            enable_profiler=True \
            skip_first_n_steps_for_profiler=95 \
            max_train_steps=${NUM_STEPS} ]
            write_metrics=True'
        
  3. Run the model using the script you created in the previous step. You must either specify the --base-docker-image flag to use the MaxText base image or specify the --docker-image flag and the image you want to use.

    Optional: You can enable debug logging by including the --enable-debug-logs flag. For more information, see Debug JAX on MaxText.

    Optional: you can create a Vertex AI Experiment to upload data to Vertex AI TensorBoard by including the --use-vertex-tensorboard flag. For more information, see Monitor JAX on MaxText using Vertex AI.

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \
        --workload ${USER}-xpk-ACCELERATOR_TYPE-NUM_SLICES \
        --tpu-type=ACCELERATOR_TYPE \
        --num-slices=NUM_SLICES  \
        --on-demand \
        --zone $ZONE \
        --project $PROJECT_ID \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command YOUR_MODEL_SCRIPT

    Replace the following variables:

    • CLUSTER_NAME: The name of your XPK cluster.
    • ACCELERATOR_TYPE: The version and size of your TPU. For example, v6e-256.
    • NUM_SLICES: The number of TPU slices.
    • YOUR_MODEL_SCRIPT: The model script to execute as a training command.

    The output includes a link to follow your workload, similar to the following:

    [XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
    

    Open the link and click the Logs tab to track your workload in real time.

Debug JAX on MaxText

Use supplemental XPK commands to diagnose why the cluster or workload isn't running:

Monitor JAX on MaxText using Vertex AI

View scalar and profile data through Vertex AI's managed TensorBoard.

  1. Increase resource management (CRUD) requests for the zone you're using from 600 to 5000. This might not be an issue for small workloads using less than 16 VMs.
  2. Install dependencies such as cloud-accelerator-diagnostics for Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Create your XPK cluster using the --create-vertex-tensorboard flag, as documented in Create Vertex AI TensorBoard. You can also run this command on existing clusters.

  4. Create your Vertex AI experiment when running your XPK workload using the --use-vertex-tensorboard flag and the optional --experiment-name flag. For the full list of steps, see Create Vertex AI Experiment to upload data to Vertex AI TensorBoard.

The logs include a link to a Vertex AI TensorBoard, similar to the following:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

You can also find the Vertex AI TensorBoard link in the Google Cloud console. Go to Vertex AI Experiments in the Google Cloud console. Select the appropriate region from the drop-down.

The TensorBoard directory is also written to the Cloud Storage bucket that you specified with ${BASE_OUTPUT_DIR}.

Delete XPK workloads

Use the xpk workload delete command to delete one or more workloads based on the job prefix or job status. This command might be useful if you sent XPK workloads that no longer need to be run, or if you have jobs that are stuck in the queue.

Delete XPK cluster

Use the xpk cluster delete command to delete a cluster:

python3 xpk.py cluster delete --cluster CLUSTER_NAME --zone $ZONE --project $PROJECT_ID

Llama and PyTorch

This tutorial describes how to train Llama models using PyTorch/XLA on TPU v6e using the WikiText dataset. Additionally, users can access PyTorch TPU model recripes as docker images here.

Installation

Install the pytorch-tpu/transformers fork of Hugging Face Transformers and dependencies in a virtual environment:

git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
cd transformers
pip3 install -e .
pip3 install datasets
pip3 install evaluate
pip3 install scikit-learn
pip3 install accelerate

Set up model configs

The training command in the next section, Build your model script uses two JSON config files to define model parameters and FSDP (Fully Sharded Data Parallel) configuration. FSDP sharding is used for the model weights to fit a bigger batch size while training. When training with smaller models, it might be enough to just use data parallelism and replicate the weights on each device. Refer to PyTorch/XLA SPMD User Guide for more details on how to shard tensors across devices in PyTorch/XLA.

  1. Create the model parameter config file. The following is the model parameter config for Llama3-8B. For other models, find the config on Hugging Face. For example, see the Llama2-7B config.

    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
  2. Create the FSDP config file:

    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }

    Refer to FSDPv2 for more details about FSDP.

  3. Upload the config files to your TPU VMs using the following command:

        gcloud alpha compute tpus tpu-vm scp YOUR_CONFIG_FILE.json $TPU_NAME:. \
            --worker=all \
            --project=$PROJECT \
            --zone $ZONE

    You can also create the config files in your current working directory and use the --base-docker-image flag in XPK.

Build your model script

Build your model script, specifying the model parameter config file using the --config_name flag and the FSDP config file using the --fsdp_config flag. You will run this script on your TPU in the next section, Run the model. Don't execute the model script yet.

    PJRT_DEVICE=TPU
    XLA_USE_SPMD=1
    ENABLE_PJRT_COMPATIBILITY=true
    # Optional variables for debugging:
    XLA_IR_DEBUG=1
    XLA_HLO_DEBUG=1
    PROFILE_EPOCH=0
    PROFILE_STEP=3
    PROFILE_DURATION_MS=100000
    PROFILE_LOGDIR=local VM path or gs://my-bucket/profile_path
    python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 8 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/config-8B.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp_config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20

Run the model

Run the model using the script you created in the previous step, Build your model script.

If you're using a single-host TPU VM (such as v6e-4), you can run the training command directly on the TPU VM. If you're using a multi-host TPU VM, use the following command to run the script simultaneously on all the hosts:

gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \
    --zone $ZONE \
    --worker=all \
    --command=YOUR_COMMAND

Troubleshooting PyTorch/XLA

If you set the optional variables for debugging in the previous section, the profile for the model will be stored at the location specified by the variable PROFILE_LOGDIR. You can extract the xplane.pb file stored at this location and use tensorboard to view the profiles in your browser using the TensorBoard instructions If PyTorch/XLA isn't performing as expected, see the troubleshooting guide, which has suggestions for debugging, profiling, and optimizing yourmodels(s).

DLRM DCN v2 tutorial

This tutorial shows you how to train the DLRM DCN v2 model on TPU v6e.

If you are running on multi-host, reset tpu-runtime with the appropriate TensorFlow version by running the following command. If you are running on single host, you don't need to run the following two commands.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID}
--zone  ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
 --zone  ${ZONE}   \
 --worker=all \
 --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

SSH into worker-0

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}

Set the TPU name

export TPU_NAME=${TPU_NAME}

Run DLRM v2

pip install cloud-tpu-client

pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git

export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'

TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py  --mode=train     --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
  distribution_strategy: tpu
  mixed_precision_dtype: 'mixed_bfloat16'
task:
  use_synthetic_data: false
  use_tf_record_reader: true
  train_data:
    input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  validation_data:
    input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  model:
    num_dense_features: 13
    bottom_mlp: [512, 256, 128]
    embedding_dim: 128
    interaction: 'multi_layer_dcn'
    dcn_num_layers: 3
    dcn_low_rank_dim: 512
    size_threshold: 8000
    top_mlp: [1024, 1024, 512, 256, 1]
    use_multi_hot: true
    concat_dense: false
    dcn_use_bias: true
    vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
    multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
    max_ids_per_chip_per_sample: 128
    max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
    max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
    use_partial_tpu_embedding: false
    size_threshold: 0
    initialize_tables_on_host: true
trainer:
  train_steps: 10000
  validation_interval: 1000
  validation_steps: 660
  summary_interval: 1000
  steps_per_loop: 1000
  checkpoint_interval: 0
  optimizer_config:
    embedding_optimizer: 'Adagrad'
    dense_optimizer: 'Adagrad'
    lr_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.025
      warmup_steps: 0
    dense_sgd_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.00025
      warmup_steps: 8000
  train_tf_function: true
  train_tf_while_loop: true
  eval_tf_while_loop: true
  use_orbit: true
  pipeline_sparse_and_dense_execution: true"

Run script.sh:

chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

The following flags are necessary to run recommendation workloads (DLRM DCN):

ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"

Benchmarking results

The following section contains benchmarking results for DLRM DCN v2 and MaxDiffusion on v6e.

DLRM DCN v2

The DLRM DCN v2 training script was run at different scales. See the throughputs in the following table.

v6e-64 v6e-128 v6e-256
Training steps 7000 7000 7000
Global batch size 131072 262144 524288
Throughput (examples/sec) 2975334 5111808 10066329

MaxDiffusion

We ran the training script for MaxDiffusion on a v6e-4, a v6e-16, and a 2xv6e-16. See the throughputs in the following table.

v6e-4 v6e-16 Two v6e-16
Training steps 0.069 0.073 0.13
Global batch size 8 32 64
Throughput (examples/sec) 115.9 438.4 492.3

Collections

v6e introduces a new feature named collections for the benefit of users who run serving workloads. The collections feature only applies to v6e.

Collections lets you indicate indicate to Google Cloud which of your TPU nodes form part of a serving workload. This enables the underlying Google Cloud infrastructure to limit and streamline interruptions that may be applied to training workloads in the normal course of operations.

Use collections from the Cloud TPU API

A single-host collection on the Cloud TPU API is a queued resource on which a special flag (--workload-type = availability-optimized) is set to indicate to underlying infrastructure that it is meant to be used for serving workloads.

The following command provisions a single-host collection using the Cloud TPU API:

gcloud alpha compute tpus queued-resources create COLLECTION_NAME \
   --project=project name \
   --zone=zone name \
   --accelerator-type=accelerator type \
   --node-count=number of nodes \
   --workload-type=availability-optimized

Monitor and profile

Cloud TPU v6e supports monitoring and profiling using the same methods as previous generations of Cloud TPU. For more information about monitoring, see Monitor TPU VMs.