Cloud TPU v5e Inference introduction

Overview and Benefits

Cloud TPU v5e is Google Cloud's latest generation AI accelerator. With a smaller 256-chip footprint per Pod, v5e Pods are optimized for transformer-based, text-to-image and CNN-based training, fine-tuning, and serving.

Concepts

If you are new to Cloud TPUs, check out the TPU documentation home.

Chips

There are 256 chips in a single v5e with 8 chips per host. See System architecture for more details.

Cores

TPU chips have one or two TensorCores to run matrix multiplication. Similar to v2 and v3 Pods, v5e has one TensorCore per chip. By contrast, v4 Pods have 2 TensorCores per chip. See System architecture for more details on v5e TensorCores. Additional information about TensorCores can be found in this ACM article.

Host

A host is a physical computer (CPU) that runs VMs. A host can run multiple VMs at once.

Batch Inference

Batch or offline inference refers to doing inference outside of production pipelines typically on a bulk of inputs. Batch inference is used for offline tasks such as data labeling and also for evaluating the trained model. Latency SLOs are not a priority for batch inference.

Serving

Serving refers to the process of deploying a trained machine learning model to a production environment, where it can be used to make predictions or decisions. Latency SLOs are a priority for serving.

Single Host versus Multi Host

Slices using fewer than 8 chips use at most 1 host. Slices greater than 8 chips, have access to more than a single host and can run distributed training using multiple hosts.

Queued Resource

A representation of TPU resources, used to enqueue and manage a request for a single-slice or multi-slice TPU environment. See the Queued Resources user guide for more information.

TPU VM

A virtual machine running Linux that has access to the underlying TPU's. For v5e TPUs, each TPU VM has direct access to 1, 4, or 8 chips depending on the user-specified accelerator type. A TPU VM is also known as a worker.

Get started

Securing capacity

Contact Cloud Sales to start using Cloud TPU v5e for your AI workloads.

Prepare a Google Cloud Project

  1. Sign in to your Google Account. If you haven't already, sign up for a new account.

  2. In the Cloud Console, select or create a Cloud project from the project selector page.

  3. Billing setup is required for all Google Cloud usage so make sure billing is enabled for your project.

  4. Install gcloud alpha components.

  5. Enable the TPU API using the following gcloud command in Cloud Shell. (You may also enable it from the Google Cloud Console.)

    gcloud services enable tpu.googleapis.com
    
  6. Enable the TPU service account.

    Service accounts allow the Cloud TPU service to access other Google Cloud services. A user-managed service account is a recommended Google Cloud practice. Follow these guides to create and grant the following roles to your service account. The following roles are necessary:

    • TPU Admin
    • Storage Admin
    • Logs Writer
    • Monitoring Metric Writer
  7. Configure the project and zone.

    Your project ID is the name of your project shown on the Cloud console. The default zone for Cloud TPU v5e is us-west4-a.

    export PROJECT_ID=project-ID
    export ZONE=us-west4-a
    
    gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
    
    gcloud auth login
    gcloud config set project ${PROJECT}
    gcloud config set compute/zone ${ZONE}
    
  8. Provision the Cloud TPU v5e environment.

    A v5e is managed as a Queued Resource. Capacity can be provisioned using the queued-resource create command.

    Create environment variables for project ID, accelerator type, zone, runtime version, and TPU name.

    export PROJECT_ID=project_ID
    export ACCELERATOR_TYPE=v5litepod-1
    export ZONE=us-west4-a
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=service_account
    export TPU_NAME=tpu-name
    export QUEUED_RESOURCE_ID=queued_resource_id
    

    Variable descriptions

    PROJECT_ID
    Project Name. Use your Google project name.
    ACCELERATOR_TYPE
    See the Accelerator Types section for supported accelerator types.
    ZONE
    All inference capacity is in us-west4-a.
    RUNTIME_VERSION
    v2-alpha-tpuv5-lite
    SERVICE_ACCOUNT
    This is the email address of your service account that you can find in Google Cloud Console -> IAM -> Service Accounts For example: tpu-service-account@myprojectID.iam.gserviceaccount.com.
    TPU_NAME
    The user-assigned ID of the TPU which is created when the queued resource request is allocated.
    QUEUED_RESOURCE_ID
    The user-assigned ID of the queued resource request.
  9. Create a TPU resource.

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
    --node-id ${TPU_NAME} \
    --project ${PROJECT_ID} \
    --zone ${ZONE} \
    --accelerator-type ${ACCELERATOR_TYPE} \
    --runtime-version ${RUNTIME_VERSION} \
    --service-account ${SERVICE_ACCOUNT} \
    --${QUOTA_TYPE}
    

    If you would like to delete the resource you have reserved, you need to delete the resource TPU_NAME first and then also delete the queued resource request.

    gcloud alpha compute tpus delete $TPU_NAME --zone ${ZONE} --project ${PROJECT_ID}
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
    --project ${PROJECT_ID} \
    --zone ${ZONE}
    
  10. Connect to your TPU vM using SSH

    To run code on your TPU VMs, you need to SSH into each TPU VM. In this example, with a v5litepod-1, there is only one TPU VM.

    gcloud compute config-ssh
    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone $ZONE --project $PROJECT_ID
    

Manage your TPU VMs

For all TPU management options for your TPU VMs, see Managing TPUs.

Develop and Run

This section describes the general setup process for custom model inference using JAX or PyTorch on Cloud TPU v5e. TensorFlow support will be enabled soon.

For v5e training instructions, refer to the v5e training guide.

Running Inference on v5e

Inference Software Stack

Details of the Inference SW stack are covered in the following sections. This document focuses on single host serving for models trained with JAX, TensorFlow (TF), and PyTorch.

This section assumes that you have already set up your Google Cloud project according to the instructions in Prepare a Google Cloud project.

JAX Model Inference and Serving

