Serve Stable Diffusion XL (SDXL) using TPUs on GKE with MaxDiffusion


This tutorial shows you how to serve a SDXL image generation model using Tensor Processing Units (TPUs) on Google Kubernetes Engine (GKE) with MaxDiffusion. In this tutorial, you download the model from Hugging Face and deploy it on a Autopilot or Standard cluster using a container that runs MaxDiffusion.

This guide is a good starting point if you need the granular control, customization, scalability, resilience, portability, and cost-effectiveness of managed Kubernetes when deploying and serving your AI/ML workloads. If you need a unified managed AI platform to rapidly build and serve ML models cost effectively, we recommend that you try our Vertex AI deployment solution.

Background

By serving SDXL using TPUs on GKE with MaxDiffusion, you can build a robust, production-ready serving solution with all the benefits of managed Kubernetes, including cost-efficiency, scalability and higher availability. This section describes the key technologies used in this tutorial.

Stable Diffusion XL (SDXL)

Stable Diffusion XL (SDXL) is a type of latent diffusion model (LDM) supported by MaxDiffusion for inference. For generative AI, you can use LDMs to generate high-quality images from text descriptions. LDMs are useful for applications such as image search and image captioning.

SDXL supports single or multi-host inference with sharding annotations. This lets SDXL be trained and run across multiple machines, which can improve efficiency.

To learn more, see the Generative Models by Stability AI repository and the SDXL paper.

TPUs

TPUs are Google's custom-developed application-specific integrated circuits (ASICs) used to accelerate machine learning and AI models built using frameworks such as TensorFlow, PyTorch, and JAX.

Before you use TPUs in GKE, we recommend that you complete the following learning path:

  1. Learn about current TPU version availability with the Cloud TPU system architecture.
  2. Learn about TPUs in GKE.

This tutorial covers serving the SDXL model. GKE deploys the model on single-host TPU v5e nodes with TPU topologies configured based on the model requirements for serving prompts with low latency. In this guide, the model uses a TPU v5e chip with a 1x1 topology.

MaxDiffusion

MaxDiffusion is a collection of reference implementations, written in Python and Jax, of various latent diffusion models that run on XLA devices, including TPUs and GPUs. MaxDiffusion is a starting point for Diffusion projects for both research and production.

To learn more, refer to the MaxDiffusion repository.

Objectives

This tutorial is intended for generative AI customers who use JAX, new or existing users of SDXL, and any ML Engineers, MLOps (DevOps) engineers, or platform administrators who are interested in using Kubernetes container orchestration capabilities for serving LLMs.

This tutorial covers the following steps:

  1. Create a GKE Autopilot or Standard cluster with the recommended TPU topology, based on the model characteristics.
  2. Build a SDXL inference container image.
  3. Deploy the SDXL inference server on GKE.
  4. Serve an interact with the model through a web app.

Architecture

This section describes the GKE architecture used in this tutorial. The architecture consists of a GKE Autopilot or Standard cluster that provisions TPUs and hosts MaxDiffusion components. GKE uses these components to deploy and serve the models.

The following diagram shows you the components of this architecture:

Example architecture of serving MaxDiffusion with TPU v5e on GKE.

This architecture includes the following components:

  • A GKE Autopilot or Standard regional cluster.
  • One single-host TPU slice node pool that hosts the SDXL model on the MaxDiffusion deployment.
  • The Service component with a load balancer of type ClusterIP. This Service distributes inbound traffic to all MaxDiffusion HTTP replicas.
  • The WebApp HTTP server with an external LoadBalancer Service that distributes inbound traffic and redirect model serving traffic to ClusterIP Service.

Before you begin

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  • Make sure that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Enable the API

  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  • Make sure that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Enable the API

  • Make sure that you have the following role or roles on the project: roles/container.admin, roles/iam.serviceAccountAdmin

    Check for the roles

    1. In the Google Cloud console, go to the IAM page.

      Go to IAM
    2. Select the project.
    3. In the Principal column, find all rows that identify you or a group that you're included in. To learn which groups you're included in, contact your administrator.

    4. For all rows that specify or include you, check the Role colunn to see whether the list of roles includes the required roles.

    Grant the roles

    1. In the Google Cloud console, go to the IAM page.

      Go to IAM
    2. Select the project.
    3. Click Grant access.
    4. In the New principals field, enter your user identifier. This is typically the email address for a Google Account.

    5. In the Select a role list, select a role.
    6. To grant additional roles, click Add another role and add each additional role.
    7. Click Save.
  • Ensure that you have sufficient quota for TPU v5e PodSlice Lite chips. In this tutorial, you use on-demand instances.

Prepare the environment

In this tutorial, you use Cloud Shell to manage resources hosted on Google Cloud. Cloud Shell comes preinstalled with the software you'll need for this tutorial, including kubectl and gcloud CLI.

