Serve an LLM using multi-host TPUs on GKE with Saxml


This tutorial shows you how to deploy and serve a large language model (LLM) using multi-host TPU slice node pool on Google Kubernetes Engine (GKE) with Saxml for efficient scalable architecture.

Background

Saxml is an experimental system that serves Paxml, JAX, and PyTorch frameworks. You can use TPUs to accelerate data processing with these frameworks. To demo the deployment of TPUs in GKE, this tutorial serves the 175B LmCloudSpmd175B32Test test model. GKE deploys this test model on two v5e TPU slice node pools with 4x8 topology respectively.

To properly deploy the test model, the TPU topology has been defined based on the size of the model. Given that the N billion 16 bit model approximately requires around 2 times (2xN) GB of memory, the 175B LmCloudSpmd175B32Test model requires about 350 GB of memory. The TPU v5e single TPU chip has 16 GB. To support 350 GB, GKE needs 21 v5e TPU chips (350/16= 21). Based on the mapping of TPU configuration, the proper TPU configuration for this tutorial is:

  • Machine type: ct5lp-hightpu-4t
  • Topology: 4x8 (32 number of TPU chips)

Selecting the right TPU topology for serving a model is important when deploying TPUs in GKE. To learn more, see Plan your TPU configuration.

Objectives

This tutorial is intended for MLOps or DevOps engineers or platform administrators that want to use GKE orchestration capabilities for serving data models.

This tutorial covers the following steps:

  1. Prepare your environment with a GKE Standard cluster. The cluster has two v5e TPU slice node pools with 4x8 topology.
  2. Deploy Saxml. Saxml needs an administrator server, a group of Pods that work as the model server, a prebuilt HTTP server, and a load balancer.
  3. Use the Saxml to serve the LLM.

The following diagram shows the architecture that the following tutorial implements:

Architecture of a multi-host TPU on GKE.
Figure: Example architecture of a multi-host TPU on GKE.

Before you begin

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  • Make sure that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Enable the API

  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  • Make sure that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Enable the API

  • Make sure that you have the following role or roles on the project: roles/container.admin, roles/iam.serviceAccountAdmin

    Check for the roles

    1. In the Google Cloud console, go to the IAM page.

      Go to IAM
    2. Select the project.
    3. In the Principal column, find all rows that identify you or a group that you're included in. To learn which groups you're included in, contact your administrator.

    4. For all rows that specify or include you, check the Role column to see whether the list of roles includes the required roles.

    Grant the roles

    1. In the Google Cloud console, go to the IAM page.

      Go to IAM
    2. Select the project.
    3. Click Grant access.
    4. In the New principals field, enter your user identifier. This is typically the email address for a Google Account.

    5. In the Select a role list, select a role.
    6. To grant additional roles, click Add another role and add each additional role.
    7. Click Save.

Prepare the environment

  1. In the Google Cloud console, start a Cloud Shell instance:
    Open Cloud Shell

  2. Set the default environment variables:

      gcloud config set project PROJECT_ID
      export PROJECT_ID=$(gcloud config get project)
      export REGION=COMPUTE_REGION
      export ZONE=COMPUTE_ZONE
      export GSBUCKET=PROJECT_ID-gke-bucket
    

    Replace the following values:

Create a GKE Standard cluster

Use Cloud Shell to do the following:

  1. Create a Standard cluster that uses Workload Identity Federation for GKE:

    gcloud container clusters create saxml \
        --zone=${ZONE} \
        --workload-pool=${PROJECT_ID}.svc.id.goog \
        --cluster-version=VERSION \
        --num-nodes=4
    

    Replace the VERSION with the GKE version number. GKE supports TPU v5e in version 1.27.2-gke.2100 and later. For more information, see TPU availability in GKE.

    The cluster creation might take several minutes.

  2. Create the first node pool named tpu1:

    gcloud container node-pools create tpu1 \
        --zone=${ZONE} \
        --num-nodes=8 \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=4x8 \
        --cluster=saxml
    
  3. Create the second node pool named tpu2:

    gcloud container node-pools create tpu2 \
        --zone=${ZONE} \
        --num-nodes=8 \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=4x8 \
        --cluster=saxml
    

You have created the following resources:

  • A Standard cluster with four CPU nodes.
  • Two v5e TPU slice node pools with 4x8 topology. Each node pools represent eight TPU slice nodes with 4 TPU chips each.

The 175B model has to be served on a multi-host v5e TPU slice with 4x8 topology slice (32 v5e TPU chips) at minimum.

Create a Cloud Storage bucket

Create a Cloud Storage bucket to store Saxml administrator server configurations. A running administrator server periodically saves its state and the details of the published models.

In Cloud Shell, run the following:

gcloud storage buckets create gs://${GSBUCKET}

Configure your workloads access using Workload Identity Federation for GKE