The following section walks through the workflow for JAX model Inference. There are two paths for JAX inference as shown in the diagram. This section will cover the production path for JAX models through jax2tf and Cloud TPU Serving.

  1. Use jax2tf to convert the model to Cloud TPU 2 and save the model
  2. Use the Inference Converter to convert the saved model
  3. Use Cloud TPU Serving to serve the model

Use jax2tf to convert the model and save it

Refer to JAX and Cloud TPU interoperation to convert and save your JAX model to Cloud TPU.

# Inference function
def model_jax(params, inputs):
  return params[0] + params[1] * inputs

# Wrap the parameter constants as tf.Variables; this will signal to the model
# saving code to save those constants as variables, separate from the
# computation graph.
params_vars = tf.nest.map_structure(tf.Variable, params)

# Build the prediction function by closing over the `params_vars`. If you
# instead were to close over `params` your SavedModel would have no variables
# and the parameters will be included in the function graph.
prediction_tf = lambda inputs: jax2tf.convert(model_jax)(params_vars, inputs)

my_model = tf.Module()
# Tell the model saver what the variables are.
my_model._variables = tf.nest.flatten(params_vars)
my_model.f = tf.function(prediction_tf, jit_compile=True, autograph=False)
tf.saved_model.save(my_model)

Use the Inference Converter to convert the saved model

The steps for Inference Converter are described in the Inference converter guide.

Use Cloud TPU Serving

The steps for Cloud TPU Serving are described in Cloud TPU serving.

E2E JAX Model Serving Example

Prerequisite

You need to set up your Docker credentials and pull the Inference Converter and Cloud TPU Serving Docker image. If you have not already done so, run the following commands:

sudo usermod -a -G docker ${USER}
newgrp docker
gcloud auth configure-docker \
    us-docker.pkg.dev
docker pull us-docker.pkg.dev/cloud-tpu-images/inference/tpu-inference-converter-cli:2.13.0
docker pull us-docker.pkg.dev/cloud-tpu-images/inference/tf-serving-tpu:2.13.0

Download the Demo code: SSH to your TPU VM and install the inference Demo code.

gsutil -m cp -r \
  "gs://cloud-tpu-inference-public/demo" \
  .

Install the JAX demo dependencies On your TPU VM, install requirements.txt.

pip install -r ./demo/jax/requirements.txt

Run JAX BERT E2E Serving demo

The pretrained BERT model is from Hugging Face.

  1. Export a TPU-compatible TF2 saved model from a Flax BERT model:

    cd demo/jax/bert
    
    python3 export_bert_model.py
    
  2. Launch the Cloud TPU model server container for the model:

    docker run -t --rm --privileged -d \
     -p 8500:8500 -p 8501:8501 \
     --mount type=bind,source=/tmp/jax/bert_tpu,target=/models/bert \
     -e MODEL_NAME=bert \
     us-docker.pkg.dev/cloud-tpu-images/inference/tf-serving-tpu:2.13.0
    

    Check the model server container log and make sure the gRPC and Http Server is up:

    CONTAINER_ID=$(docker ps | grep "tf-serving-tpu" | awk '{print $1}')
    docker logs ${CONTAINER_ID}
    

    If you see the log ending with the following information, it means the server is ready to serve requests. It takes around 30 seconds.

    2023-04-08 00:43:10.481682: I tensorflow_serving/model_servers/server.cc:409] Running gRPC ModelServer at 0.0.0.0:8500 ...
    [warn] getaddrinfo: address family for nodename not supported
    2023-04-08 00:43:10.520578: I tensorflow_serving/model_servers/server.cc:430] Exporting HTTP/REST API at:localhost:8501 ...
    [evhttp_server.cc : 245] NET_LOG: Entering the event loop ...
    
  3. Send the request to the model server.

    python3 bert_request.py
    

    The output will be similar to the following:

    For input "The capital of France is [MASK].", the result is ". the capital of france is paris.."
    For input "Hello my name [MASK] Jhon, how can I [MASK] you?", the result is ". hello my name is jhon, how can i help you?."
    
  4. Clean up.

    Make sure to clean up the Docker container before running other demos.

    CONTAINER_ID=$(docker ps | grep "tf-serving-tpu" | awk '{print $1}')
    docker stop ${CONTAINER_ID}
    

    Clean up the model artifacts:

    sudo rm -rf /tmp/jax/
    

Run JAX Stable Diffusion E2E Serving demo

The pretrained Stable Diffusion model is from Hugging Face.

  1. Export TPU-compatible TF2 saved model from Flax Stable Diffusion model:

    cd demo/jax/stable_diffusion
    
    python3 export_stable_diffusion_model.py
    
  2. Launch the Cloud TPU model server container for the model:

    docker run -t --rm --privileged -d \
     -p 8500:8500 -p 8501:8501 \
     --mount type=bind,source=/tmp/jax/stable_diffusion_tpu,target=/models/stable_diffusion \
     -e MODEL_NAME=stable_diffusion \
     us-docker.pkg.dev/cloud-tpu-images/inference/tf-serving-tpu:2.13.0
    

    Check the model server container log and make sure the gRPC and Http Server is up:

    CONTAINER_ID=$(docker ps | grep "tf-serving-tpu" | awk '{print $1}')
    docker logs ${CONTAINER_ID}
    

    If you see the log ending with the following information, it means the server is ready to serve requests. It takes around 2 minutes.

    2023-04-08 00:43:10.481682: I tensorflow_serving/model_servers/server.cc:409] Running gRPC ModelServer at 0.0.0.0:8500 ...
    [warn] getaddrinfo: address family for nodename not supported
    2023-04-08 00:43:10.520578: I tensorflow_serving/model_servers/server.cc:430] Exporting HTTP/REST API at:localhost:8501 ...
    [evhttp_server.cc : 245] NET_LOG: Entering the event loop ...
    
  3. Send the request to the model server.

    python3 stable_diffusion_request.py
    

    The prompt is "Painting of a squirrel skating in New York" and the output image will be saved as stable_diffusion_images.jpg in your current directory.

  4. Clean up.

    Make sure to clean up the Docker container before running other demos.

    CONTAINER_ID=$(docker ps | grep "tf-serving-tpu" | awk '{print $1}')
    docker stop ${CONTAINER_ID}
    

    Clean up the model artifacts

    sudo rm -rf /tmp/jax/
    