To set up your environment with Cloud Shell, follow these steps:

  1. In the Google Cloud console, launch a Cloud Shell session by clicking Cloud Shell activation icon Activate Cloud Shell in the Google Cloud console. This launches a session in the bottom pane of Google Cloud console.

  2. Set the default environment variables:

    gcloud config set project PROJECT_ID
    export PROJECT_ID=$(gcloud config get project)
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION_NAME
    export ZONE=ZONE
    

    Replace the following values:

    • PROJECT_ID: your Google Cloud project ID.
    • CLUSTER_NAME: the name of your GKE cluster.
    • REGION_NAME: the region where your GKE cluster, Cloud Storage bucket, and TPU nodes are located. The region contains zones where TPU v5e machine types are available (for example, us-west1, us-west4, us-central1, us-east1, us-east5, or europe-west4).
    • (Standard cluster only) ZONE: the zone where the TPU resources are available (for example, us-west4-a). For Autopilot clusters, you don't need to specify the zone, only the region.
  3. Clone the example repository and open the tutorial directory:

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples
    cd kubernetes-engine-samples/ai-ml/maxdiffusion-tpu 
    WORK_DIR=$(pwd)
    gcloud artifacts repositories create gke-llm --repository-format=docker --location=$REGION
    gcloud auth configure-docker $REGION-docker.pkg.dev
    

Create and configure Google Cloud resources

Follow these instructions to create the required resources.

Create a GKE cluster

You can serve SDXL on TPUs in a GKE Autopilot or Standard cluster. We recommend that you use a Autopilot cluster for a fully managed Kubernetes experience. To choose the GKE mode of operation that's the best fit for your workloads, see Choose a GKE mode of operation.

Autopilot

  1. In Cloud Shell, run the following command:

    gcloud container clusters create-auto ${CLUSTER_NAME} \
      --project=${PROJECT_ID} \
      --region=${REGION} \
      --release-channel=rapid \
      --cluster-version=1.29
    

    GKE creates an Autopilot cluster with CPU and TPU nodes as requested by the deployed workloads.

  2. Configure kubectl to communicate with your cluster:

      gcloud container clusters get-credentials ${CLUSTER_NAME} --location=${REGION}
    

Standard

  1. Create a regional GKE Standard cluster that uses Workload Identity Federation for GKE.

    gcloud container clusters create ${CLUSTER_NAME} \
        --enable-ip-alias \
        --machine-type=n2-standard-4 \
        --num-nodes=2 \
        --workload-pool=${PROJECT_ID}.svc.id.goog \
        --location=${REGION}
    

    The cluster creation might take several minutes.

  2. Run the following command to create a node pool for your cluster:

    gcloud container node-pools create maxdiffusion-tpu-nodepool \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct5lp-hightpu-1t \
      --num-nodes=1 \
      --region=${REGION} \
      --node-locations=${ZONE} \
      --spot
    

    GKE creates a TPU v5e node pool with a 1x1 topology and one node.

    To create node pools with different topologies, learn how to Plan your TPU configuration. Make sure that you update the sample values in this tutorual, such as cloud.google.com/gke-tpu-topology and google.com/tpu.

  3. Configure kubectl to communicate with your cluster:

      gcloud container clusters get-credentials ${CLUSTER_NAME} --location=${REGION}
    

Build the SDXL inference container

Follow these instructions to build a container image for the SDXL inference server.

  1. Open the server/cloudbuild.yaml manifest:

    steps:
    - name: 'gcr.io/cloud-builders/docker'
      args: [ 'build', '-t', '$LOCATION-docker.pkg.dev/$PROJECT_ID/gke-llm/max-diffusion:latest', '.' ]
    images:
    - '$LOCATION-docker.pkg.dev/$PROJECT_ID/gke-llm/max-diffusion:latest'
  2. Execute the build and create inference container image.

    cd $WORK_DIR/build/server
    gcloud builds submit . --region=$REGION
    

    The output contains the path of the container image.

