使用 JobSet 和 Kueue 编排多切片工作负载


本教程介绍了如何在 Google Kubernetes Engine (GKE) 和 Kueue 中使用 TPU 多切片运行 Jax 工作负载。Kueue 会根据配额和层次结构,在团队之间共享资源,从而实现 Job 排队,确定 Job 应等待的时间和应开始的时间。

本教程介绍了如何编排多个需要 TPU 资源并发运行的多切片工作负载

使用 GKE 中的 TPU 之前,我们建议您完成以下学习路线:

  1. 通过 Cloud TPU 系统架构了解当前的 TPU 版本可用性。
  2. 了解 GKE 中的 TPU 多切片

目标

本教程适用于已有 GKE 集群并希望首次运行多切片工作负载的 GKE 管理员。

本教程介绍以下步骤:

  1. 使用具有三个 v5e TPU 切片的 GKE 集群准备环境。每个 TPU 切片都有一个 2x4 拓扑,每个主机有四个芯片。因此,总共有 24 个 TPU v5e 芯片。
  2. 创建 Kueue 资源以确保配额在工作负载之间公平共享。
  3. 运行您的多片工作负载。

准备工作

在开始之前,请确保您已执行以下任务:

  • 启用 Google Kubernetes Engine API。
  • 启用 Google Kubernetes Engine API
  • 如果您要使用 Google Cloud CLI 执行此任务,请安装初始化 gcloud CLI。 如果您之前安装了 gcloud CLI,请运行 gcloud components update 以获取最新版本。

准备环境

  1. 在 Google Cloud 控制台中,启动 Cloud Shell 实例:
    打开 Cloud Shell

  2. 设置默认环境变量:

    gcloud config set project PROJECT_ID
    gcloud config set compute/region COMPUTE_REGION
    

    替换以下值:

默认情况下,运行 1.29.2-gke.1521000 版或更高版本的 Autopilot 集群会启用 TPU。Autopilot 集群上的 TPU 在工作负载规范中配置。如需了解详情,请参阅使用 JobSet 定义多切片工作负载部分。

创建 GKE 集群

在 Cloud Shell 中创建一个 GKE 集群:

Autopilot

gcloud container clusters create-auto multislice-cluster \
    --location=LOCATION \
    --cluster-version 1.29.2-gke.1521000 \
    --release-channel rapid

标准

gcloud container clusters create multislice-cluster \
    --location=LOCATION

LOCATION 替换为您要在其中创建集群的位置。确保集群有足够的容量用于 ct5lp-hightpu-4t 机器类型。集群创建过程可能需要几分钟时间。

如果您使用 GKE Autopilot 模式,请跳至创建 Kueue 资源部分。默认情况下,运行 1.29.2-gke.1521000 版或更高版本的 Autopilot 集群会启用 TPU。

创建三个 Standard 模式 TPU 节点池

  1. 创建第一个节点池,名为 nodepool1

    gcloud beta container node-pools create nodepool1 \
        --location=LOCATION \
        --cluster=multislice-cluster \
        --node-locations=NODE_LOCATION \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=2x4 \
        --num-nodes=2 \
        --project=PROJECT_ID
    

    NODE_LOCATION 替换为您要在其中创建节点的集群区域中的一个或多个可用区。

  2. 创建第二个节点池,名为 nodepool2

    gcloud beta container node-pools create nodepool2 \
        --location=LOCATION \
        --cluster=multislice-cluster \
        --node-locations=NODE_LOCATION \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=2x4 \
        --num-nodes=2 \
        --project=PROJECT_ID
    
  3. 创建第三个节点池,名为 nodepool3

    gcloud beta container node-pools create nodepool3 \
        --location=LOCATION \
        --cluster=multislice-cluster \
        --node-locations=NODE_LOCATION \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=2x4 \
        --num-nodes=2 \
        --project=PROJECT_ID
    

GKE 会创建三个节点池。每个节点池都是一个单独的 TPU 切片。

创建 Kueue 资源

  1. 创建以下 kueue.yaml 清单:

    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ResourceFlavor
    metadata:
      name: "vlp-24"
    spec:
      nodeLabels:
        cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
        cloud.google.com/gke-tpu-topology: 2x4
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ClusterQueue
    metadata:
      name: "cluster-queue"
    spec:
      namespaceSelector: {}
      queueingStrategy: BestEffortFIFO
      resourceGroups:
      - coveredResources: ["google.com/tpu"]
        flavors:
        - name: "vlp-24"
          resources:
          - name: "google.com/tpu"
            nominalQuota: 24
    
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: LocalQueue
    metadata:
      namespace: default
      name: multislice-queue
    spec:
      clusterQueue: cluster-queue
    
  2. 应用 kueue.yaml 清单:

    kubectl apply -f kueue.yaml
    

    GKE 会创建以下 Kueue 资源:

  • ResourceFlavor:集群中资源的抽象。在此示例中,三个 TPU 切片具有 2x4 拓扑,每个主机具有四个芯片,因此 24 个 TPU 芯片。
  • ClusterQueue:管理工作负载和集群资源的全局队列。
  • LocalQueue:通常由单个租户(用户)运行的紧密相关的工作负载。每个 LocalQueue 都指向一个 ClusterQueue,系统会从该集群分配资源以运行其工作负载。Kueue 工作负载是表示批处理工作负载的抽象;在本例中,每个工作负载都是一个 JobSet。

使用 JobSet 定义多切片工作负载

