Cloud TPU v5e Inference introduction

Overview and benefits

Cloud TPU v5e is a Google-developed AI accelerator optimized for transformer-based, text-to-image and CNN-based training, fine-tuning, and serving (inference). TPU v5e slices can contain up to 256 chips.

Serving refers to the process of deploying a trained machine learning model to a production environment, where it can be used for inference. Latency SLOs are a priority for serving.

This document discusses serving a model on a single-host TPU. TPU slices with 8 or less chips have one TPU VM or host and are called single-host TPUs.

Get started

You will need quota for v5e TPUs. On-demand TPUs require tpu-v5s-litepod-serving quota. Reserved TPUs require tpu-v5s-litepod-serving-reserved quota. For more information, contact Cloud Sales.

You will need a Google Cloud account and project to use Cloud TPU. For more information, see Set up a Cloud TPU environment

You provision v5e TPUs using Queued resources. For more information on available v5e configurations for serving, see Cloud TPU v5e types for serving.

Cloud TPU model inference and serving

How you serve a model for inference depends on the ML framework your model was written with. TPU v5e supports serving models written in JAX, TensorFlow, and PyTorch.

JAX model inference and serving

To serve a model on a TPU VM, you need to:

  1. Serialize your model in TensorFlow SavedModel format
  2. Use the Inference Converter to prepare the saved model for serving
  3. Use TensorFlow Serving to serve the model

SavedModel format

A SavedModel contains a complete TensorFlow program, including trained parameters and computation. It does not require the original model building code to run.

If your model was written in JAX, you will need to use jax2tf to serialize your model in the SavedModel format.

Inference Converter

Cloud TPU Inference Converter prepares and optimizes a model exported in SavedModel format for TPU inference. You can run the inference converter in a local shell or your TPU VM. We recommend using your TPU VM shell because it has all the command line tools needed for running the converter. For more information about the Inference Converter, see the Inference Converter User Guide.

Inference Converter requirements

  1. Your model must be exported from TensorFlow or JAX in the SavedModel format.

  2. You must define a function alias for the TPU function. For more information, see the Inference Converter User Guide. The examples in this guide use tpu_func as the TPU function alias.

  3. Make sure your machine CPU supports Advanced Vector eXtensions (AVX) instructions, as the TensorFlow library (the dependency of the Cloud TPU Inference Converter) is compiled to use AVX instructions. Most CPUs have the AVX support.

JAX model inference and serving

This section describes how to serve JAX models using jax2tf and TensorFlow Serving.

  1. Use jax2tf to serialize your model into the SavedModel format
  2. Use the Inference Converter to prepare your saved model for serving
  3. Use TensorFlow Serving to serve the model

Use jax2tf to serialize a JAX model to the SavedModel format

The following Python function shows how to use jax2tf within your model code:

# Inference function
def model_jax(params, inputs):
  return params[0] + params[1] * inputs

# Wrap the parameter constants as tf.Variables; this will signal to the model
# saving code to save those constants as variables, separate from the
# computation graph.
params_vars = tf.nest.map_structure(tf.Variable, params)

# Build the prediction function by closing over the `params_vars`. If you
# instead were to close over `params` your SavedModel would have no variables
# and the parameters will be included in the function graph.
prediction_tf = lambda inputs: jax2tf.convert(model_jax)(params_vars, inputs)

my_model = tf.Module()
# Tell the model saver what the variables are.
my_model._variables = tf.nest.flatten(params_vars)
my_model.f = tf.function(prediction_tf, jit_compile=True, autograph=False)
tf.saved_model.save(my_model)

For more information about jax2tf, see JAX and Cloud TPU interoperation.

Use the Inference Converter to prepare the saved model for serving

Instructions for using the Inference Converter are described in the Inference converter guide.

Use TensorFlow Serving

Instructions for using TensorFlow Serving are described in TensorFlow serving.

JAX model serving examples