Cloud TPU Model Inference and Serving

The following sections walk through the workflow for Cloud TPU Model Inference.

  1. Use the Inference Converter to convert the model
  2. Use Cloud TPU Serving to serve the model

Inference Converter

Cloud TPU Inference Converter prepares and optimizes a model exported from TensorFlow or JAX for TPU inference. The converter runs in a local shell or in the TPU VM shell. The TPU VM shell is recommended because it comes preinstalled with the command line tools needed for the converter. For more details on the Inference Converter refer to the Inference Converter User Guide.

Prerequisites

  1. The model must be exported from TensorFlow or JAX in the SavedModel format.

  2. The model must have a function alias for the TPU function. See the code examples in Inference Converter User Guide for instructions on how to do this. The following examples use tpu_func as the TPU function alias.

  3. Make sure your machine CPU supports Advanced Vector eXtensions (AVX) instructions, as the TensorFlow library (the dependency of the Cloud TPU Inference Converter) is compiled to use AVX instructions. Most CPUs have the AVX support.

    • You can run lscpu | grep avx to check whether the AVX instruction set is supported.

Getting Started

  • Set up TPU VM Environment Set up the environment using the following steps depending on the shell you are using:

    TPU VM Shell

    • In your TPU VM shell, run the following commands to allow non-root docker usage:
    sudo usermod -a -G docker ${USER}
    newgrp docker
    
    • Initialize your Docker Credential helpers:
    gcloud auth configure-docker \
      us-docker.pkg.dev
    

    Local Shell

    In your local shell, set up the environment using the following steps:

    • Install the Cloud SDK, which includes the gcloud command-line tool.

    • Install Docker:

    • Allow non-root Docker usage:

      sudo usermod -a -G docker ${USER}
      newgrp docker
      
    • Login in to your environment:

      gcloud auth login
      
    • Initialize your Docker Credential helpers:

      gcloud auth configure-docker \
      us-docker.pkg.dev
      
  • Pull the Inference Converter Docker image:

      CONVERTER_IMAGE=us-docker.pkg.dev/cloud-tpu-images/inference/tpu-inference-converter-cli:2.13.0
      docker pull ${CONVERTER_IMAGE}
      

Converter Image

The Image is for doing one-time model conversions. Set the model paths and adjust the converter options to fit your needs. The Usage Examples section in the Inference Converter User Guide provides several common use cases.

docker run \
--mount type=bind,source=${MODEL_PATH},target=/tmp/input,readonly \
--mount type=bind,source=${CONVERTED_MODEL_PATH},target=/tmp/output \
${CONVERTER_IMAGE} \
--input_model_dir=/tmp/input \
--output_model_dir=/tmp/output \
--converter_options_string='
    tpu_functions {
      function_alias: "tpu_func"
    }
    batch_options {
      num_batch_threads: 2
      max_batch_size: 8
      batch_timeout_micros: 5000
      allowed_batch_sizes: 2
      allowed_batch_sizes: 4
      allowed_batch_sizes: 8
      max_enqueued_batches: 10
    }
'

The following section shows how to run this model with a TensorFlow model Server.

TensorFlow Serving

The following instructions demonstrate how you can serve your TensorFlow model on TPU VMs.

Prerequisites:

Set up your Docker credentials, if you have not already done so:

  1. Download the TensorFlow Serving Docker Image for your TPU VM.

    Set sample environment variables

    export YOUR_LOCAL_MODEL_PATH=model-path
    export MODEL_NAME=model-name
    # Note: this image name may change later.
    export IMAGE_NAME=us-docker.pkg.dev/cloud-tpu-images/inference/tf-serving-tpu:2.13.0
    

    Download the Docker image

    docker pull ${IMAGE_NAME}
    
  2. Serve your TensorFlow model using the TensorFlow Serving Docker Image on your TPU VM.

    # PORT 8500 is for gRPC model server and 8501 is for HTTP/REST model server.
    docker run -t --rm --privileged -d \
     -p 8500:8500 -p 8501:8501 \
     --mount type=bind,source=${YOUR_LOCAL_MODEL_PATH},target=/models/${MODEL_NAME} \
     -e MODEL_NAME=${MODEL_NAME} \
     ${IMAGE_NAME}
    
  3. Follow the Serving Client API to query your model.

E2E TensorFlow Model Serving Example:

Prerequisite: Make sure you already set up the Docker credentials and pulled the Inference Converter and TensorFlow Serving Docker image. If not, run the following commands:

sudo usermod -a -G docker ${USER}
newgrp docker
gcloud auth configure-docker \
    us-docker.pkg.dev
docker pull us-docker.pkg.dev/cloud-tpu-images/inference/tpu-inference-converter-cli:2.13.0
docker pull us-docker.pkg.dev/cloud-tpu-images/inference/tf-serving-tpu:2.13.0

Download the Demo code:

gsutil -m cp -r \
  "gs://cloud-tpu-inference-public/demo" \
  .

Install the TensorFlow demo dependencies:

pip install -r ./demo/tf/requirements.txt