Deploy the SDXL inference server

  1. Explore the serve_sdxl_v5e.yaml manifest.

    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: stable-diffusion-deployment
    spec:
      selector:
        matchLabels:
          app: max-diffusion-server
      replicas: 1  # number of nodes in node-pool
      template:
        metadata:
          labels:
            app: max-diffusion-server
        spec:
          nodeSelector:
            cloud.google.com/gke-tpu-topology: 1x1 #  target topology
            cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
            #cloud.google.com/gke-spot: "true"
          volumes:
          - name: dshm
            emptyDir:
                  medium: Memory
          containers:
          - name: serve-stable-diffusion
            image: REGION-docker.pkg.dev/PROJECT_ID/gke-llm/max-diffusion:latest
            env:
            - name: MODEL_NAME
              value: 'stable_diffusion'
            ports:
            - containerPort: 8000
            resources:
              requests:
                google.com/tpu: 1  # TPU chip request
              limits:
                google.com/tpu: 1  # TPU chip request
            volumeMounts:
                - mountPath: /dev/shm
                  name: dshm
    
    ---
    apiVersion: v1
    kind: Service
    metadata:
      name: max-diffusion-server
      labels:
        app: max-diffusion-server
    spec:
      type: ClusterIP
      ports:
        - port: 8000
          targetPort: 8000
          name: http-max-diffusion-server
          protocol: TCP
      selector:
        app: max-diffusion-server
  2. Update the project ID in the manifest.

    cd $WORK_DIR
    sed -i "s|PROJECT_ID|$PROJECT_ID|g" serve_sdxl_v5e.yaml
    sed -i "s|REGION|$REGION|g" serve_sdxl_v5e.yaml
    
  3. Apply the manifest:

    kubectl apply -f serve_sdxl_v5e.yaml
    

    The output is similar to the following:

    deployment.apps/max-diffusion-server created
    
  4. Verify the status of the model:

    watch kubectl get deploy
    

    The output is similar to the following:

    NAME                          READY   UP-TO-DATE   AVAILABLE   AGE
    stable-diffusion-deployment   1/1     1            1           8m21s
    
  5. Retrieve the ClusterIP address:

    kubectl get service max-diffusion-server
    

    The output contains an ClusterIP field. Make a note of the CLUSTER-IP value.

  6. Validate the Deployment:

     export ClusterIP=CLUSTER_IP
     kubectl run curl --image=curlimages/curl \
        -it --rm --restart=Never \
        -- "$ClusterIP:8000"
    

    Replace the CLUSTER_IP with the CLUSTER-IP value that you noted previously. The output is similar to the following:

    {"message":"Hello world! From FastAPI running on Uvicorn with Gunicorn."}
    pod "curl" deleted
    
  7. View the logs from the Deployment:

    kubectl logs -l app=max-diffusion-server
    

    When the Deployment finishes, the output is similar to the following:

    2024-06-12 15:45:45,459 [INFO] __main__: replicate params:
    2024-06-12 15:45:46,175 [INFO] __main__: start initialized compiling
    2024-06-12 15:45:46,175 [INFO] __main__: Compiling ...
    2024-06-12 15:45:46,175 [INFO] __main__: aot compiling:
    2024-06-12 15:45:46,176 [INFO] __main__: tokenize prompts:2024-06-12 15:48:49,093 [INFO] __main__: Compiled in 182.91802048683167
    INFO:     Started server process [1]
    INFO:     Waiting for application startup.
    INFO:     Application startup complete.
    

Deploy the webapp client

In this section, you deploy the webapp client to serve the SDXL model.

  1. Explore the build/webapp/cloudbuild.yaml manifest.

    steps:
    - name: 'gcr.io/cloud-builders/docker'
      args: [ 'build', '-t', '$LOCATION-docker.pkg.dev/$PROJECT_ID/gke-llm/max-diffusion-web:latest', '.' ]
    images:
    - '$LOCATION-docker.pkg.dev/$PROJECT_ID/gke-llm/max-diffusion-web:latest'
  2. Execute the build and create the client container image under the build/webapp directory.

    cd $WORK_DIR/build/webapp
    gcloud builds submit . --region=$REGION
    

    The output contains the path of the container image.

  3. Open the serve_sdxl_client.yaml manifest:

    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: max-diffusion-client
    spec:
      selector:
        matchLabels:
          app: max-diffusion-client
      template:
        metadata:
          labels:
            app: max-diffusion-client
        spec:
          containers:
          - name: webclient
            image: REGION-docker.pkg.dev/PROJECT_ID/gke-llm/max-diffusion-web:latest
            env:
              - name: SERVER_URL
                value: "http://ClusterIP:8000"
            resources:
              requests:
                memory: "128Mi"
                cpu: "250m"
              limits:
                memory: "256Mi"
                cpu: "500m"
            ports:
            - containerPort: 5000
    ---
    apiVersion: v1
    kind: Service
    metadata:
      name: max-diffusion-client-service
    spec:
      type: LoadBalancer
      selector:
        app: max-diffusion-client
      ports:
      - port: 8080
        targetPort: 5000
  4. Edit the project ID in the manifest:

    cd $WORK_DIR
    sed -i "s|PROJECT_ID|$PROJECT_ID|g" serve_sdxl_client.yaml
    sed -i "s|ClusterIP|$ClusterIP|g" serve_sdxl_client.yaml
    sed -i "s|REGION|$REGION|g" serve_sdxl_client.yaml
    
  5. Apply the manifest:

    kubectl apply -f serve_sdxl_client.yaml
    
  6. Retrieve the LoadBalancer IP address:

    kubectl get service max-diffusion-client-service
    

    The output contains an LoadBalancer field. Make a note of the EXTERNAL-IP value.

Interact with the model by using the web page

  1. Access to the following URL from a web browser:

    http://EXTERNAL_IP:8080
    

    Replace the EXTERNAL_IP with the EXTERNAL_IP value that you noted previously.

  2. Interact with SDXL using the chat interface. Add a prompt and click Submit. For example:

    Create a detailed image of a fictional historical site, capturing its unique architecture and cultural significance
    

The output is a model-generated image similar to the following example:

SDXL-generated image

Clean up

To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.

Delete the project

  1. In the Google Cloud console, go to the Manage resources page.

    Go to Manage resources

  2. In the project list, select the project that you want to delete, and then click Delete.
  3. In the dialog, type the project ID, and then click Shut down to delete the project.

Delete the individual resources

Keep the project and delete the individual resources, as described in the following section. Run the following commands and follow the prompts:

gcloud container clusters delete ${CLUSTER_NAME} --region=${REGION}

What's next