在本部分中,您将创建三个 JobSet。这些 JobSet 会运行一个 Jax 工作负载,用于输出切片中的全球 TPU 芯片数量,然后休眠 60 秒以模拟一些模型训练时间,然后退出。

  1. 创建以下 jobsets-multislice.yaml 清单:

    Autopilot

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-1slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                    resources:
                      limits:
                        google.com/tpu: 4
    
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-2slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 2
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-3slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 3
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    

    标准

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-1slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                    resources:
                      limits:
                        google.com/tpu: 4
    
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-2slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 2
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-3slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 3
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    
  2. 应用 jobsets-multislice.yaml 清单:

    kubectl apply -f jobsets-multislice.yaml
    

GKE 会使用以下资源请求来创建作业:

  • multislice-1slice JobSet 会创建一个总共需要一个 TPU 切片的作业。
  • multislice-2slice JobSet 创建了两个 Job,总共需要两个 TPU 切片。
  • multislice-3slice JobSet 创建了三个作业,总共需要三个 TPU 切片。

由于集群只有三个 TPU 切片,因此并非所有 JobSet 可以同时运行。当 Kueue 将三个 multislice-3slice JobSet 加入队列时,其 Job 会单独运行以完成。multislice-1slicemultislice-2slice 会等待,然后一起运行。

验证 Kueue 是否允许工作负载

  1. 检查 Kueue 中已加入队列的工作负载:

    kubectl get workloads
    

    输出类似于以下内容:

    NAME                             QUEUE              ADMITTED BY     AGE
    jobset-multislice-1slice-2530a   multislice-queue                   3s
    jobset-multislice-2slice-ffb02   multislice-queue                   4s
    jobset-multislice-3slice-8c695   multislice-queue   cluster-queue   10s
    

Kueue 会将一个或多个工作负载加入队列,具体取决于它们所需的 TPU 资源。

监控工作负载

  1. 监控哪些 pod 正在运行:

    kubectl get pods
    

    输出类似于以下内容:

    NAME                                READY   STATUS      RESTARTS   AGE
    multislice-1slice-slice-0-0-pf2ll   1/1     Running     0          1s
    multislice-1slice-slice-0-1-55g62   1/1     Running     0          1s
    multislice-2slice-slice-0-0-f4hf7   1/1     Running     0          3s
    multislice-2slice-slice-0-1-c8kv7   1/1     Running     0          3s
    multislice-2slice-slice-1-0-7h46t   1/1     Running     0          3s
    multislice-2slice-slice-1-1-lj9hb   1/1     Running     0          3s
    multislice-3slice-slice-0-0-wzq9t   0/1     Completed   0          2m31s
    multislice-3slice-slice-0-1-zf4dp   0/1     Completed   0          2m30s
    multislice-3slice-slice-1-0-hbfn5   0/1     Completed   0          2m31s
    multislice-3slice-slice-1-1-45fgl   0/1     Completed   0          2m30s
    multislice-3slice-slice-2-0-wjbp4   0/1     Completed   0          2m30s
    multislice-3slice-slice-2-1-lwnvs   0/1     Completed   0          2m30s
    

    看到 GKE 先为 multislice-3slice 安排、创建和运行 Pod。然后,GKE 从 multislice-1slicemultislice-2slice JobSet 运行 Pod。

启用 Kueue 工作负载优先级和抢占

(可选)您可以分配 Kueue 工作负载优先级,用于确定 Kueue 允许加入队列的工作负载的顺序。

  1. 更新 ClusterQueue 以获取抢占政策:

    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ResourceFlavor
    metadata:
      name: "vlp-24"
    spec:
      nodeLabels:
        cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
        cloud.google.com/gke-tpu-topology: 2x4
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ClusterQueue
    metadata:
      name: "cluster-queue"
    spec:
      namespaceSelector: {}
      resourceGroups:
      - coveredResources: ["google.com/tpu"]
        flavors:
        - name: "vlp-24"
          resources:
          - name: "google.com/tpu"
            nominalQuota: 24
     preemption:
        reclaimWithinCohort: Any
        withinClusterQueue: LowerPriority
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: LocalQueue
    metadata:
      namespace: default
      name: multislice-queue
    spec:
      clusterQueue: cluster-queue
    
  2. 为您要分配给工作负载的每个不同优先级创建 PriorityClass

    apiVersion: scheduling.k8s.io/v1
    kind: PriorityClass
    metadata:
      name: low-priority
    value: 100
    globalDefault: false
    description: "This low priority class should be used for some Pods only."
    
  3. 为 JobSet 分配 priorityClassName

    Autopilot

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: low-priority
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  priorityClassName: low-priority
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4 # Number of TPU chips per worker
    

    标准

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: low-priority
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  priorityClassName: low-priority
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4 # Number of TPU chips per worker
      ```
    

清理

为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

删除项目

  1. 在 Google Cloud 控制台中,进入管理资源页面。

    转到“管理资源”

  2. 在项目列表中,选择要删除的项目,然后点击删除
  3. 在对话框中输入项目 ID,然后点击关闭以删除项目。

逐个删除资源

  1. 删除 Kueue 配额系统:

    kubectl delete -n team-a localqueue
    kubectl delete -n team-b localqueue
    kubectl delete clusterqueue
    kubectl delete clusterqueue
    kubectl delete clusterqueue
    kubectl delete resourceflavor
    kubectl delete resourceflavor
    kubectl delete resourceflavor
    
  2. 删除 Kueue 清单:

    VERSION=kueue.x-k8s.io/v1beta1
    kubectl delete -f \
        https://github.com/kubernetes-sigs/kueue/releases/download/$VERSION/manifests.yaml
    
  3. 删除集群:

    gcloud container clusters delete kueue-cohort --region=COMPUTE_REGION
    

后续步骤