Configure KubeRay with TPU Trillium

This tutorial shows you how to configure KubeRay with TPU Trillium on Google Kubernetes Engine (GKE). Learn to set up both single-host and multi-host TPU configurations, including necessary environment variables and Pod specifications for TPU Trillium.

This tutorial is for Platform admins and operators and Data and AI specialists who want to want to learn how to configure TPU Trillium initialization with KubeRay for single-host and multi-host node pools. This tutorial demonstrates how to run a script with Jax that verifies successful TPU initialization. This tutorial doesn't deploy a model.

Before you configure KubeRay in GKE, ensure that you are familiar with Ray definitions and terminology in GKE.

Overview

This tutorial shows how to run a Python script with Jax that verifies that TPU Trillium initialization with KubeRay was successful. Jax is a high-performance numerical computation library that supports machine learning workloads. KubeRay is a Kubernetes operator that provides a unified way to deploy, manage, and monitor Ray applications on Kubernetes.

Trillium TPUs (v6e) require specific environment variables and Pod specifications that differ from previous TPU generations. This tutorial provides the necessary configurations to successfully deploy a workload with KubeRay on Trillium TPUs.

Before you begin

Before you start, make sure that you have performed the following tasks:

  • Enable the Google Kubernetes Engine API.
  • Enable Google Kubernetes Engine API
  • If you want to use the Google Cloud CLI for this task, install and then initialize the gcloud CLI. If you previously installed the gcloud CLI, get the latest version by running gcloud components update.
  • Ensure you have the Ray CLI (version 2.37.0) installed.

Activate Cloud Shell

Cloud Shell comes preinstalled with the gcloud, helm, and kubectl command-line tools that are used in this tutorial.

  1. Go to the Google Cloud console.
  2. At the top of the Google Cloud console window, click the Activate Cloud Shell Activate Shell Button button.

    A Cloud Shell session opens inside a new frame in the Google Cloud console and displays a command-line prompt.

    Cloud Shell session

Create a GKE cluster and node pool

You can configure KubeRay on TPUs in a GKE Autopilot or Standard cluster. We recommend that you use a Autopilot cluster for a fully managed Kubernetes experience. To choose the GKE mode of operation that's the best fit for your workloads, see About GKE modes of operation.

Autopilot

  1. In Cloud Shell, run the following command:

    gcloud container clusters create-auto CLUSTER_NAME \
        --enable-ray-operator \
        --release-channel=rapid \
        --location=LOCATION
    

    Replace the following:

    • CLUSTER_NAME: the name of the new cluster.
    • LOCATION: the region where your TPU Trillium capacity is available. For more information, see TPU availability in GKE.

    GKE creates an Autopilot cluster with the Ray operator addon enabled. The addon automatically installs the Ray TPU webhook in the cluster control plane.

  2. To communicate with your cluster, configure kubectl :

    gcloud container clusters get-credentials CLUSTER_NAME --location=LOCATION
    

Standard

  1. In Cloud Shell, create a Standard cluster that enables the Ray operator addon by running the following command to :

    gcloud container clusters create CLUSTER_NAME \
      --location LOCATION \
      --addons=RayOperator \
      --cluster-version=1.33 \
      --machine-type=n1-standard-16
    

    Replace the following:

    • CLUSTER_NAME: the name of the new cluster.
    • LOCATION: the region where your TPU Trillium capacity is available. For more information, see TPU availability in GKE.

    The cluster creation might take several minutes.

  2. To communicate with your cluster, configure kubectl :

    gcloud container clusters get-credentials CLUSTER_NAME --location=LOCATION
    
  3. You can create a single-host or a multi-host TPU slice node pool:

Single-host

In Cloud Shell, run the following command:

gcloud container node-pools create v6e-4 \
    --location=us-central2-b \
    --cluster=CLUSTER_NAME \
    --machine-type=ct6e-standard-4t \
    --num-nodes=1 \
    --threads-per-core=1 \
    --tpu-topology=2x2

Multi-host

In Cloud Shell, run the following command:

gcloud container node-pools create v6e-16 \
    --location=us-central2-b \
    --cluster=CLUSTER_NAME \
    --machine-type=ct6e-standard-4t \
    --num-nodes=4 \
    --threads-per-core=1 \
    --tpu-topology=4x4

Run a RayJob custom resource

By defining a RayJob manifest, you instruct KubeRay to do the following:

  • Create a RayCluster: the RayJob spec includes a rayClusterSpec which defines the Ray cluster configuration (head and worker groups) that you want.
  • Run a specific Job: the entrypoint field within the RayJob specifies the command or script to execute within the created Ray cluster. In this tutorial, the entrypoint is a Python script (tpu_list_devices.py) designed to verify the TPU Trillium initialization.

To create a RayJob custom resource, complete the following steps:

Single-host

  1. Create the following ray-job.tpu-v6e-singlehost.yaml manifest:

    apiVersion: ray.io/v1
    kind: RayJob
    metadata:
      name: v6e-4-job
    spec:
      entrypoint: python ai-ml/gke-ray/tpu/tpu_list_devices.py
      runtimeEnvYAML: |
        working_dir: "https://github.com/GoogleCloudPlatform/kubernetes-engine-samples/archive/refs/heads/main.zip"
        pip:
          - jax[tpu]==0.4.33
          - -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      rayClusterSpec:
        rayVersion: '2.43.0'
        headGroupSpec:
          rayStartParams: {}
          template:
            spec:
              containers:
              -   name: ray-head
                  image: rayproject/ray:2.43.0-py310
                  ports:
                    - containerPort: 6379
                      name: gcs-server
                    - containerPort: 8265
                      name: dashboard
                    - containerPort: 10001
                      name: client
                  resources:
                    limits:
                      cpu: "8"
                      memory: 40G
                    requests:
                      cpu: "8"
                      memory: 40G
        workerGroupSpecs:
        -   replicas: 1
            minReplicas: 1
            maxReplicas: 1
            numOfHosts: 1
            groupName: tpu-group
            rayStartParams: {}
            template:
              spec:
                containers:
                -   name: ray-worker
                    image: rayproject/ray:2.43.0-py310
                    resources:
                      limits:
                        cpu: "24"
                        google.com/tpu: "4"
                        memory: 200G
                      requests:
                        cpu: "24"
                        google.com/tpu: "4"
                        memory: 200G
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: 2x2
  2. Apply the manifest:

    kubectl apply -f ray-job.tpu-v6e-singlehost.yaml
    
  3. Verify that the RayJob is created and running:

    kubectl get rayjobs v6e-4-job
    

    The output is similar to the following:

    NAME      JOB STATUS   DEPLOYMENT STATUS   RAY CLUSTER NAME       START TIME  END TIME   AGE
    v6e-4-job PENDING      Running             v6e-4-job-raycluster   2024-10-15T23:15:22Z  20s
    
  4. Print the output of the RayJob.

    kubectl logs -l=job-name=v6e-4-job
    

    The output is similar to the following:

    2024-10-15 16:15:40,222 INFO cli.py:300 -- ray job stop v6e-4-job-hzq5q
    2024-10-15 16:15:40,246 INFO cli.py:307 -- Tailing logs until the job exits (disable with --no-wait):
    2024-10-15 16:15:40,112 INFO job_manager.py:528 -- Runtime env is setting up.
    2024-10-15 16:15:50,181 INFO worker.py:1461 -- Using address 10.84.1.25:6379 set in the environment variable RAY_ADDRESS
    2024-10-15 16:15:50,181 INFO worker.py:1601 -- Connecting to existing Ray cluster at address: 10.84.1.25:6379...
    2024-10-15 16:15:50,186 INFO worker.py:1777 -- Connected to Ray cluster. View the dashboard at 10.84.1.25:8265
    ['TPU cores:4']
    2024-10-15 16:16:12,349 SUCC cli.py:63 -- -------------------------------------
    2024-10-15 16:16:12,349 SUCC cli.py:64 -- Job 'v6e-4-job-hzq5q' succeeded
    2024-10-15 16:16:12,349 SUCC cli.py:65 -- -------------------------------------
    

