This guide shows you how to serve state-of-the-art large language models (LLMs) such as Llama 3.1 405B on Google Kubernetes Engine (GKE) using tensor processing units (TPUs) across multiple nodes.
This guide demonstrates how to use portable open-source technologies—Kubernetes, JetStream, Pathways on Cloud, and the LeaderWorkerSet (LWS) API—to deploy and serve AI/ML workloads on GKE, by taking advantage of GKE's granular control, scalability, resilience, portability, and cost-effectiveness.
Background
Large language models have grown in size and they no longer fit on a single host TPU slice. For ML inference, you can use Pathways on Cloud to run large scale multi-host inference on GKE across multiple interconnected TPU nodes. In this guide, you will walk through how to provision a GKE cluster with the multi-host TPU slices and use the Pathways on Cloud binaries, launch the JetStream Server with MaxText framework, and make multi-host inference requests.
By serving an LLM using TPUs on GKE with JetStream, MaxText, and Pathways, you can build a robust, production-ready serving solution with all the benefits of managed Kubernetes, including cost-efficiency, scalability, and higher availability. This section describes the key technologies used in this tutorial.
About TPUs
TPUs are Google's custom-developed application-specific integrated circuits (ASICs) used to accelerate machine learning and AI models that are built using frameworks such as TensorFlow, PyTorch, and JAX.
Before you use TPUs in GKE, we recommend that you complete the following learning path:
- Learn about current TPU version availability with the Cloud TPU system architecture.
- Learn about TPUs in GKE.
This tutorial covers serving the Llama 3.1-405B model. GKE deploys the model on multi-host TPU v6e nodes with TPU topologies that are configured based on the model requirements for serving prompts with low latency.
Pathways on Cloud
Pathways is a large-scale orchestration layer for accelerators. Pathways is explicitly designed to enable exploration of new systems and ML research ideas, while retaining state of the art performance for current models. Pathways enables a single JAX client process to coordinate computation across one or more large TPU slices, streamlining ML computations that span hundreds or thousands of TPU chips.
JetStream
JetStream is an open source inference serving framework developed by Google. JetStream enables high-performance, high-throughput, and memory-optimized inference on TPUs and GPUs. JetStream provides advanced performance optimizations, including continuous batching, KV cache optimizations, and quantization techniques, to facilitate LLM deployment. JetStream enables PyTorch/XLA and JAX TPU serving to to optimize performance.
MaxText
MaxText is a performant, scalable, and adaptable JAX LLM implementation, built on open source JAX libraries such as Flax, Orbax, and Optax. MaxText's decoder-only LLM implementation is written in Python. It leverages the XLA compiler heavily to achieve high performance without needing to build custom kernels.
For more information about the latest models and parameter sizes that MaxText supports, see the MaxText project repository.
Llama 3.1 405B
Llama 3.1 405B is a large language model by Meta that's designed for a range of natural language processing tasks, including text generation, translation, and question answering. GKE offers the infrastructure required to support the distributed training and serving needs of models of this scale.
For more information, see the Llama documentation.
Architecture
This section describes the GKE architecture used in this tutorial. The architecture includes a GKE Standard cluster that provisions TPUs and hosts JetStream and Pathways components to deploy and serve the model.
The following diagram shows you the components of this architecture:
This architecture includes the following components:
- A GKE Standard regional cluster.
- A multi-host TPU slice node pool that host the JetStream deployment and Pathways components.
- The
Pathways resource manager
manages accelerator resources and coordinates allocation of accelerators for user jobs. - The
Pathways client
coordinates with thePathways resource manager
to determine where the compiled programs are placed for execution. - The
Pathways worker
runs and performs computations on accelerator machines, and sends data back to your workload over the IFRT proxy server. - The
IFRT proxy client
implements the OSS Interim Framework Runtime (IFRT) API and acts as the communication bridge between your workload and Pathways components. - The
IFRT proxy server
receives requests from theIFRT proxy client
and forwards them to thePathways client
, distributing the work. - The
JetStream-Pathways
container provides a JAX-based inference server that receives inference requests and delegates its execution processes to thePathways workers
- The Service component spreads inbound traffic to all
JetStream HTTP
replicas. JetStream HTTP
is an HTTP server which accepts requests as a wrapper to JetStream's required format and sends it to JetStream's GRPC client.