Serve Gemma open models using TPUs on Vertex AI Prediction with Saxml

This guide shows you how to serve a Gemma open models large language model (LLM) using Tensor Processing Units (TPUs) on Vertex AI Prediction with Saxml. In this guide, you download the 2B and 7B parameter instruction tuned Gemma models to Cloud Storage and deploy them on Vertex AI Prediction that runs Saxml on TPUs.

Background

By serving Gemma using TPUs on Vertex AI Prediction with Saxml. You can take advantage of a managed AI solution that takes care of low level infrastructure and offers a cost effective way for serving LLMs. This section describes the key technologies used in this tutorial.

Gemma

Gemma is a set of openly available, lightweight, and generative artificial intelligence (AI) models released under an open license. These AI models are available to run in your applications, hardware, mobile devices, or hosted services. You can use the Gemma models for text generation, however you can also tune these models for specialized tasks.

To learn more, see the Gemma documentation.

Saxml

Saxml is an experimental system that serves Paxml, JAX, and PyTorch models for inference. For the sake of this tutorial we'll cover how to serve Gemma on TPUs that are more cost efficient for Saxml. Setup for GPUs is similar. Saxml offers scripts to build containers for Vertex AI Prediction that we are going to use in this tutorial.

TPUs

TPUs are Google's custom-developed application-specific integrated circuits (ASICs) used to accelerate data processing frameworks such as TensorFlow, PyTorch, and JAX.

This tutorial serves the Gemma 2B and Gemma 7B models. Vertex AI Prediction hosts these models on the following single-host TPU v5e node pools:

  • Gemma 2B: Hosted in a TPU v5e node pool with 1x1 topology that represents one TPU chip. The machine type for the nodes is ct5lp-hightpu-1t.
  • Gemma 7B: Hosted in a TPU v5e node pool with 2x2 topology that represents four TPU chips. The machine type for the nodes is ct5lp-hightpu-4t.

Before you begin

  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  3. Make sure that billing is enabled for your Google Cloud project.

  4. Enable the Vertex AI API.

    Enable the API

  5. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  6. Make sure that billing is enabled for your Google Cloud project.

  7. Enable the Vertex AI API.

    Enable the API

  8. In the Google Cloud console, activate Cloud Shell.

    Activate Cloud Shell

    At the bottom of the Google Cloud console, a Cloud Shell session starts and displays a command-line prompt. Cloud Shell is a shell environment with the Google Cloud CLI already installed and with values already set for your current project. It can take a few seconds for the session to initialize.

This tutorial assumes that you are using Cloud Shell to interact with Google Cloud. If you want to use a different shell instead of Cloud Shell, then perform the following additional configuration:

  1. Install the Google Cloud CLI.
  2. To initialize the gcloud CLI, run the following command:

    gcloud init
  3. Ensure that you have sufficient quota for TPU v5e chips for Vertex AI Prediction. By default, this quota is 0. For a 1x1 topology, it must be 1. For 2x2, it must be 4. To run both topologies, it must be 5.
  4. Create a Kaggle account, if you don't already have one.

Get access to the model

Note that Cloud Shell might not have sufficient resources to download model weights. If so, you can create a Vertex AI Workbench instance for performing that task.

To get access to the Gemma models for deployment to Vertex AI Prediction, you must sign in to the Kaggle platform, sign the license consent agreement, and get a Kaggle API token. In this tutorial, you use a Kubernetes Secret for the Kaggle credentials.

You must sign the consent agreement to use Gemma. Follow these instructions:

  1. Access the model consent page on Kaggle.com.
  2. Sign in to Kaggle if you haven't done so already.
  3. Click Request Access.
  4. In the Choose Account for Consent section, select Verify via Kaggle Account to use your Kaggle account for consent.
  5. Accept the model Terms and Conditions.

Generate an access token

To access the model through Kaggle, you need a Kaggle API token.

Follow these steps to generate a new token if you don't have one already:

  1. In your browser, go to Kaggle settings.
  2. Under the API section, click Create New Token.

    A file named kaggle.json is downloaded.

Upload the access token to Cloud Shell

In Cloud Shell, you can upload the Kaggle API token to your Google Cloud project:

  1. In Cloud Shell, click More > Upload.
  2. Select File and click Choose Files.
  3. Open the kaggle.json file.
  4. Click Upload.

Create the Cloud Storage bucket

Create Cloud Storage bucket to store the model checkpoints.

In Cloud Shell, run the following:

gcloud storage buckets create gs://CHECKPOINTS_BUCKET_NAME

Replace the CHECKPOINTS_BUCKET_NAME with the name of the Cloud Storage bucket that stores the model checkpoints.

Copy model to Cloud Storage bucket

In Cloud Shell, run the following:

pip install kaggle --break-system-packages