Multi-host

  1. Create the following ray-job.tpu-v6e-multihost.yaml manifest:

    apiVersion: ray.io/v1
    kind: RayJob
    metadata:
      name: v6e-16-job
    spec:
      entrypoint: python ai-ml/gke-ray/tpu/tpu_list_devices.py
      runtimeEnvYAML: |
        working_dir: "https://github.com/GoogleCloudPlatform/kubernetes-engine-samples/archive/refs/heads/main.zip"
        pip:
          - jax[tpu]==0.4.33
          - -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      rayClusterSpec:
        rayVersion: '2.43.0'
        headGroupSpec:
          rayStartParams: {}
          template:
            spec:
              containers:
              -   name: ray-head
                  image: rayproject/ray:2.43.0-py310
                  ports:
                    - containerPort: 6379
                      name: gcs-server
                    - containerPort: 8265
                      name: dashboard
                    - containerPort: 10001
                      name: client
                  resources:
                    limits:
                      cpu: "8"
                      memory: 40G
                    requests:
                      cpu: "8"
                      memory: 40G
        workerGroupSpecs:
          - replicas: 1
            minReplicas: 1
            maxReplicas: 1
            numOfHosts: 4
            groupName: tpu-group
            rayStartParams: {}
            template:
              spec:
                containers:
                  - name: ray-worker
                    image: rayproject/ray:2.43.0-py310
                    resources:
                      limits:
                        cpu: "24"
                        google.com/tpu: "4"
                        memory: 200G
                      requests:
                        cpu: "24"
                        google.com/tpu: "4"
                        memory: 200G
                    env:
                    - name: NODE_IP
                      valueFrom:
                        fieldRef:
                          fieldPath: status.hostIP
                    - name: VBAR_CONTROL_SERVICE_URL
                      value: $(NODE_IP):8353
                    - name: JAX_PLATFORMS
                      value: tpu,cpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                    ports:
                    - containerPort: 8081
                      name: mxla
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: 4x4
  2. Apply the manifest:

    kubectl apply -f ray-job.tpu-v6e-multihost.yaml
    
  3. Verify the v6e-16 RayJob is created and running:

    kubectl get rayjobs v6e-16-job
    

    The output is similar to the following:

    NAME         JOB STATUS   DEPLOYMENT STATUS   RAY CLUSTER NAME              START TIME             END TIME   AGE
    v6e-16-job                Running             v6e-16-job-raycluster-qr6vk   2024-10-16T19:28:19Z              66s
    
  4. Print the output of the v6e-16 RayJob:

    kubectl logs -l=job-name=v6e-16-job
    

    The output is similar to the following:

    2024-10-16 12:21:33,986 INFO cli.py:300 -- ray job stop v6e-16-job-z44s7
    2024-10-16 12:21:34,011 INFO cli.py:307 -- Tailing logs until the job exits (disable with --no-wait):
    2024-10-16 12:21:33,826 INFO job_manager.py:528 -- Runtime env is setting up.
    2024-10-16 12:21:46,327 INFO worker.py:1461 -- Using address 10.84.1.61:6379 set in the environment variable RAY_ADDRESS
    2024-10-16 12:21:46,327 INFO worker.py:1601 -- Connecting to existing Ray cluster at address: 10.84.1.61:6379...
    2024-10-16 12:21:46,333 INFO worker.py:1777 -- Connected to Ray cluster. View the dashboard at 10.84.1.61:8265
    ['TPU cores:16', 'TPU cores:16', 'TPU cores:16', 'TPU cores:16']
    2024-10-16 12:22:12,156 SUCC cli.py:63 -- ---------------------------------
    2024-10-16 12:22:12,156 SUCC cli.py:64 -- Job 'v6e-16-job-z44s7' succeeded
    2024-10-16 12:22:12,156 SUCC cli.py:65 -- ---------------------------------
    

View the RayJob in the Ray Dashboard

Verify that GKE created the RayCluster service, and also connect to the RayCluster instance.

Single-host

  1. Retrieve the name of the generated RayCluster for the RayJob:

    export RAYCLUSTER_NAME=$(kubectl get rayjob v6e-4-job -o jsonpath='{.status.rayClusterName}')
    
  2. Retrieve the name of the RayCluster head service:

    export HEAD_SVC=$(kubectl get svc -l ray.io/cluster=$RAYCLUSTER_NAME,ray.io/node-type=head -o jsonpath='{.items[0].metadata.name}')
    
  3. Connect to the Ray Dashboard by port-forwarding the head service:

    kubectl port-forward svc/$HEAD_SVC 8265:8265 2>&1 >/dev/null &
    
  4. Open a web browser and enter the following URL:

    http://localhost:8265/#/jobs
    
  5. View the RayJob status and relevant logs.

Multi-host

  1. Retrieve the name of the generated RayCluster for the RayJob:

    export RAYCLUSTER_NAME=$(kubectl get rayjob v6e-16-job -o jsonpath='{.status.rayClusterName}')
    
  2. Retrieve the name of the RayCluster head service:

    export HEAD_SVC=$(kubectl get svc -l ray.io/cluster=$RAYCLUSTER_NAME,ray.io/node-type=head -o jsonpath='{.items[0].metadata.name}')
    
  3. Connect to the Ray Dashboard by port-forwarding the head service:

    kubectl port-forward svc/$HEAD_SVC 8265:8265 2>&1 >/dev/null &
    
  4. Open a web browser and enter the following URL:

    http://localhost:8265/#/jobs
    
  5. View the RayJob status and relevant logs.

Ray sets a TPU-{accelerator}-Head resource to identify the Ray worker node that corresponds to the TPU_WORKER_ID=0 value. In the multi-host TPU group, the Ray node with TPU_WORKER_ID=0 has TPU-v6e-16-head: 1.0 set in its resources. This TPU_WORKER_ID environment variable is set by a mutating GKE webhook for KubeRay.

Clean up

After you complete the tutorial, to prevent unwanted charges incurring on your account, delete the RayJob:

Single-host

kubectl delete rayjobs v6e-4-job

Multi-host

kubectl delete rayjobs v6e-16-job

What's next