Run TensorFlow ResNet-50 E2E Serving demo

  1. Export a TPU-compatible TF2 saved model from the Keras ResNet-50 model.

    cd demo/tf/resnet-50
    
    python3 export_resnet_model.py
    
  2. Launch the TensorFlow model server container for the model.

    docker run -t --rm --privileged -d \
     -p 8500:8500 -p 8501:8501 \
     --mount type=bind,source=/tmp/tf/resnet_tpu,target=/models/resnet \
     -e MODEL_NAME=resnet \
     us-docker.pkg.dev/cloud-tpu-images/inference/tf-serving-tpu:2.13.0
    

    Check the model server container log and make sure the gRPC and Http Server is up:

    CONTAINER_ID=$(docker ps | grep "tf-serving-tpu" | awk '{print $1}')
    docker logs ${CONTAINER_ID}
    

    If you see the log ending with the following information, it means the server is ready to serve requests. It takes around 30 seconds.

    2023-04-08 00:43:10.481682: I tensorflow_serving/model_servers/server.cc:409] Running gRPC ModelServer at 0.0.0.0:8500 ...
    [warn] getaddrinfo: address family for nodename not supported
    2023-04-08 00:43:10.520578: I tensorflow_serving/model_servers/server.cc:430] Exporting HTTP/REST API at:localhost:8501 ...
    [evhttp_server.cc : 245] NET_LOG: Entering the event loop ...
    
  3. Send the request to the model server.

    The request image is a banana from https://i.imgur.com/j9xCCzn.jpeg .

    python3 resnet_request.py
    

    The output will be similar to the following:

    Predict result: [[('n07753592', 'banana', 0.94921875), ('n03532672', 'hook', 0.022338867), ('n07749582', 'lemon', 0.005126953)]]
    
  4. Clean up.

    Make sure to clean up the Docker container before running other demos.

    CONTAINER_ID=$(docker ps | grep "tf-serving-tpu" | awk '{print $1}')
    docker stop ${CONTAINER_ID}
    

    Clean up the model artifacts:

    sudo rm -rf /tmp/tf/
    

PyTorch Model Inference and Serving

The following sections walk through the workflow for PyTorch Model Inference:

  1. Write a Python model handler for loading and inferencing using TorchDynamo and PyTorch/XLA
  2. Use TorchModelArchiver to create a model archive
  3. Use TorchServe to serve the model

TorchDynamo and PyTorch/XLA

TorchDynamo (Dynamo) is a Python-level JIT compiler designed to make unmodified PyTorch programs faster. It provides a clean API for compiler backends to hook into. Its biggest feature is to dynamically modify Python bytecode just before execution. In the PyTorch/XLA 2.0 release, an experimental backend for Dynamo is provided for both inference and training.

Dynamo provides a Torch FX (FX) graph when it recognizes a model pattern and PyTorch/XLA uses a Lazy Tensor approach to compile the FX graph and return the compiled function. To get more insight regarding the technical details about PyTorch/XLA's dynamo implementation, see the Pytorch Dev Discussions post dev-discuss post and the TorchDynamo documentation. See this blog for more details.

Here is a small code example of running densenet161 inference with torch.compile.

import torch
import torchvision
import torch_xla.core.xla_model as xm

def eval_model(loader):
  device = xm.xla_device()
  xla_densenet161 = torchvision.models.densenet161().to(device)
  xla_densenet161.eval()
  dynamo_densenet161 = torch.compile(
      xla_densenet161, backend='torchxla_trace_once')
  for data, _ in loader:
    output = dynamo_densenet161(data)

TorchServe

The Cloud TPU TorchServe Docker Image lets you to serve the PyTorch eager mode model using TorchServe on a TPU VM.

You can use the provided torchserve-tpu Docker image that is ready for serving your archived pytorch model on a TPU VM.

Set up authentication for Docker:

sudo usermod -a -G docker ${USER}
newgrp docker
gcloud auth configure-docker \
    us-docker.pkg.dev

Pull the Cloud TPU TorchServe Docker image to your TPU VM:

CLOUD_TPU_TORCHSERVE_IMAGE_URL=us-docker.pkg.dev/cloud-tpu-images/inference/torchserve-tpu:v0.9.0-2.1
docker pull ${CLOUD_TPU_TORCHSERVE_IMAGE_URL}

Collect Model Artifacts

To get started, you need to provide a model handler, which instructs the TorchServe model server worker to load your model, process the input data and run inference. You can use the TorchServe default inference handlers (source), or develop your own custom model handler following the base_handler.py. You may also need to provide the trained model, and the model definition file.

In the following Densenet 161 example, we use model artifacts and the default image classifier handler provided by TorchServe:

The work directory is shown below.

CWD="$(pwd)"

WORKDIR="${CWD}/densenet_161"

mkdir -p ${WORKDIR}/model-store
mkdir -p ${WORKDIR}/logs
  1. Download and copy model artifacts from the TorchServe image classifier example:

    git clone https://github.com/pytorch/serve.git
    
    cp ${CWD}/serve/examples/image_classifier/densenet_161/model.py ${WORKDIR}
    cp ${CWD}/serve/examples/image_classifier/index_to_name.json ${WORKDIR}
    
  2. Download the model weights:

    wget https://download.pytorch.org/models/densenet161-8d451a50.pth -O densenet161-8d451a50.pth
    
    mv densenet161-8d451a50.pth ${WORKDIR}
    
  3. Create a TorchServe model config file to use the Dynamo backend:

    echo 'pt2: "torchxla_trace_once"' >> ${WORKDIR}/model_config.yaml
    

    You should see the files and directories shown below:

    >> ls ${WORKDIR}
    model_config.yaml
    index_to_name.json
    logs
    model.py
    densenet161-8d451a50.pth
    model-store
    

Generate a model archive file

To serve your PyTorch model with Cloud TPU TorchServe, you need to package your model handler and all your model artifacts into a model archive file (*.mar) using Torch Model Archiver.

Generate a model archive file with torch-model-archiver:

MODEL_NAME=Densenet161

docker run \
    --privileged  \
    --shm-size 16G \
    --name torch-model-archiver \
    -it \
    -d \
    --rm \
    --mount type=bind,source=${WORKDIR},target=/home/model-server/ \
    ${CLOUD_TPU_TORCHSERVE_IMAGE_URL} \
    torch-model-archiver \
        --model-name ${MODEL_NAME} \
        --version 1.0 \
        --model-file model.py \
        --serialized-file densenet161-8d451a50.pth \
        --handler image_classifier \
        --export-path model-store \
        --extra-files index_to_name.json \
        --config-file model_config.yaml

You should see the model archive file generated in the model-store directory:

>> ls ${WORKDIR}/model-store
Densenet161.mar

Serve inference requests