Prerequisites

  1. Set up your Docker credentials and pull the Inference Converter and Cloud TPU Serving Docker image:

    sudo usermod -a -G docker ${USER}
    newgrp docker
    gcloud auth configure-docker \
       us-docker.pkg.dev
    docker pull us-docker.pkg.dev/cloud-tpu-images/inference/tpu-inference-converter-cli:2.13.0
    docker pull us-docker.pkg.dev/cloud-tpu-images/inference/tf-serving-tpu:2.13.0
    
  2. Connect to your TPU VM with SSH and install the inference demo code:

    gsutil -m cp -r \
    "gs://cloud-tpu-inference-public/demo" \
    .
    
  3. Install the JAX demo dependencies:

    pip install -r ./demo/jax/requirements.txt
    

Serve the JAX BERT model for inference

You can download the pretrained BERT model from Hugging Face.

  1. Export a TPU-compatible TensorFlow saved model from a Flax BERT model:

    cd demo/jax/bert
    python3 export_bert_model.py
    
  2. Start the Cloud TPU model server container:

    docker run -t --rm --privileged -d \
      -p 8500:8500 -p 8501:8501 \
      --mount type=bind,source=/tmp/jax/bert_tpu,target=/models/bert \
      -e MODEL_NAME=bert \
      us-docker.pkg.dev/cloud-tpu-images/inference/tf-serving-tpu:2.13.0
    

    About 30 seconds after the container is started, check the model server container log and make sure the gRPC and HTTP servers are up:

    CONTAINER_ID=$(docker ps | grep "tf-serving-tpu" | awk '{print $1}')
    docker logs ${CONTAINER_ID}
    

    If you see a log entry ending with the following information, the server is ready to serve requests.

    2023-04-08 00:43:10.481682: I tensorflow_serving/model_servers/server.cc:409] Running gRPC ModelServer at 0.0.0.0:8500 ...
    [warn] getaddrinfo: address family for nodename not supported
    2023-04-08 00:43:10.520578: I tensorflow_serving/model_servers/server.cc:430] Exporting HTTP/REST API at:localhost:8501 ...
    [evhttp_server.cc : 245] NET_LOG: Entering the event loop ...
    
  3. Send an inference request to the model server.

    python3 bert_request.py
    

    The output will be similar to the following:

    For input "The capital of France is [MASK].", the result is ". the capital of france is paris.."
    For input "Hello my name [MASK] Jhon, how can I [MASK] you?", the result is ". hello my name is jhon, how can i help you?."
    
  4. Clean up.

    Make sure to clean up the Docker container before running other demos.

    CONTAINER_ID=$(docker ps | grep "tf-serving-tpu" | awk '{print $1}')
    docker stop ${CONTAINER_ID}
    

    Clean up the model artifacts:

    sudo rm -rf /tmp/jax/
    

Serve the JAX Stable Diffusion for inference

You can download pretrained Stable Diffusion model from Hugging Face.

  1. Download the Stable Diffusion model in a TPU-compatible TF2 saved model format:

    cd demo/jax/stable_diffusion
    python3 export_stable_diffusion_model.py
    
  2. Start the Cloud TPU model server container for the model:

    docker run -t --rm --privileged -d \
      -p 8500:8500 -p 8501:8501 \
      --mount type=bind,source=/tmp/jax/stable_diffusion_tpu,target=/models/stable_diffusion \
      -e MODEL_NAME=stable_diffusion \
      us-docker.pkg.dev/cloud-tpu-images/inference/tf-serving-tpu:2.13.0
    

    After about two minutes, check the model server container log to make sure the gRPC and HTTP servers are running:

    CONTAINER_ID=$(docker ps | grep "tf-serving-tpu" | awk '{print $1}')
    docker logs ${CONTAINER_ID}
    

    If you see the log ending with the following information, it means the servers are ready to serve requests.

    2023-04-08 00:43:10.481682: I tensorflow_serving/model_servers/server.cc:409] Running gRPC ModelServer at 0.0.0.0:8500 ...
    [warn] getaddrinfo: address family for nodename not supported
    2023-04-08 00:43:10.520578: I tensorflow_serving/model_servers/server.cc:430] Exporting HTTP/REST API at:localhost:8501 ...
    [evhttp_server.cc : 245] NET_LOG: Entering the event loop ...
    
  3. Send a request to the model server.

    python3 stable_diffusion_request.py
    

    This script sends "Painting of a squirrel skating in New York" as the prompt. The output image will be saved as stable_diffusion_images.jpg in your current directory.

  4. Clean up.

    Make sure to clean up the Docker container before running other demos.

    CONTAINER_ID=$(docker ps | grep "tf-serving-tpu" | awk '{print $1}')
    docker stop ${CONTAINER_ID}
    

    Clean up the model artifacts

    sudo rm -rf /tmp/jax/
    

