Orchestrate Multislice workloads using JobSet and Kueue


This tutorial shows you how to run a Jax workload using TPU Multislice in Google Kubernetes Engine (GKE) and Kueue. Kueue implements Job queueing, deciding when Jobs should wait and when they should start, based on quotas and a hierarchy for sharing resources fairly among teams.

This tutorial shows how to orchestrate multiple Multislice workloads which require TPU resources to run concurrently.

Before you use TPUs in GKE, we recommend that you complete the following learning path:

  1. Learn about current TPU version availability with the Cloud TPU system architecture.
  2. Learn about TPU Multislice in GKE.

Objectives

This tutorial is intended for GKE administrators who have existing GKE clusters and want to run Multislice workloads for the first time.

This tutorial covers the following steps:

  1. Prepare your environment with a GKE cluster with three v5e TPU slices. Each TPU slice has a 2x4 topology and four chips per host. Therefore, 24 TPU v5e chips in total.
  2. Create the Kueue resources to ensure that quotas are shared fairly between the workloads.
  3. Run your Multislice workload.

Before you begin

Before you start, make sure you have performed the following tasks:

  • Enable the Google Kubernetes Engine API.
  • Enable Google Kubernetes Engine API
  • If you want to use the Google Cloud CLI for this task, install and then initialize the gcloud CLI. If you previously installed the gcloud CLI, get the latest version by running gcloud components update.

Prepare the environment

  1. In the Google Cloud console, start a Cloud Shell instance:
    Open Cloud Shell

  2. Set the default environment variables:

    gcloud config set project PROJECT_ID
    gcloud config set compute/region COMPUTE_REGION
    

    Replace the following values:

Autopilot clusters that run version 1.29.2-gke.1521000 or later enable TPUs by default. TPUs on Autopilot clusters are configured in the workload specification. For more information, see the Define your Multislice workloads with JobSets section.

Create a GKE cluster

In Cloud Shell, create a GKE cluster:

Autopilot

gcloud container clusters create-auto multislice-cluster \
    --location=LOCATION \
    --cluster-version 1.29.2-gke.1521000 \
    --release-channel rapid

Standard

gcloud container clusters create multislice-cluster \
    --location=LOCATION

Replace LOCATION with the location in which you want to create your cluster. Ensure it has capacity for the ct5lp-hightpu-4t machine type. Cluster creation might take several minutes.

If you use GKE Autopilot mode, skip to the Create the Kueue resources section. Autopilot clusters that run version 1.29.2-gke.1521000 or later enable TPUs by default.

Create three Standard mode TPU node pools

  1. Create the first node pool named nodepool1:

    gcloud beta container node-pools create nodepool1 \
        --location=LOCATION \
        --cluster=multislice-cluster \
        --node-locations=NODE_LOCATION \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=2x4 \
        --num-nodes=2 \
        --project=PROJECT_ID
    

    Replace NODE_LOCATION with one or more zones in the cluster region in which you want to create the nodes.

  2. Create the second node pool named nodepool2:

    gcloud beta container node-pools create nodepool2 \
        --location=LOCATION \
        --cluster=multislice-cluster \
        --node-locations=NODE_LOCATION \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=2x4 \
        --num-nodes=2 \
        --project=PROJECT_ID
    
  3. Create the third node pool named nodepool3:

    gcloud beta container node-pools create nodepool3 \
        --location=LOCATION \
        --cluster=multislice-cluster \
        --node-locations=NODE_LOCATION \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=2x4 \
        --num-nodes=2 \
        --project=PROJECT_ID
    

GKE creates three node pools. Each node pool is a separate TPU slice.

Create the Kueue resources

  1. Create the following kueue.yaml manifest:

    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ResourceFlavor
    metadata:
      name: "vlp-24"
    spec:
      nodeLabels:
        cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
        cloud.google.com/gke-tpu-topology: 2x4
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ClusterQueue
    metadata:
      name: "cluster-queue"
    spec:
      namespaceSelector: {}
      queueingStrategy: BestEffortFIFO
      resourceGroups:
      - coveredResources: ["google.com/tpu"]
        flavors:
        - name: "vlp-24"
          resources:
          - name: "google.com/tpu"
            nominalQuota: 24
    
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: LocalQueue
    metadata:
      namespace: default
      name: multislice-queue
    spec:
      clusterQueue: cluster-queue
    
  2. Apply the kueue.yaml manifest:

    kubectl apply -f kueue.yaml
    

    GKE creates the following Kueue resources:

  • ResourceFlavor: An abstraction of the resources in a cluster. In this example, three TPU slices with 2x4 topology and four chips per host, therefore 24 TPU chips.
  • ClusterQueue: A global queue managing workloads and cluster resources.
  • LocalQueue: Groups closely related workloads that are typically run by a single tenant (user). Each LocalQueue points to a ClusterQueue from which resources are allocated to run its workloads. A Kueue Workload is an abstraction representing a batch workload, in this case, each workload is a JobSet.

Define your Multislice workloads with JobSets

