Serve LLMs using multi-host TPUs on GKE with JetStream and Pathways

This guide shows you how to serve state-of-the-art large language models (LLMs) such as Llama 3.1 405B on Google Kubernetes Engine (GKE) using tensor processing units (TPUs) across multiple nodes.

This guide demonstrates how to use portable open-source technologies—Kubernetes, JetStream, Pathways on Cloud, and the LeaderWorkerSet (LWS) API—to deploy and serve AI/ML workloads on GKE, by taking advantage of GKE's granular control, scalability, resilience, portability, and cost-effectiveness.

Background

Large language models have grown in size and they no longer fit on a single host TPU slice. For ML inference, you can use Pathways on Cloud to run large scale multi-host inference on GKE across multiple interconnected TPU nodes. In this guide, you will walk through how to provision a GKE cluster with the multi-host TPU slices and use the Pathways on Cloud binaries, launch the JetStream Server with MaxText framework, and make multi-host inference requests.

By serving an LLM using TPUs on GKE with JetStream, MaxText, and Pathways, you can build a robust, production-ready serving solution with all the benefits of managed Kubernetes, including cost-efficiency, scalability, and higher availability. This section describes the key technologies used in this tutorial.

About TPUs

TPUs are Google's custom-developed application-specific integrated circuits (ASICs) used to accelerate machine learning and AI models that are built using frameworks such as TensorFlow, PyTorch, and JAX.

Before you use TPUs in GKE, we recommend that you complete the following learning path:

  1. Learn about current TPU version availability with the Cloud TPU system architecture.
  2. Learn about TPUs in GKE.

This tutorial covers serving the Llama 3.1-405B model. GKE deploys the model on multi-host TPU v6e nodes with TPU topologies that are configured based on the model requirements for serving prompts with low latency.

Pathways on Cloud

Pathways is a large-scale orchestration layer for accelerators. Pathways is explicitly designed to enable exploration of new systems and ML research ideas, while retaining state of the art performance for current models. Pathways enables a single JAX client process to coordinate computation across one or more large TPU slices, streamlining ML computations that span hundreds or thousands of TPU chips.

JetStream

JetStream is an open source inference serving framework developed by Google. JetStream enables high-performance, high-throughput, and memory-optimized inference on TPUs and GPUs. JetStream provides advanced performance optimizations, including continuous batching, KV cache optimizations, and quantization techniques, to facilitate LLM deployment. JetStream enables PyTorch/XLA and JAX TPU serving to to optimize performance.

MaxText

MaxText is a performant, scalable, and adaptable JAX LLM implementation, built on open source JAX libraries such as Flax, Orbax, and Optax. MaxText's decoder-only LLM implementation is written in Python. It leverages the XLA compiler heavily to achieve high performance without needing to build custom kernels.

For more information about the latest models and parameter sizes that MaxText supports, see the MaxText project repository.

Llama 3.1 405B

Llama 3.1 405B is a large language model by Meta that's designed for a range of natural language processing tasks, including text generation, translation, and question answering. GKE offers the infrastructure required to support the distributed training and serving needs of models of this scale.

For more information, see the Llama documentation.

Architecture

This section describes the GKE architecture used in this tutorial. The architecture includes a GKE Standard cluster that provisions TPUs and hosts JetStream and Pathways components to deploy and serve the model.

The following diagram shows you the components of this architecture:

Architecture of GKE cluster with multi-host TPU node pool containing the JetStream and Pathways components.

This architecture includes the following components:

  • A GKE Standard regional cluster.
  • A multi-host TPU slice node pool that host the JetStream deployment and Pathways components.
  • The Pathways resource manager manages accelerator resources and coordinates allocation of accelerators for user jobs.
  • The Pathways client coordinates with the Pathways resource manager to determine where the compiled programs are placed for execution.
  • The Pathways worker runs and performs computations on accelerator machines, and sends data back to your workload over the IFRT proxy server.
  • The IFRT proxy client implements the OSS Interim Framework Runtime (IFRT) API and acts as the communication bridge between your workload and Pathways components.
  • The IFRT proxy server receives requests from the IFRT proxy client and forwards them to the Pathways client, distributing the work.
  • The JetStream-Pathways container provides a JAX-based inference server that receives inference requests and delegates its execution processes to the Pathways workers
  • The Service component spreads inbound traffic to all JetStream HTTP replicas.
  • JetStream HTTP is an HTTP server which accepts requests as a wrapper to JetStream's required format and sends it to JetStream's GRPC client.