Getting started with JAX multi-node applications with NVIDIA GPUs on Google Kubernetes Engine
Software Engineer, NVIDIA
Cloud Solutions Architect
JAX is a rapidly growing Python library for high-performance numerical computing and machine learning (ML) research. With applications in large language models, drug discovery, physics ML, reinforcement learning, and neural graphics, JAX has seen incredible adoption in the past few years. JAX offers numerous benefits for developers and researchers, including an easy-to-use NumPy API, auto differentiation and optimization. JAX also includes support for distributed processing across multi-node and multi-GPU systems in a few lines of code, with accelerated performance through XLA-optimized kernels on NVIDIA GPUs.
We show how to run JAX multi-GPU-multi-node applications on GKE (Google Kubernetes Engine) using the A2 ultra machine series, powered by NVIDIA A100 80GB Tensor Core GPUs. It runs a simple Hello World application on 4 nodes with 8 processes and 8 GPUs each.
Install gcloud and setup your environment by running
gcloud initand following the prompts
Setup a GKE cluster
Clone the repository
Enable the required APIs
Create a default VPC (if it doesn’t already exist)
Create a cluster (the control nodes). Replace
us-central1-c by your preferred zone.
Create a pool (the compute nodes).
--enable-fast-socket --enable-gvnic is required for multi-node performance.
--preemptible removes the need for quotas but makes the node preemptible. Remove the flag if this is not desirable. Replace
us-central1-c by your preferred zone. This might take a few minutes.
Install the NVIDIA CUDA driver on the compute nodes
Build and push the container to your registry. This will push a container to
gcr.io/<your project>/jax/hello:latest. This might take a few minutes.
<<PROJECT>> by your GCP project name.
Run the JAX application on the compute nodes. This will create 32 pods (8 per nodes), each running one JAX process on one NVIDIA GPU.
to check the status. This will change from
Pending (after a few minutes),
Running and finally
Once the job has completed, use
kubectl logs to see the output from one pod
The application creates an array of length 1 equal to
[1.0] on each process and then reduces them all. The output, on 32 processes, should be
[32.0] on each process.
Congratulations! You just ran JAX on 32 NVIDIA A100 GPUs in GKE. Next, learn how to run inference at scale with TensorRT on NVIDIA T4 GPUs.
Special thanks to Jarek Kazmierczak, Google Machine Learning Solution Architect and Iris Liu, NVIDIA System Software Engineer for their expertise and guidance on this blog post.