Serve Stable Diffusion XL (SDXL) using TPUs on GKE with MaxDiffusion

This tutorial shows you how to serve a SDXL image generation model using Tensor Processing Units (TPUs) on Google Kubernetes Engine (GKE) with MaxDiffusion. In this tutorial, you download the model from Hugging Face and deploy it on a Autopilot or Standard cluster using a container that runs MaxDiffusion.

This guide is a good starting point if you need the granular control, customization, scalability, resilience, portability, and cost-effectiveness of managed Kubernetes when deploying and serving your AI/ML workloads. If you need a unified managed AI platform to rapidly build and serve ML models cost effectively, we recommend that you try our Vertex AI deployment solution.

Background

By serving SDXL using TPUs on GKE with MaxDiffusion, you can build a robust, production-ready serving solution with all the benefits of managed Kubernetes, including cost-efficiency, scalability and higher availability. This section describes the key technologies used in this tutorial.

Stable Diffusion XL (SDXL)

Stable Diffusion XL (SDXL) is a type of latent diffusion model (LDM) supported by MaxDiffusion for inference. For generative AI, you can use LDMs to generate high-quality images from text descriptions. LDMs are useful for applications such as image search and image captioning.

SDXL supports single or multi-host inference with sharding annotations. This lets SDXL be trained and run across multiple machines, which can improve efficiency.

To learn more, see the Generative Models by Stability AI repository and the SDXL paper.

TPUs

TPUs are Google's custom-developed application-specific integrated circuits (ASICs) used to accelerate machine learning and AI models built using frameworks such as TensorFlow, PyTorch, and JAX.

Before you use TPUs in GKE, we recommend that you complete the following learning path:

  1. Learn about current TPU version availability with the Cloud TPU system architecture.
  2. Learn about TPUs in GKE.

This tutorial covers serving the SDXL model. GKE deploys the model on single-host TPU v5e nodes with TPU topologies configured based on the model requirements for serving prompts with low latency. In this guide, the model uses a TPU v5e chip with a 1x1 topology.

MaxDiffusion

MaxDiffusion is a collection of reference implementations, written in Python and Jax, of various latent diffusion models that run on XLA devices, including TPUs and GPUs. MaxDiffusion is a starting point for Diffusion projects for both research and production.

To learn more, refer to the MaxDiffusion repository.