Cloud TPU v5e Inference [Public Preview]
Get Started
Overview and Benefits
Cloud TPU v5e is Google Cloud's latest generation AI accelerator. With a smaller 256-chip footprint per Pod, v5e Pods are optimized for transformer-based, text-to-image and CNN-based training, fine-tuning, and serving.
Concepts
If you are new to Cloud TPUs, check out the TPU docs home.
Chips
There are 256 chips in a single v5e with 8 chips per host. See System architecture for more details.
Cores
TPU chips have one or two TensorCores to run matrix multiplication. Similar to v2 and v3 Pods, v5e has one TensorCore per chip. By contrast, v4 Pods have 2 TensorCores per chip. See System architecture for more details on v5e TensorCores. Additional information about TensorCores can be found in this ACM article.
Host
A host is a physical computer (CPU) that runs VMs. A host can run multiple VMs at once.
Batch Inference
Batch or offline inference refers to doing inference outside of production pipelines typically on a bulk of inputs. Batch inference is used for offline tasks such as data labeling and also for evaluating the trained model. Latency SLOs are not a priority for batch inference.
Serving
Serving refers to the process of deploying a trained machine learning model to a production environment, where it can be used to make predictions or decisions. Latency SLOs are a priority for serving.
Single Host versus Multi Host
Slices using fewer than 8 chips use at most 1 host. Slices greater than 8 chips, have access to more than a single host and can run distributed training using multiple hosts. The preview does not support multi host serving. The maximum number of chips allowed in a single serving job is 8 chips.
Queued Resource
A representation of TPU resources, used to enqueue and manage a request for a single-slice or multi-slice TPU environment. See the Queued Resources user guide for more information.
Cloud TPU VM
A virtual machine running Linux that has access to the underlying TPU's. For v5e TPUs, each TPU VM has direct access to 1, 4, or 8 chips depending on the user-specified accelerator type. A TPU VM is also known as a worker.
System Architecture
Each v5e chip contains one TensorCore. Each TensorCore has 4 Matrix Multiply Units (MXU), a vector unit, and a scalar unit. The following table shows the key specifications and their values for a v5e. Pod specifications are included in the table that follows the chip specifications.
Key chip specifications | v5e values |
Peak compute per chip (bf16) | 197 TFLOPs |
Peak compute per chip (Int8) | 393 TFLOPs |
HBM2 capacity and bandwidth | 16 GB, 819 GBps |
Interchip Interconnect BW | 1600 Gbps |
All-reduce bandwidth per Pod | 51.2 TB/s |
Bisection bandwidth per Pod | 1.6 TB/s |
Data center network bandwidth per Pod | 6.4 Tbps |
Key Pod specifications | v5e values |
TPU Pod size | 256 chips |
Interconnect topology | 2D Torus |
Peak compute per Pod | 100 PetaOps(Int8) |
All-reduce bandwidth per Pod | 51.2 TB/s |
Bisection bandwidth per Pod | 1.6 TB/s |
Data center network bandwidth per Pod | 6.4 Tbps |
TPU v5e chip
The following diagram illustrates a TPU v5e chip.
Accelerator Types
Cloud TPU v5e is a combined training and inference product.
The AcceleratorType
flag is used to differentiate between the TPU
environment provisioned for training versus the environment provisioned
for serving. Training jobs are optimized for throughput and availability
while serving jobs are optimized for latency. So, a training job on TPUs
provisioned for serving could have lower availability and similarly, a
serving job executed on TPUs provisioned for training could have higher latency.
AcceleratorType
uses the number of TensorCores in the node to describe the
size of the slice. AcceleratorType
is a formatted string 'v$VERSION_NUMBER-$CORES_COUNT'.
The following 2D slice shapes are supported for v5e:
Topology | Number of TPU chips | Number of Hosts |
1x1 | 1 | 1/8 |
2x2 | 4 | 1/2 |
2x4 | 8 | 1 |
4x4 | 16 | 2 |
4x8 | 32 | 4 |
8x8 | 64 | 8 |
8x16 | 128 | 16 |
16x16 | 256 | 32 |
Cloud TPU v5e types for serving
There are v5e TPU types that offer better availability for serving. These types only support up to 8 v5e chips (single host). The following configurations are supported: 1x1, 2x2 and 2x4 slices. Each slice has 1, 4 and 8 chips respectively.
Serving on more than 8 v5e chips, i.e., multi host serving, is on the Cloud TPU roadmap and will be available later.
To provision TPUs for a serving job, use the following accelerator types in your CLI or API TPU creation request. All v5e serving capacity currently is in zone us-west4-a
:
- v5e-1
- v5e-4
v5e-8
VM Types
Each VM in a TPU slice may contain 1, 4 or 8 chips. 4-chip and smaller slices will have the same NUMA node (for more on NUMA nodes, please see System Architecture.) For 8-chip TPU VMs, CPU-TPU communication will be more efficient within the NUMA partitions. For example, in the figure shown below, CPU0-Chip0 communication will be faster than CPU0-Chip4 communication.
VM type comparison:
For Public Preview, inference customers have access to the 1 chip (v5e-1), 4 chip (v5e-4), and 8 chip (v5e-8).
VM Type | n2d-48-24-v5lite-tpu | n2d-192-112-v5lite-tpu | n2d-384-224-v5lite-tpu |
# of v5e chips | 1 | 4 | 8 |
# of vCPUs | 24 | 112 | 224 |
RAM (GB) | 48 | 192 | 384 |
# of NUMA Nodes | 1 | 1 | 2 |
Applies to | v5e-1 | v5e-4 | v5e-8 |
Disruption | High | Medium | Low |
To make space for VMs with more chips, schedulers may preempt VMs with fewer chips. So 8-chip VMs are likely to preempt 1 and 4-chip VMs.
Get started
Securing capacity
Cloud TPU v5e is now in Public Preview. Please contact Cloud Sales to start using Cloud TPU v5e for your AI workloads.
Prepare a Google Cloud Project
Sign in to your Google Account. If you haven't already, sign up for a new account.
In the Cloud Console, select or create a Cloud project from the project selector page.
Billing setup is required for all Google Cloud usage so make sure billing is enabled for your project.
Billing for Public Preview usage follows standard regional pricing shown on the Cloud TPU pricing page.
Install gcloud alpha components.
Enable the TPU API using the following
gcloud
command in Cloud Shell. (You may also enable it from the Google Cloud Console.)gcloud services enable tpu.googleapis.com
Enable the TPU service account.
Service accounts allow the Cloud TPU service to access other Google Cloud services. A user-managed service account is a recommended Google Cloud practice. Follow these guides to create and grant the following roles to your service account. The following roles are necessary:
- TPU Admin
- Storage Admin
- Logs Writer
- Monitoring Metric Writer
Configure the project and zone.
Your project ID is the name of your project shown on the Cloud console. The default zone for Cloud TPU v5e ia
us-west4-a
.export PROJECT_ID=project-ID export ZONE=us-west4-a gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE} gcloud auth login gcloud config set project ${PROJECT} gcloud config set compute/zone ${ZONE}
Request access to Cloud TPU Inference Public Preview Artifacts.
Send the service account you plan to use for TPU VM to the TPU team. The default one is the Compute Engine default service account (${PROJECT_ID}-compute@developer.gserviceaccount.com ) which can be found in your IAM page.
Provision the Cloud TPU v5e environment.
A v5e is managed as a Queued Resource. Capacity can be provisioned using the
queued-resource create
command.Create environment variables for project id, accelerator type, zone, runtime version, and TPU name.
export PROJECT_ID=project_ID export ACCELERATOR_TYPE=v5e-1 export ZONE=us-west4-a export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=service_account export TPU_NAME=tpu-name export QUEUED_RESOURCE_ID=queued_resource_id
Variable descriptions
- Project Name. Use your Google project name.
- See the Accelerator Types section for supported accelerator types.
- All Public Preview capacity will be in us-west4-a.
- v2-alpha-tpuv5-lite
- This is the email address of your service account that you can find in Google Cloud Console -> IAM -> Service Accounts For example: tpu-service-account@myprojectID.iam.gserviceaccount.com.
- The user-assigned ID of the TPU which is created when the queued resource request is allocated.
- The user-assigned ID of the queued resource request.
PROJECT_ID
ACCELERATOR_TYPE
ZONE
RUNTIME_VERSION
SERVICE_ACCOUNT
TPU_NAME
QUEUED_RESOURCE_ID
Create a TPU resource.
gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --service-account ${SERVICE_ACCOUNT} \ --reserved
If you would like to delete the resource you have reserved, you need to delete the resource TPU_NAME first and then also delete the queued resource request.
gcloud alpha compute tpus delete $TPU_NAME --zone ${ZONE} --project ${PROJECT_ID} gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} \ --zone ${ZONE}
SSH into TPU VMs
To run code on your TPU VMs, you need to
SSH
into each TPU VM. In this example, with a v5e-1, there is only one TPU VM.gcloud compute config-ssh gcloud compute tpus tpu-vm ssh $TPU_NAME --zone $ZONE --project $PROJECT_ID
Manage your TPU VMs
For all TPU management options for your TPU VMs, see Managing TPUs.
Develop and Run
This section describes the general setup process for custom model inference using JAX or PyTorch on Cloud TPU v5e. Tensorflow support will be enabled soon.
For v5e training instructions, refer to the v5e training guide.
Running Inference on v5e
Inference Software Stack
Details of the Inference SW stack are covered in the following sections. This document focuses on single host serving for models trained with JAX, TensorFlow (TF), and PyTorch.
This section assumes that you have already set up your Google Cloud project according to the instructions in Prepare a Google Cloud project.
JAX Model Inference and Serving
The following section walks through the workflow for JAX model Inference.
There are two paths for JAX inference as shown in the diagram.
This section will cover the production path for JAX models through jax2tf
and TensorFlow Serving.
- Use
jax2tf
to convert the model to TensorFlow 2 and save the model - Use the Inference Converter to convert the saved model
- Use TensorFlow Serving to serve the model
Use jax2tf
to convert the model and save it
Please refer to JAX and TensorFlow interoperation to convert and save your JAX model to TensorFlow.
# Inference function def model_jax(params, inputs): return params[0] + params[1] * inputs # Wrap the parameter constants as tf.Variables; this will signal to the model # saving code to save those constants as variables, separate from the # computation graph. params_vars = tf.nest.map_structure(tf.Variable, params) # Build the prediction function by closing over the `params_vars`. If you # instead were to close over `params` your SavedModel would have no variables # and the parameters will be included in the function graph. prediction_tf = lambda inputs: jax2tf.convert(model_jax)(params_vars, inputs) my_model = tf.Module() # Tell the model saver what the variables are. my_model._variables = tf.nest.flatten(params_vars) my_model.f = tf.function(prediction_tf, jit_compile=True, autograph=False) tf.saved_model.save(my_model)
Use the Inference Converter to convert the saved model
The steps for Inference Converter are described in the Inference converter guide.
Use TensorFlow Serving
The steps for TensorFlow Serving are described in TensorFlow serving.
E2E JAX Model Serving Example
Prerequisite:
You need to set up your Docker credentials and pull the Inference Converter and TensorFlow Serving Docker image. If you have not already done so, run the following commands:
sudo usermod -a -G docker ${USER} newgrp docker gcloud auth configure-docker \ us-docker.pkg.dev docker pull us-docker.pkg.dev/cloud-tpu-images/inference/tpu-inference-converter-cli:2.13.0 docker pull us-docker.pkg.dev/cloud-tpu-images/inference/tf-serving-tpu:2.13.0
Download the Demo code: SSH to your TPU VM and install the inference Demo code.
gsutil -m cp -r \ "gs://cloud-tpu-inference-public/demo" \ .
Install the JAX demo dependencies
On your TPU VM, install requirements.txt
.
pip install -r ./demo/jax/requirements.txt
Run JAX BERT E2E Serving demo
The pretrained BERT model is from Hugging Face.
Export a TPU-compatible TF2 saved model from a Flax BERT model:
cd demo/jax/bert
python3 export_bert_model.py
Launch the TensorFlow model server container for the model:
docker run -t --rm --privileged -d \ -p 8500:8500 -p 8501:8501 \ --mount type=bind,source=/tmp/jax/bert_tpu,target=/models/bert \ -e MODEL_NAME=bert \ us-docker.pkg.dev/cloud-tpu-images/inference/tf-serving-tpu:2.13.0
Check the model server container log and make sure the gRPC and Http Server is up:
CONTAINER_ID=$(docker ps | grep "tf-serving-tpu" | awk '{print $1}') docker logs ${CONTAINER_ID}
If you see the log ending with the following information, it means the server is ready to serve requests. It takes around 30 seconds.
2023-04-08 00:43:10.481682: I tensorflow_serving/model_servers/server.cc:409] Running gRPC ModelServer at 0.0.0.0:8500 ... [warn] getaddrinfo: address family for nodename not supported 2023-04-08 00:43:10.520578: I tensorflow_serving/model_servers/server.cc:430] Exporting HTTP/REST API at:localhost:8501 ... [evhttp_server.cc : 245] NET_LOG: Entering the event loop ...
Send the request to the model server.
python3 bert_request.py
The output will be similar to the following:
For input "The capital of France is [MASK].", the result is ". the capital of france is paris.." For input "Hello my name [MASK] Jhon, how can I [MASK] you?", the result is ". hello my name is jhon, how can i help you?."
Clean up.
Make sure to clean up the Docker container before running other demos.
CONTAINER_ID=$(docker ps | grep "tf-serving-tpu" | awk '{print $1}') docker stop ${CONTAINER_ID}
Clean up the model artifacts:
sudo rm -rf /tmp/jax/
Run JAX Stable Diffusion E2E Serving demo
The pretrained Stable Diffusion model is from Hugging Face.
Export TPU-compatible TF2 saved model from Flax Stable Diffusion model:
cd demo/jax/stable_diffusion
python3 export_stable_diffusion_model.py
Launch the TensorFlow model server container for the model:
docker run -t --rm --privileged -d \ -p 8500:8500 -p 8501:8501 \ --mount type=bind,source=/tmp/jax/stable_diffusion_tpu,target=/models/stable_diffusion \ -e MODEL_NAME=stable_diffusion \ us-docker.pkg.dev/cloud-tpu-images/inference/tf-serving-tpu:2.13.0
Check the model server container log and make sure the gRPC and Http Server is up:
CONTAINER_ID=$(docker ps | grep "tf-serving-tpu" | awk '{print $1}') docker logs ${CONTAINER_ID}
If you see the log ending with the following information, it means the server is ready to serve requests. It takes around 2 minutes.
2023-04-08 00:43:10.481682: I tensorflow_serving/model_servers/server.cc:409] Running gRPC ModelServer at 0.0.0.0:8500 ... [warn] getaddrinfo: address family for nodename not supported 2023-04-08 00:43:10.520578: I tensorflow_serving/model_servers/server.cc:430] Exporting HTTP/REST API at:localhost:8501 ... [evhttp_server.cc : 245] NET_LOG: Entering the event loop ...
Send the request to the model server.
python3 stable_diffusion_request.py
The prompt is "Painting of a squirrel skating in New York" and the output image will be saved as
stable_diffusion_images.jpg
in your current directory.Clean up.
Make sure to clean up the Docker container before running other demos.
CONTAINER_ID=$(docker ps | grep "tf-serving-tpu" | awk '{print $1}') docker stop ${CONTAINER_ID}
Clean up the model artifacts
sudo rm -rf /tmp/jax/
TensorFlow Model Inference and Serving
The following sections walk through the workflow for TensorFlow Model Inference.
- Use the Inference Converter to convert the model
- Use TensorFlow Serving to serve the model
Inference Converter
Cloud TPU Inference Converter prepares and optimizes a model exported from TensorFlow or JAX for TPU inference. The converter runs in a local shell or in the TPU VM shell. The TPU VM shell is recommended because it comes preinstalled with the command line tools needed for the converter. For more details on the Inference Converter please refer to the Inference Converter User Guide.
Prerequisites
The model must be exported from TensorFlow or JAX in the SavedModel format.
The model must have a function alias for the TPU function. See the code examples in Inference Converter User Guide for instructions on how to do this. The following examples use
tpu_func
as the TPU function alias.Make sure your machine CPU supports Advanced Vector eXtensions (AVX) instructions, as the Tensorflow library (the dependency of the Cloud TPU Inference Converter) is compiled to use AVX instructions. Most CPUs have the AVX support.
- You can run
lscpu | grep avx
to check whether the AVX instruction set is supported.
- You can run
Getting Started
Set up TPU VM Environment Set up the environment using the following steps depending on the shell you are using:
TPU VM Shell
- In your TPU VM shell, run the following commands to allow non-root docker usage:
sudo usermod -a -G docker ${USER} newgrp docker
- Initialize your Docker Credential helpers:
gcloud auth configure-docker \ us-docker.pkg.dev
Local Shell
In your local shell, set up the environment using the following steps:
Install the Cloud SDK, which includes the
gcloud
command-line tool.Install Docker:
Allow non-root Docker usage:
sudo usermod -a -G docker ${USER} newgrp docker
Login in to your environment:
gcloud auth login
Initialize your Docker Credential helpers:
gcloud auth configure-docker \ us-docker.pkg.dev
Pull the Inference Converter Docker image:
CONVERTER_IMAGE=us-docker.pkg.dev/cloud-tpu-images/inference/tpu-inference-converter-cli:2.13.0 docker pull ${CONVERTER_IMAGE}
Converter Image
The Image is for doing one-time model conversions. Set the model paths and adjust the converter options to fit your needs. The Usage Examples section in the Inference Converter User Guide provides several common use cases.
docker run \ --mount type=bind,source=${MODEL_PATH},target=/tmp/input,readonly \ --mount type=bind,source=${CONVERTED_MODEL_PATH},target=/tmp/output \ ${CONVERTER_IMAGE} \ --input_model_dir=/tmp/input \ --output_model_dir=/tmp/output \ --converter_options_string=' tpu_functions { function_alias: "tpu_func" } batch_options { num_batch_threads: 2 max_batch_size: 8 batch_timeout_micros: 5000 allowed_batch_sizes: 2 allowed_batch_sizes: 4 allowed_batch_sizes: 8 max_enqueued_batches: 10 } '
The following section shows how to run this model with a TensorFlow model Server.
TensorFlow Serving
The following instructions demonstrate how you can serve your TensorFlow model on TPU VMs.
Prerequisites:
Set up your Docker credentials, if you have not already done so:
Download the TensorFlow Serving Docker Image for your TPU VM.
Set sample environment variables
export YOUR_LOCAL_MODEL_PATH=model-path export MODEL_NAME=model-name # Note: this image name may change later. export IMAGE_NAME=us-docker.pkg.dev/cloud-tpu-images/inference/tf-serving-tpu:2.13.0
Download the Docker image
docker pull ${IMAGE_NAME}
Serve your TensorFlow model using the TensorFlow Serving Docker Image on your TPU VM.
# PORT 8500 is for gRPC model server and 8501 is for HTTP/REST model server. docker run -t --rm --privileged -d \ -p 8500:8500 -p 8501:8501 \ --mount type=bind,source=${YOUR_LOCAL_MODEL_PATH},target=/models/${MODEL_NAME} \ -e MODEL_NAME=${MODEL_NAME} \ ${IMAGE_NAME}
Follow the Serving Client API to query your model.
E2E TensorFlow Model Serving Example:
Prerequisite: Make sure you already set up the Docker credentials and pulled the Inference Converter and TensorFlow Serving Docker image. If not, run the following commands:
sudo usermod -a -G docker ${USER} newgrp docker gcloud auth configure-docker \ us-docker.pkg.dev docker pull us-docker.pkg.dev/cloud-tpu-images/inference/tpu-inference-converter-cli:2.13.0 docker pull us-docker.pkg.dev/cloud-tpu-images/inference/tf-serving-tpu:2.13.0
Download the Demo code:
gsutil -m cp -r \ "gs://cloud-tpu-inference-public/demo" \ .
Install the TensorFlow demo dependencies:
pip install -r ./demo/tf/requirements.txt
Run TensorFlow ResNet-50 E2E Serving demo
Export a TPU-compatible TF2 saved model from the Keras ResNet-50 model.
cd demo/tf/resnet-50
python3 export_resnet_model.py
Launch the TensorFlow model server container for the model.
docker run -t --rm --privileged -d \ -p 8500:8500 -p 8501:8501 \ --mount type=bind,source=/tmp/tf/resnet_tpu,target=/models/resnet \ -e MODEL_NAME=resnet \ us-docker.pkg.dev/cloud-tpu-images/inference/tf-serving-tpu:2.13.0
Check the model server container log and make sure the gRPC and Http Server is up:
CONTAINER_ID=$(docker ps | grep "tf-serving-tpu" | awk '{print $1}') docker logs ${CONTAINER_ID}
If you see the log ending with the following information, it means the server is ready to serve requests. It takes around 30 seconds.
2023-04-08 00:43:10.481682: I tensorflow_serving/model_servers/server.cc:409] Running gRPC ModelServer at 0.0.0.0:8500 ... [warn] getaddrinfo: address family for nodename not supported 2023-04-08 00:43:10.520578: I tensorflow_serving/model_servers/server.cc:430] Exporting HTTP/REST API at:localhost:8501 ... [evhttp_server.cc : 245] NET_LOG: Entering the event loop ...
Send the request to the model server.
The request image is a banana from https://i.imgur.com/j9xCCzn.jpeg .
python3 resnet_request.py
The output will be similar to the following:
Predict result: [[('n07753592', 'banana', 0.94921875), ('n03532672', 'hook', 0.022338867), ('n07749582', 'lemon', 0.005126953)]]
Clean up.
Make sure to clean up the Docker container before running other demos.
CONTAINER_ID=$(docker ps | grep "tf-serving-tpu" | awk '{print $1}') docker stop ${CONTAINER_ID}
Clean up the model artifacts:
sudo rm -rf /tmp/tf/
PyTorch Model Inference and Serving
The following sections walk through the workflow for PyTorch Model Inference:
- Write a Python model handler for loading and inferencing using TorchDynamo and PyTorch/XLA
- Use TorchModelArchiver to create a model archive
- Use TorchServe to serve the model
TorchDynamo and PyTorch/XLA
TorchDynamo (Dynamo) is a Python-level JIT compiler designed to make unmodified PyTorch programs faster. It provides a clean API for compiler backends to hook into. Its biggest feature is to dynamically modify Python bytecode just before execution. In the PyTorch/XLA 2.0 release, an experimental backend for Dynamo is provided for both inference and training.
Dynamo provides a Torch FX (FX) graph when it recognizes a model pattern and PyTorch/XLA uses a Lazy Tensor approach to compile the FX graph and return the compiled function. To get more insight regarding the technical details about PyTorch/XLA's dynamo implementation, see the Pytorch Dev Discussions post dev-discuss post and the TorchDynamo documentation. See this blog for more details.
Here is a small code example of running densenet161 inference with torch.compile
.
import torch import torchvision import torch_xla.core.xla_model as xm def eval_model(loader): device = xm.xla_device() xla_densenet161 = torchvision.models.densenet161().to(device) xla_densenet161.eval() dynamo_densenet161 = torch.compile( xla_densenet161, backend='torchxla_trace_once') for data, _ in loader: output = dynamo_densenet161(data)
TorchServe
The Cloud TPU TorchServe Docker Image allows you to serve the PyTorch eager mode model using TorchServe on a Cloud TPU VM.
You can use the provided torchserve-tpu
Docker image that is ready for
serving your archived pytorch model on a Cloud TPU VM.
Set up authentication for Docker:
sudo usermod -a -G docker ${USER} newgrp docker gcloud auth configure-docker \ us-docker.pkg.dev
Pull the Cloud TPU TorchServe Docker image to your TPU VM:
CLOUD_TPU_TORCHSERVE_IMAGE_URL=us-docker.pkg.dev/cloud-tpu-images/inference/torchserve-tpu:v0.8.2-20230829 docker pull ${CLOUD_TPU_TORCHSERVE_IMAGE_URL}
Collect Model Artifacts
To get started, you need to provide a model handler, which instructs the TorchServe model server worker to load your model, process the input data and run inference. You can use the TorchServe default inference handlers (source), or develop your own custom model handler following the base_handler.py. You may also need to provide the trained model, and the model definition file.
In the following Densenet 161 example, we use model artifacts and the default image classifier handler provided by TorchServe:
The work directory is shown below.
CWD="$(pwd)" WORKDIR="${CWD}/densenet_161" mkdir -p ${WORKDIR}/model-store mkdir -p ${WORKDIR}/logs
Download and copy model artifacts from the TorchServe image classifier example:
git clone https://github.com/pytorch/serve.git cp ${CWD}/serve/examples/image_classifier/densenet_161/model.py ${WORKDIR} cp ${CWD}/serve/examples/image_classifier/index_to_name.json ${WORKDIR}
Download the model weights:
wget https://download.pytorch.org/models/densenet161-8d451a50.pth -O densenet161-8d451a50.pth mv densenet161-8d451a50.pth ${WORKDIR}
Create a TorchServe model config file to use the Dynamo backend:
echo 'pt2: "torchxla_trace_once"' >> ${WORKDIR}/model_config.yaml
You should see the files and directories shown below:
>> ls ${WORKDIR} model_config.yaml index_to_name.json logs model.py densenet161-8d451a50.pth model-store
Generate a model archive file
To serve your PyTorch model with Cloud TPU TorchServe, you need to
package your model handler and all your model artifacts into a model
archive file (*.mar)
using Torch Model Archiver.
Generate a model archive file with torch-model-archiver:
MODEL_NAME=Densenet161 docker run \ --privileged \ --shm-size 16G \ --name torch-model-archiver \ -it \ -d \ --rm \ --mount type=bind,source=${WORKDIR},target=/home/model-server/ \ ${CLOUD_TPU_TORCHSERVE_IMAGE_URL} \ torch-model-archiver \ --model-name ${MODEL_NAME} \ --version 1.0 \ --model-file model.py \ --serialized-file densenet161-8d451a50.pth \ --handler image_classifier \ --export-path model-store \ --extra-files index_to_name.json \ --config-file model_config.yaml
You should see the model archive file generated in the model-store directory:
>> ls ${WORKDIR}/model-store Densenet161.mar
Serve inference requests
Now you have the model archive file, you can start the TorchServe model server and serve inference requests.
Start the TorchServe model server:
docker run \ --privileged \ --shm-size 16G \ --name torchserve-tpu \ -it \ -d \ --rm \ -p 7070:7070 \ -p 7071:7071 \ -p 8080:8080 \ -p 8081:8081 \ -p 8082:8082 \ -p 9001:9001 \ -p 9012:9012 \ --mount type=bind,source=${WORKDIR}/model-store,target=/home/model-server/model-store \ --mount type=bind,source=${WORKDIR}/logs,target=/home/model-server/logs \ ${CLOUD_TPU_TORCHSERVE_IMAGE_URL} \ torchserve \ --start \ --ncs \ --models ${MODEL_NAME}.mar \ --ts-config /home/model-server/config.properties
Query model server health:
curl http://localhost:8080/ping
If the model server is up and running, you will see:
{ "status": "Healthy" }
To query the default versions of the current registered model use:
curl http://localhost:8081/models
You should see the registered model:
{ "models": [ { "modelName": "Densenet161", "modelUrl": "Densenet161.mar" } ] }
To download an image for inference use:
curl -O https://raw.githubusercontent.com/pytorch/serve/master/docs/images/kitten_small.jpg mv kitten_small.jpg ${WORKDIR}
To send an inference request to the model server use:
curl http://localhost:8080/predictions/${MODEL_NAME} -T ${WORKDIR}/kitten_small.jpg
You should see a response similar to the following:
{ "tabby": 0.47878125309944153, "lynx": 0.20393909513950348, "tiger_cat": 0.16572578251361847, "tiger": 0.061157409101724625, "Egyptian_cat": 0.04997897148132324 }
Model server logs
Use the following commands to access the logs:
ls ${WORKDIR}/logs/ cat ${WORKDIR}/logs/model_log.log
You should see the following message in your log:
"Compiled model with backend torchxla_trace_once"
Clean Up
Stop the Docker container:
rm -rf serve rm -rf ${WORKDIR} docker stop torch-model-archiver docker stop torchserve-tpu
Large Language Model Serving
SAX is a serving framework to support serving large models that may require TPUs on multiple hosts to run with GSPMD, such as PAX-based Large Language Models. PAX is a framework on top of JAX, and is for training large-scale models that allows for advanced and fully configurable experimentation and parallelization.
The SAX cluster section describes the key elements to understand how SAX works. The SAX model serving section walks through a single-host model serving example with a GPTJ6B model. SAX also provides multi-host serving on Cloud TPUs and users can run models on larger TPU topologies for experimental multi-host serving preview. The example below with a 175B test model shows how to experiment with this setup.
SAX cluster (SAX cell)
SAX admin server and SAX model server are two essential components that run a SAX cluster.
SAX admin server
The SAX admin server monitors and coordinates all SAX model servers in a SAX cluster. In a SAX cluster, you can launch multiple SAX admin servers, where only one of the SAX admin server is active through leader election, the others are standby servers. When the active admin server fails, a standby admin server will become active. The active SAX admin server assigns model replicas and inference requests to available SAX model servers.
SAX admin storage bucket
Each SAX cluster requires a Cloud Storage bucket to store the configurations and locations of SAX admin servers and SAX model servers in the SAX cluster.
SAX model server
The SAX model server loads a model checkpoint and runs inference with GSPMD. A SAX model server runs on a single TPU VM worker. Single-host TPU model serving requires a single SAX model server on a single-host TPU VM. Multi-host TPU model serving requires a group of SAX model servers on a multi-host TPU slice.
SAX model serving
The following section walks through the workflow for serving language models using SAX. It uses the GPT-J 6B model as an example for single-host model serving, and a 175B test model for multi-host model serving.
Before starting, install the Cloud TPU SAX Docker images on your TPU VM:
sudo usermod -a -G docker ${USER} newgrp docker gcloud auth configure-docker us-docker.pkg.dev SAX_ADMIN_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server" SAX_MODEL_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server" SAX_UTIL_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-util" SAX_VERSION=v1.0.0 export SAX_ADMIN_SERVER_IMAGE_URL=${SAX_ADMIN_SERVER_IMAGE_NAME}:${SAX_VERSION} export SAX_MODEL_SERVER_IMAGE_URL=${SAX_MODEL_SERVER_IMAGE_NAME}:${SAX_VERSION} export SAX_UTIL_IMAGE_URL="${SAX_UTIL_IMAGE_NAME}:${SAX_VERSION}" docker pull ${SAX_ADMIN_SERVER_IMAGE_URL} docker pull ${SAX_MODEL_SERVER_IMAGE_URL} docker pull ${SAX_UTIL_IMAGE_URL}
Set some other variables you will use later:
export SAX_ADMIN_SERVER_DOCKER_NAME="sax-admin-server" export SAX_MODEL_SERVER_DOCKER_NAME="sax-model-server" export SAX_CELL="/sax/test"
GPT-J 6B single-host model serving example
Single-host model serving is applicable to single-host TPU slice, that is, v5litepod-1, v5litepod-4 and v5litepod-8.
Create a SAX cluster
Create a Cloud Storage storage bucket for the SAX cluster:
SAX_ADMIN_STORAGE_BUCKET=${your_admin_storage_bucket} gcloud storage buckets create gs://${SAX_ADMIN_STORAGE_BUCKET} \ --project=${PROJECT_ID}
You might need another Cloud Storage storage bucket to store the checkpoint.
SAX_DATA_STORAGE_BUCKET=${your_data_storage_bucket}
Connect to your TPU VM using SSH in a terminal to launch the SAX admin server:
docker run \ --name ${SAX_ADMIN_SERVER_DOCKER_NAME} \ -it \ -d \ --rm \ --network host \ --env GSBUCKET=${SAX_ADMIN_STORAGE_BUCKET} \ ${SAX_ADMIN_SERVER_IMAGE_URL}
You can check the Docker log by:
docker logs -f ${SAX_ADMIN_SERVER_DOCKER_NAME}
The output in the log will look similar to the following:
I0829 01:22:31.184198 7 config.go:111] Creating config fs_root: "gs://test_sax_admin/sax-fs-root" I0829 01:22:31.347883 7 config.go:115] Created config fs_root: "gs://test_sax_admin/sax-fs-root" I0829 01:22:31.360837 24 admin_server.go:44] Starting the server I0829 01:22:31.361420 24 ipaddr.go:39] Skipping non-global IP address 127.0.0.1/8. I0829 01:22:31.361455 24 ipaddr.go:39] Skipping non-global IP address ::1/128. I0829 01:22:31.361462 24 ipaddr.go:39] Skipping non-global IP address fe80::4001:aff:fe8e:fc8/64. I0829 01:22:31.361469 24 ipaddr.go:39] Skipping non-global IP address fe80::42:bfff:fef9:1bd3/64. I0829 01:22:31.361474 24 ipaddr.go:39] Skipping non-global IP address fe80::20fb:c3ff:fe5b:baac/64. I0829 01:22:31.361482 24 ipaddr.go:56] IPNet address 10.142.15.200 I0829 01:22:31.361488 24 ipaddr.go:56] IPNet address 172.17.0.1 I0829 01:22:31.456952 24 admin.go:305] Loaded config: fs_root: "gs://test_sax_admin/sax-fs-root" I0829 01:22:31.609323 24 addr.go:105] SetAddr /gcs/test_sax_admin/sax-root/sax/test/location.proto "10.142.15.200:10000" I0829 01:22:31.656021 24 admin.go:325] Updated config: fs_root: "gs://test_sax_admin/sax-fs-root" I0829 01:22:31.773245 24 mgr.go:781] Loaded manager state I0829 01:22:31.773260 24 mgr.go:784] Refreshing manager state every 10s I0829 01:22:31.773285 24 admin.go:350] Starting the server on port 10000 I0829 01:22:31.773292 24 cloud.go:506] Starting the HTTP server on port 8080
Launch a single-host SAX model server into the SAX cluster:
At this point, the SAX cluster contains only the SAX admin server. You can connect to your TPU VM over SSH in a second terminal to launch a SAX model server in your SAX cluster:
docker run \ --privileged \ -it \ -d \ --rm \ --network host \ --name ${SAX_MODEL_SERVER_DOCKER_NAME} \ --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ ${SAX_MODEL_SERVER_IMAGE_URL} \ --sax_cell=${SAX_CELL} \ --port=10001 \ --platform_chip=tpuv4 \ --platform_topology=1x1
You can check the Docker log by:
docker logs -f ${SAX_MODEL_SERVER_DOCKER_NAME}
Convert model checkpoint:
You need to install PyTorch and Transformers to download the GPT-J checkpoint from EleutherAI:
pip3 install accelerate pip3 install torch pip3 install transformers
To convert the checkpoint to SAX checkpoint, you need to install
paxml
:pip3 install paxml==1.1.0
Then, set the following variable:
>>PT_CHECKPOINT_PATH=./fine_tuned_pt_checkpoint
Download the fine tuned PyTorch checkpoint to
${PT_CHECKPOINT_PATH}
, followhttps://github.com/mlcommons/inference/blob/master/language/gpt-j/README.md#download-gpt-j-model
, and run the following commands:ls ${PT_CHECKPOINT_PATH}
This should list the following:
added_tokens.json generation_config.json pytorch_model.bin.index.json pytorch_model-00001-of-00003.bin pytorch_model-00002-of-00003.bin pytorch_model-00003-of-00003.bin special_tokens_map.json trainer_state.json config.json merges.txt tokenizer_config.json vocab.json
The following script converts the GPT-J checkpoint to SAX checkpoint:
python3 -m convert_gptj_ckpt --base EleutherAI/gpt-j-6b --pax pax_6b wget https://raw.githubusercontent.com/google/saxml/main/saxml/tools/convert_gptj_ckpt.py python3 -m convert_gptj_ckpt --base ${PT_CHECKPOINT_PATH} --pax .
This should generate content similar to the following:transformer.wte.weight (50401, 4096) transformer.h.0.ln_1.weight (4096,) transformer.h.0.ln_1.bias (4096,) transformer.h.0.attn.k_proj.weight (4096, 4096) . . . transformer.ln_f.weight (4096,) transformer.ln_f.bias (4096,) lm_head.weight (50401, 4096) lm_head.bias (50401,) Saving the pax model to . done
After the conversion is done, enter the following command:
ls checkpoint_00000000/
This should list the following:
metadate state
You need to create a commit_success file and placed in the sub directories:
CHECKPOINT_PATH=gs://${SAX_DATA_STORAGE_BUCKET}/path/to/checkpoint_00000000 gsutil -m cp -r checkpoint_00000000 ${CHECKPOINT_PATH} touch commit_success.txt gsutil cp commit_success.txt ${CHECKPOINT_PATH}/ gsutil cp commit_success.txt ${CHECKPOINT_PATH}/metadata/ gsutil cp commit_success.txt ${CHECKPOINT_PATH}/state/
Publish the model to SAX cluster
You can now publish GPT-J with the checkpoint converted in the previous step.
MODEL_NAME=gptjtokenizedbf16bs32 MODEL_CONFIG_PATH=saxml.server.pax.lm.params.gptj.GPTJ4TokenizedBF16BS32 REPLICA=1
To publish the GPT-J (and steps afterward), use SSH to connect to your TPU VM in a third terminal:
docker run \ ${SAX_UTIL_IMAGE_URL} \ --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ publish \ ${SAX_CELL}/${MODEL_NAME} \ ${MODEL_CONFIG_PATH} \ ${CHECKPOINT_PATH} \ ${REPLICA}
You will see a lot of activity from the model server Docker log until you see something like the following to indicate the model has loaded successfully:
I0829 01:33:49.287459 139865140229696 servable_model.py:697] loading completed.
Generate inference results
For GPT-J, input and output must be formatted as a comma separated token ID string. You will need to tokenize the text input.
TEXT = ("Below is an instruction that describes a task, paired with " "an input that provides further context. Write a response that " "appropriately completes the request.\n\n### Instruction\:\nSummarize the " "following news article\:\n\n### Input\:\nMarch 10, 2015 . We're truly " "international in scope on Tuesday. We're visiting Italy, Russia, the " "United Arab Emirates, and the Himalayan Mountains. Find out who's " "attempting to circumnavigate the globe in a plane powered partially by the " "sun, and explore the mysterious appearance of craters in northern Asia. " "You'll also get a view of Mount Everest that was previously reserved for " "climbers. On this page you will find today's show Transcript and a place " "for you to request to be on the CNN Student News Roll Call. TRANSCRIPT . " "Click here to access the transcript of today's CNN Student News program. " "Please note that there may be a delay between the time when the video is " "available and when the transcript is published. CNN Student News is " "created by a team of journalists who consider the Common Core State " "Standards, national standards in different subject areas, and state " "standards when producing the show. ROLL CALL . For a chance to be " "mentioned on the next CNN Student News, comment on the bottom of this page " "with your school name, mascot, city and state. We will be selecting " "schools from the comments of the previous show. You must be a teacher or a " "student age 13 or older to request a mention on the CNN Student News Roll " "Call! Thank you for using CNN Student News!\n\n### Response\:")
You can obtain the token IDs string through the EleutherAI/gpt-j-6b tokenizer:
from transformers import GPT2Tokenizer tokenizer = GPT2Tokenizer.from_pretrained(EleutherAI/gpt-j-6b)
Tokenize the input text:
encoded_example = tokenizer(TEXT) input_ids = encoded_example.input_ids INPUT_STR = ",".join([str(input_id) for input_id in input_ids])
You can expect a token ID string similar to the following:
>>> INPUT_STR '21106,318,281,12064,326,8477,257,4876,11,20312,351,281,5128,326,3769,2252,4732,13,19430,257,2882,326,20431,32543,262,2581,13,198,198,21017,46486,25,198,13065,3876,1096,262,1708,1705,2708,25,198,198,21017,23412,25,198,16192,838,11,1853,764,775,821,4988,3230,287,8354,319,3431,13,775,821,10013,8031,11,3284,11,262,1578,4498,24880,11,290,262,42438,22931,21124,13,9938,503,508,338,9361,284,2498,4182,615,10055,262,13342,287,257,6614,13232,12387,416,262,4252,11,290,7301,262,11428,5585,286,1067,8605,287,7840,7229,13,921,1183,635,651,257,1570,286,5628,41336,326,373,4271,10395,329,39311,13,1550,428,2443,345,481,1064,1909,338,905,42978,290,257,1295,329,345,284,2581,284,307,319,262,8100,13613,3000,8299,4889,13,48213,6173,46023,764,6914,994,284,1895,262,14687,286,1909,338,8100,13613,3000,1430,13,4222,3465,326,612,743,307,257,5711,1022,262,640,618,262,2008,318,1695,290,618,262,14687,318,3199,13,8100,13613,3000,318,2727,416,257,1074,286,9046,508,2074,262,8070,7231,1812,20130,11,2260,5423,287,1180,2426,3006,11,290,1181,5423,618,9194,262,905,13,15107,3069,42815,764,1114,257,2863,284,307,4750,319,262,1306,8100,13613,3000,11,2912,319,262,4220,286,428,2443,351,534,1524,1438,11,37358,11,1748,290,1181,13,775,481,307,17246,4266,422,262,3651,286,262,2180,905,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,6952,345,329,1262,8100,13613,3000,0,198,198,21017,18261,25'
To generate a summary for your article:
docker run \ ${SAX_UTIL_IMAGE_URL} \ --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ lm.generate \ ${SAX_CELL}/${MODEL_NAME} \ ${INPUT_STR}
You can expect something similar to:
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+ | GENERATE | SCORE | +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+ | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -0.023136413 | | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,50256 | -0.91842502 | | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -1.1726116 | | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -1.2472695 | +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
To detokenize the output token IDs string:
output_token_ids = [int(token_id) for token_id in OUTPUT_STR.split(',')] OUTPUT_TEXT = tokenizer.decode(output_token_ids, skip_special_tokens=True)
You can expect the detokenized text as:
>>> OUTPUT_TEXT 'This page includes the show Transcript.\nUse the Transcript to help students with reading comprehension and vocabulary.\nAt the bottom of the page, comment for a chance to be mentioned on CNN Student News. You must be a teacher or a student age 13 or older to request a mention on the CNN Student News Roll Call.'
Clean up your Docker containers and Cloud Storage storage buckets.
175B multi-host model serving preview
Some of the large language models will require a multi-host TPU slice, that is, v5litepod-16 and above. In those cases, all multi-host TPU hosts will need to have a copy of a SAX model server, and all model servers function as a SAX model server group to serve the large model on a multi-host TPU slice.
Create a new SAX cluster
You can follow the same step of Create a SAX cluster in the GPT-J walk through to create a new SAX cluster and a SAX admin server.
Or, if you already have an existing SAX cluster, you can launch a multi-host model server into your SAX cluster.
Launch a multi-host SAX model server into a SAX cluster
Use the same command to create a multi-host TPU slice as you use for a single-host TPU slice, just specify the appropriate multi-host accelerator type:
ACCELERATOR_TYPE=v5litepod-32 ZONE=us-east1-c gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --service-account ${SERVICE_ACCOUNT} \ --reserved
To pull the SAX model server image to all TPU hosts/workers and launch them:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=" gcloud auth configure-docker \ us-docker.pkg.dev # Pull SAX model server image docker pull ${SAX_MODEL_SERVER_IMAGE_URL} # Run model server docker run \ --privileged \ -it \ -d \ --rm \ --network host \ --name ${SAX_MODEL_SERVER_DOCKER_NAME} \ --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ ${SAX_MODEL_SERVER_IMAGE_URL} \ --sax_cell=${SAX_CELL} \ --port=10001 \ --platform_chip=tpuv4 \ --platform_topology=1x1"
Publish the model to SAX cluster
This example uses a LmCloudSpmd175B32Test model:
MODEL_NAME=lmcloudspmd175b32test MODEL_CONFIG_PATH=saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test CHECKPOINT_PATH=None REPLICA=1
To publish the test model:
docker run \ ${SAX_UTIL_IMAGE_URL} \ --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ publish \ ${SAX_CELL}/${MODEL_NAME} \ ${MODEL_CONFIG_PATH} \ ${CHECKPOINT_PATH} \ ${REPLICA}
Generate inference results
docker run \ ${SAX_UTIL_IMAGE_URL} \ --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ lm.generate \ ${SAX_CELL}/${MODEL_NAME} \ "Q: Who is Harry Porter's mother? A\: "
Note that since this example uses a test model with random weights, the output may not be meaningful.
Clean Up
Stop the docker containers:
docker stop ${SAX_ADMIN_SERVER_DOCKER_NAME} docker stop ${SAX_MODEL_SERVER_DOCKER_NAME}
Delete your Cloud Storage admin storage bucket and any data storage bucket using
gsutil
as shown below.gsutil rm -rf gs://${SAX_ADMIN_STORAGE_BUCKET} gsutil rm -rf gs://${SAX_DATA_STORAGE_BUCKET}
Profiling
After setting up the inference, profilers can be used to analyze the performance and TPU utilization. References to some profiling related documents are shown below:
Support and Feedback
We welcome all feedback! To share feedback or request support, reach out to us here or by emailing cloudtpu-preview-support@google.com
Terms
All information Google has provided to you regarding this Public Preview is Google's confidential information and subject to the confidentiality provisions in the Google Cloud Platform Terms of Service (or other agreement governing your use of Google Cloud Platform).