Jump to Content
Containers & Kubernetes

Getting started with JAX multi-node applications with NVIDIA GPUs on Google Kubernetes Engine

March 22, 2023
Leopold Cambier

Software Engineer, NVIDIA

Roberto Barbero

Cloud Solutions Architect

JAX is a rapidly growing Python library for high-performance numerical computing and machine learning (ML) research. With applications in large language models, drug discovery, physics ML, reinforcement learning, and neural graphics, JAX has seen incredible adoption in the past few years. JAX offers numerous benefits for developers and researchers, including an easy-to-use NumPy API, auto differentiation and optimization. JAX also includes support for distributed processing across multi-node and multi-GPU systems in a few lines of code, with accelerated performance through XLA-optimized kernels on NVIDIA GPUs.

We show how to run JAX multi-GPU-multi-node applications on GKE (Google Kubernetes Engine) using the A2 ultra machine series, powered by NVIDIA A100 80GB Tensor Core  GPUs. It runs a simple Hello World application on 4 nodes with 8 processes and 8 GPUs each.


  1. Install gcloud and setup your environment by running gcloud init and following the prompts

  2. Install docker and login into the Google Container Registry using gcloud credentials helper

  3. Install kubectl and the kubectl authentication plugin for GCP

Setup a GKE cluster

Clone the repository


Enable the required APIs


Create a default VPC (if it doesn’t already exist)


Create a cluster (the control nodes). Replace us-central1-c by your preferred zone.


Create a pool (the compute nodes). --enable-fast-socket --enable-gvnic is required for multi-node performance. --preemptible removes the need for quotas but makes the node preemptible. Remove the flag if this is not desirable. Replace us-central1-c by your preferred zone. This might take a few minutes.


Install the NVIDIA CUDA driver on the compute nodes


Build and push the container to your registry. This will push a container to gcr.io/<your project>/jax/hello:latest. This might take a few minutes.


In kubernetes/job.yaml and kubernetes/kustomization.yaml, change <<PROJECT>> by your GCP project name.


Run the JAX application on the compute nodes. This will create 32 pods (8 per nodes), each running one JAX process on one NVIDIA GPU.




to check the status. This will change from ContainerCreating to Pending (after a few minutes), Running and finally Completed.

Once the job has completed, use kubectl logs to see the output from one pod


The application creates an array of length 1 equal to [1.0] on each process and then reduces them all. The output, on 32 processes, should be [32.0] on each process.

Congratulations! You just ran JAX on 32 NVIDIA A100 GPUs in GKE. Next, learn how to run inference at scale with TensorRT on NVIDIA T4 GPUs.

Special thanks to Jarek Kazmierczak, Google Machine Learning Solution Architect and Iris Liu, NVIDIA System Software Engineer for their expertise and guidance on this blog post.

Posted in