Getting started with JAX multi-node applications with NVIDIA GPUs on Google Kubernetes Engine
Leopold Cambier
Software Engineer, NVIDIA
Roberto Barbero
Cloud Solutions Architect
JAX is a rapidly growing Python library for high-performance numerical computing and machine learning (ML) research. With applications in large language models, drug discovery, physics ML, reinforcement learning, and neural graphics, JAX has seen incredible adoption in the past few years. JAX offers numerous benefits for developers and researchers, including an easy-to-use NumPy API, auto differentiation and optimization. JAX also includes support for distributed processing across multi-node and multi-GPU systems in a few lines of code, with accelerated performance through XLA-optimized kernels on NVIDIA GPUs.
We show how to run JAX multi-GPU-multi-node applications on GKE (Google Kubernetes Engine) using the A2 ultra machine series, powered by NVIDIA A100 80GB Tensor Core GPUs. It runs a simple Hello World application on 4 nodes with 8 processes and 8 GPUs each.
Prerequisites
Install gcloud and setup your environment by running
gcloud init
and following the promptsInstall docker and login into the Google Container Registry using gcloud credentials helper
Install kubectl and the kubectl authentication plugin for GCP
Setup a GKE cluster
Clone the repository
Enable the required APIs
Create a default VPC (if it doesn’t already exist)
Create a cluster (the control nodes). Replace us-central1-c
by your preferred zone.
Create a pool (the compute nodes). --enable-fast-socket --enable-gvnic
is required for multi-node performance. --preemptible
removes the need for quotas but makes the node preemptible. Remove the flag if this is not desirable. Replace us-central1-c
by your preferred zone. This might take a few minutes.
Install the NVIDIA CUDA driver on the compute nodes
Build and push the container to your registry. This will push a container to gcr.io/<your project>/jax/hello:latest
. This might take a few minutes.
In kubernetes/job.yaml
and kubernetes/kustomization.yaml
, change <<PROJECT>>
by your GCP project name.
Run JAX
Run the JAX application on the compute nodes. This will create 32 pods (8 per nodes), each running one JAX process on one NVIDIA GPU.
Use
to check the status. This will change from ContainerCreating
to Pending
(after a few minutes), Running
and finally Completed
.
Once the job has completed, use kubectl
logs to see the output from one pod
The application creates an array of length 1 equal to [1.0]
on each process and then reduces them all. The output, on 32 processes, should be [32.0]
on each process.
Congratulations! You just ran JAX on 32 NVIDIA A100 GPUs in GKE. Next, learn how to run inference at scale with TensorRT on NVIDIA T4 GPUs.
Special thanks to Jarek Kazmierczak, Google Machine Learning Solution Architect and Iris Liu, NVIDIA System Software Engineer for their expertise and guidance on this blog post.