Assign a Kubernetes ServiceAccount to the application and configure that Kubernetes ServiceAccount to act as an IAM service account.

  1. Configure kubectl to communicate with your cluster:

    gcloud container clusters get-credentials saxml --zone=${ZONE}
    
  2. Create a Kubernetes ServiceAccount for your application to use:

    kubectl create serviceaccount sax-sa --namespace default
    
  3. Create an IAM service account for your application:

    gcloud iam service-accounts create sax-iam-sa
    
  4. Add an IAM policy binding for your IAM service account to read and write to Cloud Storage:

    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
      --member "serviceAccount:sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com" \
      --role roles/storage.admin
    
  5. Allow the Kubernetes ServiceAccount to impersonate the IAM service account by adding an IAM policy binding between the two service accounts. This binding allows the Kubernetes ServiceAccount to act as the IAM service account, so that the Kubernetes ServiceAccount can read and write to Cloud Storage.

    gcloud iam service-accounts add-iam-policy-binding sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com \
      --role roles/iam.workloadIdentityUser \
      --member "serviceAccount:${PROJECT_ID}.svc.id.goog[default/sax-sa]"
    
  6. Annotate the Kubernetes service account with the email address of the IAM service account. This lets your sample app know which service account to use to access Google Cloud services. So when the app uses any standard Google API Client Libraries to access Google Cloud services, it uses that IAM service account.

    kubectl annotate serviceaccount sax-sa \
      iam.gke.io/gcp-service-account=sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com
    

Deploy Saxml

In this section, you deploy the Saxml administrator server and the Saxml model server.

Deploy the Saxml administrator server

  1. Create the following sax-admin-server.yaml manifest:

    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: sax-admin-server
    spec:
      replicas: 1
      selector:
        matchLabels:
          app: sax-admin-server
      template:
        metadata:
          labels:
            app: sax-admin-server
        spec:
          hostNetwork: false
          serviceAccountName: sax-sa
          containers:
          - name: sax-admin-server
            image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server:v1.1.0
            securityContext:
              privileged: true
            ports:
            - containerPort: 10000
            env:
            - name: GSBUCKET
              value: BUCKET_NAME

    Replace the BUCKET_NAME with the name of your Cloud Storage bucket name.

  2. Apply the manifest:

    kubectl apply -f sax-admin-server.yaml
    
  3. Verify that the administrator server Pod is up and running:

    kubectl get deployment
    

    The output is similar to the following:

    NAME               READY   UP-TO-DATE   AVAILABLE   AGE
    sax-admin-server   1/1     1            1           52s
    

Deploy Saxml model server

Workloads running in multi-host TPU slices require a stable network identifier for each Pod to discover peers in the same TPU slice. To define these identifiers, use IndexedJob, StatefulSet with a headless Service or JobSet which automatically creates a headless Service for all the Jobs that belong to the JobSet. The following section shows how to manage multiple groups of model server Pods with JobSet.

  1. Install JobSet v0.2.3 or later.

    kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/JOBSET_VERSION/manifests.yaml
    

    Replace the JOBSET_VERSION with the JobSet version. For example, v0.2.3.

  2. Validate JobSet controller is running in the jobset-system namespace:

    kubectl get pod -n jobset-system
    

    The output is similar to the following:

    NAME                                        READY   STATUS    RESTARTS   AGE
    jobset-controller-manager-69449d86bc-hp5r6   2/2     Running   0          2m15s
    
  3. Deploy two model servers in two TPU slice node pools. Save the following sax-model-server-set manifest:

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: sax-model-server-set
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: sax-model-server
          replicas: 2
          template:
            spec:
              parallelism: 8
              completions: 8
              backoffLimit: 0
              template:
                spec:
                  serviceAccountName: sax-sa
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 4x8
                  containers:
                  - name: sax-model-server
                    image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server:v1.1.0
                    args: ["--port=10001","--sax_cell=/sax/test", "--platform_chip=tpuv5e"]
                    ports:
                    - containerPort: 10001
                    - containerPort: 8471
                    securityContext:
                      privileged: true
                    env:
                    - name: SAX_ROOT
                      value: "gs://BUCKET_NAME/sax-root"
                    - name: MEGASCALE_NUM_SLICES
                      value: ""
                    resources:
                      requests:
                        google.com/tpu: 4
                      limits:
                        google.com/tpu: 4

    Replace the BUCKET_NAME with the name of your Cloud Storage bucket name.

    In this manifest:

    • replicas: 2 is the number of Job replicas. Each job represents a model server. Therefore, a group of 8 Pods.
    • parallelism: 8 and completions: 8 are equal to the number of nodes in each node pool.
    • backoffLimit: 0 must be zero to mark the Job as failed if any Pod fails.
    • ports.containerPort: 8471 is the default port for the VMs communication
    • name: MEGASCALE_NUM_SLICES unsets the environment variable because GKE isn't running Multislice training.
  4. Apply the manifest:

    kubectl apply -f sax-model-server-set.yaml
    
  5. Verify the status of the Saxml Admin Server and Model Server Pods:

    kubectl get pods
    

    The output is similar to the following:

    NAME                                              READY   STATUS    RESTARTS   AGE
    sax-admin-server-557c85f488-lnd5d                 1/1     Running   0          35h
    sax-model-server-set-sax-model-server-0-0-nj4sm   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-1-sl8w4   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-2-hb4rk   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-3-qv67g   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-4-pzqz6   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-5-nm7mz   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-6-7br2x   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-7-4pw6z   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-0-8mlf5   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-1-h6z6w   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-2-jggtv   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-3-9v8kj   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-4-6vlb2   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-5-h689p   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-6-bgv5k   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-7-cd6gv   1/1     Running   0          24m
    