Now you have the model archive file, you can start the TorchServe model server and serve inference requests.

  1. Start the TorchServe model server:

    docker run \
       --privileged  \
       --shm-size 16G \
       --name torchserve-tpu \
       -it \
       -d \
       --rm \
       -p 7070:7070 \
       -p 7071:7071 \
       -p 8080:8080 \
       -p 8081:8081 \
       -p 8082:8082 \
       -p 9001:9001 \
       -p 9012:9012 \
       --mount type=bind,source=${WORKDIR}/model-store,target=/home/model-server/model-store \
       --mount type=bind,source=${WORKDIR}/logs,target=/home/model-server/logs \
       ${CLOUD_TPU_TORCHSERVE_IMAGE_URL} \
       torchserve \
           --start \
           --ncs \
           --models ${MODEL_NAME}.mar \
           --ts-config /home/model-server/config.properties
    
  2. Query model server health:

    curl http://localhost:8080/ping
    

    If the model server is up and running, you will see:

    {
     "status": "Healthy"
    }
    

    To query the default versions of the current registered model use:

    curl http://localhost:8081/models
    

    You should see the registered model:

    {
     "models": [
       {
         "modelName": "Densenet161",
         "modelUrl": "Densenet161.mar"
       }
     ]
    }
    

    To download an image for inference use:

    curl -O https://raw.githubusercontent.com/pytorch/serve/master/docs/images/kitten_small.jpg
    
    mv kitten_small.jpg ${WORKDIR}
    

    To send an inference request to the model server use:

    curl http://localhost:8080/predictions/${MODEL_NAME} -T ${WORKDIR}/kitten_small.jpg
    

    You should see a response similar to the following:

    {
     "tabby": 0.47878125309944153,
     "lynx": 0.20393909513950348,
     "tiger_cat": 0.16572578251361847,
     "tiger": 0.061157409101724625,
     "Egyptian_cat": 0.04997897148132324
    }
    
  3. Model server logs

    Use the following commands to access the logs:

    ls ${WORKDIR}/logs/
    cat ${WORKDIR}/logs/model_log.log
    

    You should see the following message in your log:

    "Compiled model with backend torchxla\_trace\_once"
    

Clean Up

Stop the Docker container:

rm -rf serve
rm -rf ${WORKDIR}

docker stop torch-model-archiver
docker stop torchserve-tpu

Large Language Model Serving

SAX is a serving framework to support serving large models that may require TPUs on multiple hosts to run with GSPMD, such as PAX-based Large Language Models. PAX is a framework on top of JAX, and is for training large-scale models that allows for advanced and fully configurable experimentation and parallelization.

The SAX cluster section describes the key elements to understand how SAX works. The SAX model serving section walks through a single-host model serving example with a GPTJ6B model. SAX also provides multi-host serving on Cloud TPUs and users can run models on larger TPU topologies for experimental multi-host serving. The following example with a 175B test model shows how to experiment with this setup.

SAX cluster (SAX cell)

SAX admin server and SAX model server are two essential components that run a SAX cluster.

SAX admin server

The SAX admin server monitors and coordinates all SAX model servers in a SAX cluster. In a SAX cluster, you can launch multiple SAX admin servers, where only one of the SAX admin server is active through leader election, the others are standby servers. When the active admin server fails, a standby admin server will become active. The active SAX admin server assigns model replicas and inference requests to available SAX model servers.

SAX admin storage bucket

Each SAX cluster requires a Cloud Storage bucket to store the configurations and locations of SAX admin servers and SAX model servers in the SAX cluster.

SAX model server

The SAX model server loads a model checkpoint and runs inference with GSPMD. A SAX model server runs on a single TPU VM worker. Single-host TPU model serving requires a single SAX model server on a single-host TPU VM. Multi-host TPU model serving requires a group of SAX model servers on a multi-host TPU slice.

SAX model serving

The following section walks through the workflow for serving language models using SAX. It uses the GPT-J 6B model as an example for single-host model serving, and a 175B test model for multi-host model serving.

Before starting, install the Cloud TPU SAX Docker images on your TPU VM:

sudo usermod -a -G docker ${USER}
newgrp docker

gcloud auth configure-docker us-docker.pkg.dev

SAX_ADMIN_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server"
SAX_MODEL_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server"
SAX_UTIL_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-util"

SAX_VERSION=v1.1.0

export SAX_ADMIN_SERVER_IMAGE_URL=${SAX_ADMIN_SERVER_IMAGE_NAME}:${SAX_VERSION}
export SAX_MODEL_SERVER_IMAGE_URL=${SAX_MODEL_SERVER_IMAGE_NAME}:${SAX_VERSION}
export SAX_UTIL_IMAGE_URL="${SAX_UTIL_IMAGE_NAME}:${SAX_VERSION}"

docker pull ${SAX_ADMIN_SERVER_IMAGE_URL}
docker pull ${SAX_MODEL_SERVER_IMAGE_URL}
docker pull ${SAX_UTIL_IMAGE_URL}

Set some other variables you will use later:

export SAX_ADMIN_SERVER_DOCKER_NAME="sax-admin-server"
export SAX_MODEL_SERVER_DOCKER_NAME="sax-model-server"
export SAX_CELL="/sax/test"

GPT-J 6B single-host model serving example