In this section, you create three JobSets. These JobSets run a Jax workload which outputs the global number of TPU chips in the slice, then sleeps for 60 seconds to simulate some model training time, then exits.

  1. Create the following jobsets-multislice.yaml manifest:

    Autopilot

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-1slice  
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue  
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                    resources:
                      limits:
                        google.com/tpu: 4
    
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-2slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 2
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-3slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 3
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    

    Standard

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-1slice  
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue  
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                    resources:
                      limits:
                        google.com/tpu: 4
    
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-2slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 2
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-3slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 3
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    
  2. Apply the jobsets-multislice.yaml manifest:

    kubectl apply -f jobsets-multislice.yaml
    

GKE creates the Jobs with the following resource requests:

  • The multislice-1slice JobSet creates one Job that requires one TPU slice in total.
  • The multislice-2slice JobSet creates two Jobs that require two TPU slices in total.
  • The multislice-3slice JobSet creates three Jobs that require three TPU slice in total.

Because the cluster only has three TPU slices, not all JobSets can run at once. When Kueue enqueues all three of multislice-3slice JobSets, its Jobs run alone to completion. The multislice-1slice and multislice-2slice wait and run together afterwards.

Verify Kueue admitted the workloads

  1. Check the enqueued workloads in Kueue:

    kubectl get workloads
    

    The output is similar to the following:

    NAME                             QUEUE              ADMITTED BY     AGE
    jobset-multislice-1slice-2530a   multislice-queue                   3s
    jobset-multislice-2slice-ffb02   multislice-queue                   4s
    jobset-multislice-3slice-8c695   multislice-queue   cluster-queue   10s
    

Kueue enqueues one or more workloads, depending on the TPU resources they require.

Monitor the workloads

  1. Monitor which pods are running:

    kubectl get pods
    

    The output is similar to the following:

    NAME                                READY   STATUS      RESTARTS   AGE
    multislice-1slice-slice-0-0-pf2ll   1/1     Running     0          1s
    multislice-1slice-slice-0-1-55g62   1/1     Running     0          1s
    multislice-2slice-slice-0-0-f4hf7   1/1     Running     0          3s
    multislice-2slice-slice-0-1-c8kv7   1/1     Running     0          3s
    multislice-2slice-slice-1-0-7h46t   1/1     Running     0          3s
    multislice-2slice-slice-1-1-lj9hb   1/1     Running     0          3s
    multislice-3slice-slice-0-0-wzq9t   0/1     Completed   0          2m31s
    multislice-3slice-slice-0-1-zf4dp   0/1     Completed   0          2m30s
    multislice-3slice-slice-1-0-hbfn5   0/1     Completed   0          2m31s
    multislice-3slice-slice-1-1-45fgl   0/1     Completed   0          2m30s
    multislice-3slice-slice-2-0-wjbp4   0/1     Completed   0          2m30s
    multislice-3slice-slice-2-1-lwnvs   0/1     Completed   0          2m30s
    

    See that GKE scheduled, created, and ran the Pods for multislice-3slice first. Then, GKE ran the Pods from multislice-1slice and multislice-2slice JobSets.

Enable Kueue workload priorities and preemption

Optionally, you can assign Kueue workloads priorities which determine the order in which enqueued workloads are admitted by Kueue.

  1. Update your ClusterQueue to have a preemption policy:

    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ResourceFlavor
    metadata:
      name: "vlp-24"
    spec:
      nodeLabels:
        cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
        cloud.google.com/gke-tpu-topology: 2x4
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ClusterQueue
    metadata:
      name: "cluster-queue"
    spec:
      namespaceSelector: {}
      resourceGroups:
      - coveredResources: ["google.com/tpu"]
        flavors:
        - name: "vlp-24"
          resources:
          - name: "google.com/tpu"
            nominalQuota: 24
     preemption:
        reclaimWithinCohort: Any
        withinClusterQueue: LowerPriority
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: LocalQueue
    metadata:
      namespace: default
      name: multislice-queue
    spec:
      clusterQueue: cluster-queue
    
  2. Create a PriorityClass for each distinct priority level you want to assign to workloads:

    apiVersion: scheduling.k8s.io/v1
    kind: PriorityClass
    metadata:
      name: low-priority
    value: 100
    globalDefault: false
    description: "This low priority class should be used for some Pods only."
    
  3. Assign the priorityClassName to your JobSet:

    Autopilot

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: low-priority
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  priorityClassName: low-priority
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4 # Number of TPU chips per worker
    

    Standard

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: low-priority
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  priorityClassName: low-priority
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4 # Number of TPU chips per worker
      ```
    

Clean up

To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.

Delete the project

  1. In the Google Cloud console, go to the Manage resources page.

    Go to Manage resources

  2. In the project list, select the project that you want to delete, and then click Delete.
  3. In the dialog, type the project ID, and then click Shut down to delete the project.

Delete the individual resource

  1. Delete the Kueue quota system:

    kubectl delete -n team-a localqueue
    kubectl delete -n team-b localqueue
    kubectl delete clusterqueue
    kubectl delete clusterqueue
    kubectl delete clusterqueue
    kubectl delete resourceflavor
    kubectl delete resourceflavor
    kubectl delete resourceflavor
    
  2. Delete the Kueue manifest:

    VERSION=kueue.x-k8s.io/v1beta1
    kubectl delete -f \
        https://github.com/kubernetes-sigs/kueue/releases/download/$VERSION/manifests.yaml
    
  3. Delete the cluster:

    gcloud container clusters delete kueue-cohort --region=COMPUTE_REGION
    

What's next