TensorFlow Serving

The following instructions demonstrate how you can serve your TensorFlow model on TPU VMs.

TensorFlow serving workflow

  1. Download the TensorFlow Serving Docker image for your TPU VM.

    Set sample environment variables

    export YOUR_LOCAL_MODEL_PATH=model-path
    export MODEL_NAME=model-name
    # Note: this image name may change later.
    export IMAGE_NAME=us-docker.pkg.dev/cloud-tpu-images/inference/tf-serving-tpu:2.13.0
    

    Download the Docker image

    docker pull ${IMAGE_NAME}
    
  2. Set up the Docker credentials and pull the Inference Converter and TensorFlow Serving Docker image.

    sudo usermod -a -G docker ${USER}
    newgrp docker
    gcloud auth configure-docker \
       us-docker.pkg.dev
    docker pull us-docker.pkg.dev/cloud-tpu-images/inference/tpu-inference-converter-cli:2.13.0
    docker pull us-docker.pkg.dev/cloud-tpu-images/inference/tf-serving-tpu:2.13.0
    
  3. Download the demo code:

    gsutil -m cp -r \
    "gs://cloud-tpu-inference-public/demo" \
    .
    
  4. Install the TensorFlow demo dependencies:

    pip install -r ./demo/tf/requirements.txt
    
  5. Serve your TensorFlow model using the TensorFlow Serving Docker image on your TPU VM.

    # PORT 8500 is for gRPC model server and 8501 is for HTTP/REST model server.
    docker run -t --rm --privileged -d \
      -p 8500:8500 -p 8501:8501 \
      --mount type=bind,source=${YOUR_LOCAL_MODEL_PATH},target=/models/${MODEL_NAME} \
      -e MODEL_NAME=${MODEL_NAME} \
      ${IMAGE_NAME}
    
  6. Use the Serving Client API to query your model.

Run TensorFlow ResNet-50 Serving demo

  1. Export a TPU-compatible TF2 saved model from the Keras ResNet-50 model.

    cd demo/tf/resnet-50
    python3 export_resnet_model.py
    
  2. Launch the TensorFlow model server container for the model.

    docker run -t --rm --privileged -d \
      -p 8500:8500 -p 8501:8501 \
      --mount type=bind,source=/tmp/tf/resnet_tpu,target=/models/resnet \
      -e MODEL_NAME=resnet \
      us-docker.pkg.dev/cloud-tpu-images/inference/tf-serving-tpu:2.13.0
    

    Check the model server container log and make sure the gRPC and HTTP Server is up:

    CONTAINER_ID=$(docker ps | grep "tf-serving-tpu" | awk '{print $1}')
    docker logs ${CONTAINER_ID}
    

    If you see the log ending with the following information, it means the server is ready to serve requests. It takes around 30 seconds.

    2023-04-08 00:43:10.481682: I tensorflow_serving/model_servers/server.cc:409] Running gRPC ModelServer at 0.0.0.0:8500 ...
    [warn] getaddrinfo: address family for nodename not supported
    2023-04-08 00:43:10.520578: I tensorflow_serving/model_servers/server.cc:430] Exporting HTTP/REST API at:localhost:8501 ...
    [evhttp_server.cc : 245] NET_LOG: Entering the event loop ...
    
  3. Send the request to the model server.

    The request image is a banana from https://i.imgur.com/j9xCCzn.jpeg .

    python3 resnet_request.py
    

    The output will be similar to the following:

    Predict result: [[('n07753592', 'banana', 0.94921875), ('n03532672', 'hook', 0.022338867), ('n07749582', 'lemon', 0.005126953)]]
    
  4. Clean up.

    Make sure to clean up the Docker container before running other demos.

    CONTAINER_ID=$(docker ps | grep "tf-serving-tpu" | awk '{print $1}')
    docker stop ${CONTAINER_ID}
    

    Clean up the model artifacts:

    sudo rm -rf /tmp/tf/
    