Single-host model serving is applicable to single-host TPU slice, that is, v5litepod-1, v5litepod-4 and v5litepod-8.

  1. Create a SAX cluster

    1. Create a Cloud Storage storage bucket for the SAX cluster:

      SAX_ADMIN_STORAGE_BUCKET=${your_admin_storage_bucket}
      
      gcloud storage buckets create gs://${SAX_ADMIN_STORAGE_BUCKET} \
      --project=${PROJECT_ID}
      

      You might need another Cloud Storage storage bucket to store the checkpoint.

      SAX_DATA_STORAGE_BUCKET=${your_data_storage_bucket}
      
    2. Connect to your TPU VM using SSH in a terminal to launch the SAX admin server:

      docker run \
      --name ${SAX_ADMIN_SERVER_DOCKER_NAME} \
      -it \
      -d \
      --rm \
      --network host \
      --env GSBUCKET=${SAX_ADMIN_STORAGE_BUCKET} \
      ${SAX_ADMIN_SERVER_IMAGE_URL}
      

      You can check the Docker log by:

      docker logs -f ${SAX_ADMIN_SERVER_DOCKER_NAME}
      

      The output in the log will look similar to the following:

      I0829 01:22:31.184198       7 config.go:111] Creating config fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.347883       7 config.go:115] Created config fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.360837      24 admin_server.go:44] Starting the server
      I0829 01:22:31.361420      24 ipaddr.go:39] Skipping non-global IP address 127.0.0.1/8.
      I0829 01:22:31.361455      24 ipaddr.go:39] Skipping non-global IP address ::1/128.
      I0829 01:22:31.361462      24 ipaddr.go:39] Skipping non-global IP address fe80::4001:aff:fe8e:fc8/64.
      I0829 01:22:31.361469      24 ipaddr.go:39] Skipping non-global IP address fe80::42:bfff:fef9:1bd3/64.
      I0829 01:22:31.361474      24 ipaddr.go:39] Skipping non-global IP address fe80::20fb:c3ff:fe5b:baac/64.
      I0829 01:22:31.361482      24 ipaddr.go:56] IPNet address 10.142.15.200
      I0829 01:22:31.361488      24 ipaddr.go:56] IPNet address 172.17.0.1
      I0829 01:22:31.456952      24 admin.go:305] Loaded config: fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.609323      24 addr.go:105] SetAddr /gcs/test_sax_admin/sax-root/sax/test/location.proto "10.142.15.200:10000"
      I0829 01:22:31.656021      24 admin.go:325] Updated config: fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.773245      24 mgr.go:781] Loaded manager state
      I0829 01:22:31.773260      24 mgr.go:784] Refreshing manager state every 10s
      I0829 01:22:31.773285      24 admin.go:350] Starting the server on port 10000
      I0829 01:22:31.773292      24 cloud.go:506] Starting the HTTP server on port 8080
      
  2. Launch a single-host SAX model server into the SAX cluster:

    At this point, the SAX cluster contains only the SAX admin server. You can connect to your TPU VM over SSH in a second terminal to launch a SAX model server in your SAX cluster:

    docker run \
        --privileged  \
        -it \
        -d \
        --rm \
        --network host \
        --name ${SAX_MODEL_SERVER_DOCKER_NAME} \
        --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        ${SAX_MODEL_SERVER_IMAGE_URL} \
           --sax_cell=${SAX_CELL} \
           --port=10001 \
           --platform_chip=tpuv4 \
           --platform_topology=1x1
    

    You can check the Docker log by:

    docker logs -f ${SAX_MODEL_SERVER_DOCKER_NAME}
    

  3. Convert model checkpoint:

    You need to install PyTorch and Transformers to download the GPT-J checkpoint from EleutherAI:

    pip3 install accelerate
    pip3 install torch
    pip3 install transformers
    

    To convert the checkpoint to SAX checkpoint, you need to install paxml:

    pip3 install paxml==1.1.0
    

    Then, set the following variable:

    >>PT_CHECKPOINT_PATH=./fine_tuned_pt_checkpoint
    

    Download the fine tuned PyTorch checkpoint to ${PT_CHECKPOINT_PATH}, follow https://github.com/mlcommons/inference/blob/master/language/gpt-j/README.md#download-gpt-j-model, and run the following commands:

    ls ${PT_CHECKPOINT_PATH}
    

    This should list the following:

    added_tokens.json  generation_config.json
    pytorch_model.bin.index.json
    pytorch_model-00001-of-00003.bin
    pytorch_model-00002-of-00003.bin
    pytorch_model-00003-of-00003.bin
    special_tokens_map.json
    trainer_state.json
    config.json
    merges.txt
    tokenizer_config.json
    vocab.json
    

    The following script converts the GPT-J checkpoint to SAX checkpoint, we use ${PT_CHECKPOINT_PATH} as the base model checkpoint, and after conversion, you will find the converted checkpoint in ${CONVERTED_CHECKPOINT_PATH}:

    wget https://raw.githubusercontent.com/google/saxml/main/saxml/tools/convert_gptj_ckpt.py
    python3 -m convert_gptj_ckpt --base ${PT_CHECKPOINT_PATH} --pax ${CONVERTED_CHECKPOINT_PATH}
    

    This should print output similar to the following:

    transformer.wte.weight (50401, 4096)
    transformer.h.0.ln_1.weight (4096,)
    transformer.h.0.ln_1.bias (4096,)
    transformer.h.0.attn.k_proj.weight (4096, 4096)
    .
    .
    .
    transformer.ln_f.weight (4096,)
    transformer.ln_f.bias (4096,)
    lm_head.weight (50401, 4096)
    lm_head.bias (50401,)
    Saving the pax model to .
    done
    

    After the conversion is done, enter the following command:

    ls checkpoint_00000000/
    

    This should list the following:

    metadate
    state
    

    You need to create a commit_success file and placed in the sub directories:

    CHECKPOINT_PATH=gs://${SAX_DATA_STORAGE_BUCKET}/path/to/checkpoint_00000000
    gsutil -m cp -r checkpoint_00000000 ${CHECKPOINT_PATH}
    
    touch commit_success.txt
    gsutil cp commit_success.txt ${CHECKPOINT_PATH}/
    gsutil cp commit_success.txt ${CHECKPOINT_PATH}/metadata/
    gsutil cp commit_success.txt ${CHECKPOINT_PATH}/state/
    
  4. Publish the model to SAX cluster

    You can now publish GPT-J with the checkpoint converted in the previous step.

    MODEL_NAME=gptj4bf16bs32
    MODEL_CONFIG_PATH=saxml.server.pax.lm.params.gptj.GPTJ4BF16BS32
    REPLICA=1
    

    To publish the GPT-J (and steps afterward), use SSH to connect to your TPU VM in a third terminal:

    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       publish \
         ${SAX_CELL}/${MODEL_NAME} \
         ${MODEL_CONFIG_PATH} \
         ${CHECKPOINT_PATH} \
         ${REPLICA}
    

    You will see a lot of activity from the model server Docker log until you see something like the following to indicate the model has loaded successfully:

    I0829 01:33:49.287459 139865140229696 servable_model.py:697] loading completed.
    

    To list published models:

    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       list ${SAX_CELL}
    

    You will see:

    +---+---------------+
    | # |   MODEL ID    |
    +---+---------------+
    | 0 | gptj4bf16bs32 |
    +---+---------------+
    

    If you want to use tokenized input and generate tokenized output, instead of using the above model config, you can publish using:

    MODEL_NAME=gptj4tokenizedbf16bs32
    MODEL_CONFIG_PATH=saxml.server.pax.lm.params.gptj.GPTJ4TokenizedBF16BS32
    REPLICA=1
    
  5. Generate inference results

    To generate a summary for your article using GPTJ4BF16BS32:

    TEXT = ("Below is an instruction that describes a task, paired with
    an input that provides further context. Write a response that
    appropriately completes the request.\n\n### Instruction\:\nSummarize the
    following news article\:\n\n### Input\:\nMarch 10, 2015 . We're truly
    international in scope on Tuesday. We're visiting Italy, Russia, the
    United Arab Emirates, and the Himalayan Mountains. Find out who's
    attempting to circumnavigate the globe in a plane powered partially by the
    sun, and explore the mysterious appearance of craters in northern Asia.
    You'll also get a view of Mount Everest that was previously reserved for
    climbers. On this page you will find today's show Transcript and a place
    for you to request to be on the CNN Student News Roll Call. TRANSCRIPT .
    Click here to access the transcript of today's CNN Student News program.
    Please note that there may be a delay between the time when the video is
    available and when the transcript is published. CNN Student News is
    created by a team of journalists who consider the Common Core State
    Standards, national standards in different subject areas, and state
    standards when producing the show. ROLL CALL . For a chance to be
    mentioned on the next CNN Student News, comment on the bottom of this page
    with your school name, mascot, city and state. We will be selecting
    schools from the comments of the previous show. You must be a teacher or a
    student age 13 or older to request a mention on the CNN Student News Roll
    Call! Thank you for using CNN Student News!\n\n### Response\:")
    
    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       lm.generate \
         ${SAX_CELL}/${MODEL_NAME} \
         ${TEXT}
    

    You can expect something similar to:

    +--------------------------------+------------+
    |            GENERATE            |   SCORE    |
    +--------------------------------+------------+
    |  This page includes the        | -1.0517541 |
    | show Transcript.  The daily    |            |
    | transcript is a written        |            |
    | version of each day's CNN      |            |
    | Student News program. Use the  |            |
    | Transcript to help students    |            |
    | with reading comprehension and |            |
    | vocabulary. At the bottom of   |            |
    | the page, comment for a chance |            |
    | to be mentioned on CNN Student |            |
    | News.  You must be a teacher   |            |
    | or a student age 13 or older   |            |
    | to request a mention on the    |            |
    | CNN Student News Roll Call.    |            |
    |  ...                           |            |
    +--------------------------------+------------+
    

    If you are using the GPTJ4TokenizedBF16BS32, the input must be formatted as a comma separated token ID string. You will need to tokenize the text input.

    TEXT=("Below is an instruction that describes a task, paired with "
    "an input that provides further context. Write a response that "
    "appropriately completes the request.\n\n### Instruction\:\nSummarize the "
    "following news article\:\n\n### Input\:\nMarch 10, 2015 . We're truly "
    "international in scope on Tuesday. We're visiting Italy, Russia, the "
    "United Arab Emirates, and the Himalayan Mountains. Find out who's "
    "attempting to circumnavigate the globe in a plane powered partially by the "
    "sun, and explore the mysterious appearance of craters in northern Asia. "
    "You'll also get a view of Mount Everest that was previously reserved for "
    "climbers. On this page you will find today's show Transcript and a place "
    "for you to request to be on the CNN Student News Roll Call. TRANSCRIPT . "
    "Click here to access the transcript of today's CNN Student News program. "
    "Please note that there may be a delay between the time when the video is "
    "available and when the transcript is published. CNN Student News is "
    "created by a team of journalists who consider the Common Core State "
    "Standards, national standards in different subject areas, and state "
    "standards when producing the show. ROLL CALL . For a chance to be "
    "mentioned on the next CNN Student News, comment on the bottom of this page "
    "with your school name, mascot, city and state. We will be selecting "
    "schools from the comments of the previous show. You must be a teacher or a "
    "student age 13 or older to request a mention on the CNN Student News Roll "
    "Call! Thank you for using CNN Student News!\n\n### Response\:")
    

    You can obtain the token IDs string through the EleutherAI/gpt-j-6b tokenizer:

    from transformers import GPT2Tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained(${PT_CHECKPOINT_PATH})
    

    Tokenize the input text:

    encoded_example = tokenizer(TEXT)
    input_ids = encoded_example.input_ids
    INPUT_STR = ",".join([str(input_id) for input_id in input_ids])
    

    You can expect a token ID string similar to the following:

    >>> INPUT_STR
    '21106,318,281,12064,326,8477,257,4876,11,20312,351,281,5128,326,3769,2252,4732,13,19430,257,2882,326,20431,32543,262,2581,13,198,198,21017,46486,25,198,13065,3876,1096,262,1708,1705,2708,25,198,198,21017,23412,25,198,16192,838,11,1853,764,775,821,4988,3230,287,8354,319,3431,13,775,821,10013,8031,11,3284,11,262,1578,4498,24880,11,290,262,42438,22931,21124,13,9938,503,508,338,9361,284,2498,4182,615,10055,262,13342,287,257,6614,13232,12387,416,262,4252,11,290,7301,262,11428,5585,286,1067,8605,287,7840,7229,13,921,1183,635,651,257,1570,286,5628,41336,326,373,4271,10395,329,39311,13,1550,428,2443,345,481,1064,1909,338,905,42978,290,257,1295,329,345,284,2581,284,307,319,262,8100,13613,3000,8299,4889,13,48213,6173,46023,764,6914,994,284,1895,262,14687,286,1909,338,8100,13613,3000,1430,13,4222,3465,326,612,743,307,257,5711,1022,262,640,618,262,2008,318,1695,290,618,262,14687,318,3199,13,8100,13613,3000,318,2727,416,257,1074,286,9046,508,2074,262,8070,7231,1812,20130,11,2260,5423,287,1180,2426,3006,11,290,1181,5423,618,9194,262,905,13,15107,3069,42815,764,1114,257,2863,284,307,4750,319,262,1306,8100,13613,3000,11,2912,319,262,4220,286,428,2443,351,534,1524,1438,11,37358,11,1748,290,1181,13,775,481,307,17246,4266,422,262,3651,286,262,2180,905,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,6952,345,329,1262,8100,13613,3000,0,198,198,21017,18261,25'
    

    To generate a summary for your article:

    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       lm.generate \
         ${SAX_CELL}/${MODEL_NAME} \
         ${INPUT_STR}
    

    You can expect something similar to:

    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    |                                                                                                                                                    GENERATE                                                                                                                                                    |    SCORE     |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -0.023136413 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,50256  |  -0.91842502 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256     |   -1.1726116 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256                            |   -1.2472695 |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    

    To detokenize the output token IDs string:

    output_token_ids = [int(token_id) for token_id in OUTPUT_STR.split(',')]
    OUTPUT_TEXT = tokenizer.decode(output_token_ids, skip_special_tokens=True)
    

    You can expect the detokenized text as:

    >>> OUTPUT_TEXT
    'This page includes the show Transcript.\nUse the Transcript to help
    students with reading comprehension and vocabulary.\nAt the bottom of
    the page, comment for a chance to be mentioned on CNN Student News.
    You must be a teacher or a student age 13 or older to request a mention on
    the CNN Student News Roll Call.'
    
  6. Clean up your Docker containers and Cloud Storage storage buckets.

175B multi-host model serving

Some of the large language models will require a multi-host TPU slice, that is, v5litepod-16 and above. In those cases, all multi-host TPU hosts will need to have a copy of a SAX model server, and all model servers function as a SAX model server group to serve the large model on a multi-host TPU slice.

  1. Create a new SAX cluster

    You can follow the same step of Create a SAX cluster in the GPT-J walk through to create a new SAX cluster and a SAX admin server.

    Or, if you already have an existing SAX cluster, you can launch a multi-host model server into your SAX cluster.

  2. Launch a multi-host SAX model server into a SAX cluster

    Use the same command to create a multi-host TPU slice as you use for a single-host TPU slice, just specify the appropriate multi-host accelerator type:

    ACCELERATOR_TYPE=v5litepod-32
    ZONE=us-east1-c
    
    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
     --node-id ${TPU_NAME} \
     --project ${PROJECT_ID} \
     --zone ${ZONE} \
     --accelerator-type ${ACCELERATOR_TYPE} \
     --runtime-version ${RUNTIME_VERSION} \
     --service-account ${SERVICE_ACCOUNT} \
     --${QUOTA_TYPE}
    
    Note: The QUOTA_TYPE flag can be either reserved or best-effort. See Quota Types for information on the different types of quotas supported by Cloud TPU. Use the following command to pull the SAX model server image to all TPU hosts/workers and launch them: To pull the SAX model server image to all TPU hosts/workers and launch them:
    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
     --project ${PROJECT_ID} \
     --zone ${ZONE} \
     --worker=all \
     --command="
       gcloud auth configure-docker \
         us-docker.pkg.dev
       # Pull SAX model server image
       docker pull ${SAX_MODEL_SERVER_IMAGE_URL}
       # Run model server
       docker run \
         --privileged  \
         -it \
         -d \
         --rm \
         --network host \
         --name ${SAX_MODEL_SERVER_DOCKER_NAME} \
         --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
         ${SAX_MODEL_SERVER_IMAGE_URL} \
           --sax_cell=${SAX_CELL} \
           --port=10001 \
           --platform_chip=tpuv4 \
           --platform_topology=1x1"
    
  3. Publish the model to SAX cluster

    This example uses a LmCloudSpmd175B32Test model:

    MODEL_NAME=lmcloudspmd175b32test
    MODEL_CONFIG_PATH=saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test
    CHECKPOINT_PATH=None
    REPLICA=1
    

    To publish the test model:

    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       publish \
         ${SAX_CELL}/${MODEL_NAME} \
         ${MODEL_CONFIG_PATH} \
         ${CHECKPOINT_PATH} \
         ${REPLICA}
    
  4. Generate inference results

    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       lm.generate \
         ${SAX_CELL}/${MODEL_NAME} \
         "Q:  Who is Harry Porter's mother? A\: "
    

    Note that since this example uses a test model with random weights, the output may not be meaningful.

  5. Clean Up

    Stop the docker containers:

    docker stop ${SAX_ADMIN_SERVER_DOCKER_NAME}
    docker stop ${SAX_MODEL_SERVER_DOCKER_NAME}
    

    Delete your Cloud Storage admin storage bucket and any data storage bucket using gsutil as shown below.

    gsutil rm -rf gs://${SAX_ADMIN_STORAGE_BUCKET}
    gsutil rm -rf gs://${SAX_DATA_STORAGE_BUCKET}
    

Profiling

After setting up the inference, profilers can be used to analyze the performance and TPU utilization. References to some profiling related documents are shown below:

Support and Feedback

We welcome all feedback! To share feedback or request support, reach out to us here or by emailing cloudtpu-support@google.com

Terms

All information Google has provided to you regarding this software release is Google's confidential information and subject to the confidentiality provisions in the Google Cloud Platform Terms of Service (or other agreement governing your use of Google Cloud Platform).