# For Gemma 2B
mkdir -p /data/gemma_2b-it
kaggle models instances versions download google/gemma/pax/2b-it/1 --untar -p /data/gemma_2b-it
gcloud storage cp /data/gemma_2b-it/* gs://CHECKPOINTS_BUCKET_NAME/gemma_2b-it/ --recursive

# For Gemma 7B
mkdir -p /data/gemma_7b-it
kaggle models instances versions download google/gemma/pax/7b-it/1 --untar -p /data/gemma_7b-it
gcloud storage cp /data/gemma_7b-it/* gs://CHECKPOINTS_BUCKET_NAME/gemma_7b-it/ --recursive

Deploying the model

Upload a model

To upload a Model resource that uses your Saxml container, run the following gcloud ai models upload command:

Gemma 2B-it

gcloud ai models upload \
  --region=LOCATION \
  --display-name=DEPLOYED_MODEL_NAME \
  --container-image-uri=us-docker.pkg.dev/vertex-ai/prediction/sax-tpu:latest \
  --artifact-uri='gs://CHECKPOINTS_BUCKET_NAME/gemma_2b-it/' \
  --container-args='--model_path=saxml.server.pax.lm.params.gemma.Gemma2BFP16' \
  --container-args='--platform_chip=tpuv5e' \
  --container-args='--platform_topology=2x2' \
  --container-args='--ckpt_path_suffix=checkpoint_00000000' \
  --container-ports=8502

Gemma 7B-it

gcloud ai models upload \
  --region=LOCATION \
  --display-name=DEPLOYED_MODEL_NAME \
  --container-image-uri=us-docker.pkg.dev/vertex-ai/prediction/sax-tpu:latest \
  --artifact-uri='gs://CHECKPOINTS_BUCKET_NAME/gemma_7b-it/' \
  --container-args='--model_path=saxml.server.pax.lm.params.gemma.Gemma7BFP16' \
  --container-args='--platform_chip=tpuv5e' \
  --container-args='--platform_topology=2x2' \
  --container-args='--ckpt_path_suffix=checkpoint_00000000' \
  --container-ports=8502

Replace the following:

  • PROJECT_ID: the ID of your Google Cloud project
  • LOCATION_ID: The region where you are using Vertex AI. Note that TPUs are only available in us-west1.
  • DEPLOYED_MODEL_NAME: A name for the DeployedModel. You can use the display name of the Model for the DeployedModel as well.

Create an endpoint

You must deploy the model to an endpoint before the model can be used to serve online predictions. If you are deploying a model to an existing endpoint, you can skip this step. The following example uses the gcloud ai endpoints create command:

gcloud ai endpoints create \
  --region=LOCATION \
  --display-name=ENDPOINT_NAME

Replace the following:

  • LOCATION_ID: The region where you are using Vertex AI.
  • ENDPOINT_NAME: The display name for the endpoint.

The Google Cloud CLI tool might take a few seconds to create the endpoint.

Deploy the model to endpoint

After the endpoint is ready, deploy the model to the endpoint.

ENDPOINT_ID=$(gcloud ai endpoints list \
   --region=LOCATION \
   --filter=display_name=ENDPOINT_NAME \
   --format="value(name)")

MODEL_ID=$(gcloud ai models list \
   --region=LOCATION \
   --filter=display_name=DEPLOYED_MODEL_NAME \
   --format="value(name)")

gcloud ai endpoints deploy-model $ENDPOINT_ID \
  --region=LOCATION \
  --model=$MODEL_ID \
  --display-name=DEPLOYED_MODEL_NAME \
  --machine-type=ct5lp-hightpu-4t \
  --traffic-split=0=100

Replace the following:

  • LOCATION_ID: The region where you are using Vertex AI.
  • ENDPOINT_NAME: The display name for the endpoint.
  • DEPLOYED_MODEL_NAME: A name for the DeployedModel. You can use the display name of the Model for the DeployedModel as well.

Gemma 2B can be deployed on a smaller ct5lp-hightpu-1t machine, in such case you should specify --platform_topology=1x1 when uploading model.

The Google Cloud CLI tool might take a few minutes to deploy the model to the endpoint. When the model is successfully deployed, this command prints the following output:

  Deployed a model to the endpoint xxxxx. Id of the deployed model: xxxxx.

Getting online predictions from the deployed model

To invoke the model through the Vertex AI Prediction endpoint, format the prediction request by using a standard Inference Request JSON Object .

The following example uses the gcloud ai endpoints predict command:

ENDPOINT_ID=$(gcloud ai endpoints list \
   --region=LOCATION \
   --filter=display_name=ENDPOINT_NAME \
   --format="value(name)")

gcloud ai endpoints predict $ENDPOINT_ID \
  --region=LOCATION \
  --http-headers=Content-Type=application/json \
  --json-request instances.json

Replace the following:

  • LOCATION_ID: The region where you are using Vertex AI.
  • ENDPOINT_NAME: The display name for the endpoint.
  • instances.json has following format: {"instances": [{"text_batch": "<your prompt>"},{...}]}

Cleaning up

To avoid incurring further Vertex AI charges and Artifact Registry charges, delete the Google Cloud resources that you created during this tutorial:

  1. To undeploy model from endpoint and delete the endpoint, run the following command in your shell:

    ENDPOINT_ID=$(gcloud ai endpoints list \
       --region=LOCATION \
       --filter=display_name=ENDPOINT_NAME \
       --format="value(name)")
    
    DEPLOYED_MODEL_ID=$(gcloud ai endpoints describe $ENDPOINT_ID \
       --region=LOCATION \
       --format="value(deployedModels.id)")
    
    gcloud ai endpoints undeploy-model $ENDPOINT_ID \
      --region=LOCATION \
      --deployed-model-id=$DEPLOYED_MODEL_ID
    
    gcloud ai endpoints delete $ENDPOINT_ID \
       --region=LOCATION \
       --quiet
    

    Replace LOCATION with the region where you created your model in a previous section.

  2. To delete your model, run the following command in your shell:

    MODEL_ID=$(gcloud ai models list \
       --region=LOCATION \
       --filter=display_name=DEPLOYED_MODEL_NAME \
       --format="value(name)")
    
    gcloud ai models delete $MODEL_ID \
       --region=LOCATION \
       --quiet
    

    Replace LOCATION with the region where you created your model in a previous section.

Limitations

  • On Vertex AI Prediction Cloud TPUs are supported only in us-west1. For more information, see locations.

What's next

  • Learn how to deploy other Saxml models such as Llama2 and GPT-J.