Serve an LLM using TPUs on GKE with JetStream and PyTorch

This guide shows you how to serve a large language model (LLM) using Tensor Processing Units (TPUs) on Google Kubernetes Engine (GKE) with JetStream through PyTorch. In this guide, you download model weights to Cloud Storage and deploy them on a GKE Autopilot or Standard cluster using a container that runs JetStream.

If you need the scalability, resilience, and cost-effectiveness offered by Kubernetes features when deploying your model on JetStream, this guide is a good starting point.

This guide is intended for Generative AI customers who use PyTorch, new or existing users of GKE, ML Engineers, MLOps (DevOps) engineers, or platform administrators who are interested in using Kubernetes container orchestration capabilities for serving LLMs.

Background

By serving an LLM using TPUs on GKE with JetStream, you can build a robust, production-ready serving solution with all the benefits of managed Kubernetes, including cost-efficiency, scalability and higher availability. This section describes the key technologies used in this tutorial.

About TPUs

TPUs are Google's custom-developed application-specific integrated circuits (ASICs) used to accelerate machine learning and AI models built using frameworks such as TensorFlow, PyTorch, and JAX.

Before you use TPUs in GKE, we recommend that you complete the following learning path:

  1. Learn about current TPU version availability with the Cloud TPU system architecture.
  2. Learn about TPUs in GKE.

This tutorial covers serving various LLM models. GKE deploys the model on single-host TPUv5e nodes with TPU topologies configured based on the model requirements for serving prompts with low latency.

About JetStream

JetStream is an open source inference serving framework developed by Google. JetStream enables high-performance, high-throughput, and memory-optimized inference on TPUs and GPUs. JetStream provides advanced performance optimizations, including continuous batching, KV cache optimizations, and quantization techniques, to facilitate LLM deployment. JetStream enables PyTorch/XLA and JAX TPU serving to achieve optimal performance.

Continuous Batching

Continuous batching is a technique that dynamically groups incoming inference requests into batches, reducing latency and increasing throughput.

KV cache quantization

KV cache quantization involves compressing the key-value cache used in attention mechanisms, reducing memory requirements.

Int8 weight quantization

Int8 weight quantization reduces the precision of model weights from 32-bit floating point to 8-bit integers, leading to faster computation and reduced memory usage.

To learn more about these optimizations, refer to the JetStream PyTorch and JetStream MaxText project repositories.

About PyTorch

PyTorch is an open source machine learning framework developed by Meta and now part of the Linux Foundation umbrella. PyTorch provides high-level features such as tensor computation and deep neural networks.