In this example, there are 16 model server containers: sax-model-server-set-sax-model-server-0-0-nj4sm and sax-model-server-set-sax-model-server-1-0-8mlf5 are the two primary model servers in each group.

Your Saxml cluster has two model servers deployed on two v5e TPU slice node pools with 4x8 topology respectively.

Deploy Saxml HTTP Server and load balancer

  1. Use the following prebuilt image HTTP server image. Save the following sax-http.yaml manifest:

    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: sax-http
    spec:
      replicas: 1
      selector:
        matchLabels:
          app: sax-http
      template:
        metadata:
          labels:
            app: sax-http
        spec:
          hostNetwork: false
          serviceAccountName: sax-sa
          containers:
          - name: sax-http
            image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-http:v1.0.0
            ports:
            - containerPort: 8888
            env:
            - name: SAX_ROOT
              value: "gs://BUCKET_NAME/sax-root"
    ---
    apiVersion: v1
    kind: Service
    metadata:
      name: sax-http-lb
    spec:
      selector:
        app: sax-http
      ports:
      - protocol: TCP
        port: 8888
        targetPort: 8888
      type: LoadBalancer

    Replace the BUCKET_NAME with the name of your Cloud Storage bucket name.

  2. Apply the sax-http.yaml manifest:

    kubectl apply -f sax-http.yaml
    
  3. Wait for the HTTP Server container to finish creating:

    kubectl get pods
    

    The output is similar to the following:

    NAME                                              READY   STATUS    RESTARTS   AGE
    sax-admin-server-557c85f488-lnd5d                 1/1     Running   0          35h
    sax-http-65d478d987-6q7zd                         1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-0-nj4sm   1/1     Running   0          24m
    ...
    
  4. Wait for the Service to have an external IP address assigned:

    kubectl get svc
    

    The output is similar to the following:

    NAME           TYPE           CLUSTER-IP    EXTERNAL-IP   PORT(S)          AGE
    sax-http-lb    LoadBalancer   10.48.11.80   10.182.0.87   8888:32674/TCP   7m36s
    

Use Saxml

Load, deploy, and serve the model on the Saxml in the v5e TPU multihost slice:

Load the model

  1. Retrieve the load balancer IP address for Saxml.

    LB_IP=$(kubectl get svc sax-http-lb -o jsonpath='{.status.loadBalancer.ingress[*].ip}')
    PORT="8888"
    
  2. Load the LmCloudSpmd175B test model in two v5e TPU slice node pools:

    curl --request POST \
    --header "Content-type: application/json" \
    -s ${LB_IP}:${PORT}/publish --data \
    '{
        "model": "/sax/test/spmd",
        "model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
        "checkpoint": "None",
        "replicas": 2
    }'
    

    The test model does not have a fine-tuned checkpoint, the weights are randomly generated. The model loading could take up to 10 minutes.

    The output is similar to the following:

    {
        "model": "/sax/test/spmd",
        "path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
        "checkpoint": "None",
        "replicas": 2
    }
    
  3. Check the model readiness:

    kubectl logs sax-model-server-set-sax-model-server-0-0-nj4sm
    

    The output is similar to the following:

    ...
    loading completed.
    Successfully loaded model for key: /sax/test/spmd
    

    The model is fully loaded.

  4. Get information about the model:

    curl --request GET \
    --header "Content-type: application/json" \
    -s ${LB_IP}:${PORT}/listcell --data \
    '{
        "model": "/sax/test/spmd"
    }'
    

    The output is similar to the following:

    {
    "model": "/sax/test/spmd",
    "model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
    "checkpoint": "None",
    "max_replicas": 2,
    "active_replicas": 2
    }
    

Serve the model

Serve a prompt request:

curl --request POST \
--header "Content-type: application/json" \
-s ${LB_IP}:${PORT}/generate --data \
'{
  "model": "/sax/test/spmd",
  "query": "How many days are in a week?"
}'

The output shows an example of the model response. This response might not be meaningful because the test model has random weights.

Unpublish the model

Run the following command to unpublish the model:

curl --request POST \
--header "Content-type: application/json" \
-s ${LB_IP}:${PORT}/unpublish --data \
'{
    "model": "/sax/test/spmd"
}'

The output is similar to the following:

{
  "model": "/sax/test/spmd"
}

Clean up

To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.

Delete the deployed resources

  1. Delete the cluster you created for this tutorial:

    gcloud container clusters delete saxml --zone ${ZONE}
    
  2. Delete the service account:

    gcloud iam service-accounts delete sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com
    
  3. Delete the Cloud Storage bucket:

    gcloud storage rm -r gs://${GSBUCKET}
    

What's next