This tutorial shows you how to serve a Gemma large language model (LLM) using Tensor Processing Units (TPUs) on Google Kubernetes Engine (GKE). You deploy a pre-built container with JetStream and MaxText to GKE. You also configure GKE to load the Gemma 7B weights from Cloud Storage at runtime.
This tutorial is intended for Machine learning (ML) engineers, Platform admins and operators, and for Data and AI specialists who are interested in using Kubernetes container orchestration capabilities for serving LLMs. To learn more about common roles and example tasks that we reference in Google Cloud content, see Common GKE user roles and tasks.
Before reading this page, ensure that you're familiar with the following:
- Autopilot mode and Standard mode
- Current TPU version availability with the Cloud TPU system architecture
- TPUs in GKE
Background
This section describes the key technologies used in this tutorial.
Gemma
Gemma is a set of openly available, lightweight, generative artificial intelligence (AI) models released under an open license. These AI models are available to run in your applications, hardware, mobile devices, or hosted services. You can use the Gemma models for text generation, however you can also tune these models for specialized tasks.
To learn more, see the Gemma documentation.
TPUs
TPUs are Google's custom-developed application-specific integrated circuits (ASICs) used to accelerate machine learning and AI models built using frameworks such as TensorFlow, PyTorch, and JAX.
This tutorial covers serving the Gemma 7B model. GKE deploys the model on single-host TPUv5e nodes with TPU topologies configured based on the model requirements for serving prompts with low latency.
JetStream
JetStream is an open source inference serving framework developed by Google. JetStream enables high-performance, high-throughput, and memory-optimized inference on TPUs and GPUs. It provides advanced performance optimizations, including continuous batching and quantization techniques, to facilitate LLM deployment. JetStream enables PyTorch/XLA and JAX TPU serving to achieve optimal performance.
To learn more about these optimizations, refer to the JetStream PyTorch and JetStream MaxText project repositories.
MaxText
MaxText is a performant, scalable, and adaptable JAX LLM implementation, built on open source JAX libraries such as Flax, Orbax, and Optax. MaxText's decoder-only LLM implementation is written in Python. It leverages the XLA compiler heavily to achieve high performance without needing to build custom kernels.
To learn more about the latest models and parameter sizes that MaxText supports, see the MaxtText project repository.