PyTorch model inference and serving

For models written with PyTorch, the workflow is:

  1. Write a Python model handler for loading and inferencing using TorchDynamo and PyTorch/XLA
  2. Use TorchModelArchiver to create a model archive
  3. Use TorchServe to serve the model

TorchDynamo and PyTorch/XLA

TorchDynamo (Dynamo) is a Python-level JIT compiler designed to make PyTorch programs faster. It provides a clean API for compiler backends to hook into. It dynamically modifies Python bytecode just before execution. In the PyTorch/XLA 2.0 release, there is an experimental backend for inference and training using Dynamo.

Dynamo provides a Torch FX (FX) graph when it recognizes a model pattern and PyTorch/XLA uses a lazy tensor approach to compile the FX graph and return the compiled function. For more information about Dynamo, see:

Here is a small code example of running densenet161 inference with torch.compile.

import torch
import torchvision
import torch_xla.core.xla_model as xm

def eval_model(loader):
  device = xm.xla_device()
  xla_densenet161 = torchvision.models.densenet161().to(device)
  xla_densenet161.eval()
  dynamo_densenet161 = torch.compile(
      xla_densenet161, backend='torchxla_trace_once')
  for data, _ in loader:
    output = dynamo_densenet161(data)

TorchServe

You can use the provided torchserve-tpu Docker image for serving your archived pytorch model on a TPU VM.

Set up authentication for Docker:

sudo usermod -a -G docker ${USER}
newgrp docker
gcloud auth configure-docker \
    us-docker.pkg.dev

Pull the Cloud TPU TorchServe Docker image to your TPU VM:

CLOUD_TPU_TORCHSERVE_IMAGE_URL=us-docker.pkg.dev/cloud-tpu-images/inference/torchserve-tpu:v0.9.0-2.1
docker pull ${CLOUD_TPU_TORCHSERVE_IMAGE_URL}

Collect model artifacts

To get started, you need to provide a model handler, which instructs the TorchServe model server worker to load your model, process the input data and run inference. You can use the TorchServe default inference handlers (source), or develop your own custom model handler following the base_handler.py. You might also need to provide the trained model, and the model definition file.

In the following Densenet 161 example, we use model artifacts and the default image classifier handler provided by TorchServe:

  1. Configure some environment variables:

    CWD="$(pwd)"
    
    WORKDIR="${CWD}/densenet_161"
    
    mkdir -p ${WORKDIR}/model-store
    mkdir -p ${WORKDIR}/logs
    
  2. Download and copy model artifacts from the TorchServe image classifier example:

    git clone https://github.com/pytorch/serve.git
    
    cp ${CWD}/serve/examples/image_classifier/densenet_161/model.py ${WORKDIR}
    cp ${CWD}/serve/examples/image_classifier/index_to_name.json ${WORKDIR}
    
  3. Download the model weights:

    wget https://download.pytorch.org/models/densenet161-8d451a50.pth -O densenet161-8d451a50.pth
    
    mv densenet161-8d451a50.pth ${WORKDIR}
    
  4. Create a TorchServe model config file to use the Dynamo backend:

    echo 'pt2: "torchxla_trace_once"' >> ${WORKDIR}/model_config.yaml
    

    You should see the following files and directories:

    >> ls ${WORKDIR}
    model_config.yaml
    index_to_name.json
    logs
    model.py
    densenet161-8d451a50.pth
    model-store
    

Generate a model archive file

To serve your PyTorch model with Cloud TPU TorchServe, you need to package your model handler and all your model artifacts into a model archive file (*.mar) using Torch Model Archiver.

Generate a model archive file with torch-model-archiver:

MODEL_NAME=Densenet161

docker run \
    --privileged  \
    --shm-size 16G \
    --name torch-model-archiver \
    -it \
    -d \
    --rm \
    --mount type=bind,source=${WORKDIR},target=/home/model-server/ \
    ${CLOUD_TPU_TORCHSERVE_IMAGE_URL} \
    torch-model-archiver \
        --model-name ${MODEL_NAME} \
        --version 1.0 \
        --model-file model.py \
        --serialized-file densenet161-8d451a50.pth \
        --handler image_classifier \
        --export-path model-store \
        --extra-files index_to_name.json \
        --config-file model_config.yaml

You should see the model archive file generated in the model-store directory:

>> ls ${WORKDIR}/model-store
Densenet161.mar

Serve inference requests

Now you have the model archive file, you can start the TorchServe model server and serve inference requests.

  1. Start the TorchServe model server:

    docker run \
        --privileged  \
        --shm-size 16G \
        --name torchserve-tpu \
        -it \
        -d \
        --rm \
        -p 7070:7070 \
        -p 7071:7071 \
        -p 8080:8080 \
        -p 8081:8081 \
        -p 8082:8082 \
        -p 9001:9001 \
        -p 9012:9012 \
        --mount type=bind,source=${WORKDIR}/model-store,target=/home/model-server/model-store \
        --mount type=bind,source=${WORKDIR}/logs,target=/home/model-server/logs \
        ${CLOUD_TPU_TORCHSERVE_IMAGE_URL} \
        torchserve \
            --start \
            --ncs \
            --models ${MODEL_NAME}.mar \
            --ts-config /home/model-server/config.properties
    
  2. Query model server health:

    curl http://localhost:8080/ping
    

    If the model server is up and running, you will see:

    {
      "status": "Healthy"
    }
    

    To query the default versions of the current registered model use:

    curl http://localhost:8081/models
    

    You should see the registered model:

    {
      "models": [
        {
          "modelName": "Densenet161",
          "modelUrl": "Densenet161.mar"
        }
      ]
    }
    

    To download an image for inference use:

    curl -O https://raw.githubusercontent.com/pytorch/serve/master/docs/images/kitten_small.jpg
    
    mv kitten_small.jpg ${WORKDIR}
    

    To send an inference request to the model server use:

    curl http://localhost:8080/predictions/${MODEL_NAME} -T ${WORKDIR}/kitten_small.jpg
    

    You should see a response similar to the following:

    {
      "tabby": 0.47878125309944153,
      "lynx": 0.20393909513950348,
      "tiger_cat": 0.16572578251361847,
      "tiger": 0.061157409101724625,
      "Egyptian_cat": 0.04997897148132324
    }
    
  3. Model server logs

    Use the following commands to access the logs:

    ls ${WORKDIR}/logs/
    cat ${WORKDIR}/logs/model_log.log
    

    You should see the following message in your log:

    "Compiled model with backend torchxla\_trace\_once"
    

Clean up

Stop the Docker container:

rm -rf serve
rm -rf ${WORKDIR}

docker stop torch-model-archiver
docker stop torchserve-tpu

Profiling

After setting up inference, you can use profilers to analyze the performance and TPU utilization. For more information about profiling, see: