使用 JobSet 和 Kueue 编排多切片工作负载


本教程介绍了如何在 Google Kubernetes Engine (GKE) 上编排多切片工作负载。您可以使用 TPU 多切片、JobSet 和 Kueue 运行 Jax 工作负载。Kueue 会根据配额和层次结构,在团队之间共享资源,从而实现 Job 排队,确定 Job 应等待的时间和应开始的时间。

本教程介绍如何编排多个需要 TPU 资源并发运行的多切片工作负载。

使用 GKE 中的 TPU 之前,我们建议您完成以下学习路线:

  1. 了解 Cloud TPU 系统架构中的当前 TPU 版本可用性。
  2. 了解 GKE 中的 TPU 多切片

目标

本教程适用于已有 GKE 集群并希望首次运行多切片工作负载的 GKE 管理员。

本教程介绍以下步骤:

  1. 使用具有三个 v5e TPU 切片的 GKE 集群准备环境。每个 TPU 切片都有一个包含 8 个芯片的 2x4 拓扑。因此,总共有 24 个 TPU v5e TPU 芯片。
  2. 创建 Kueue 资源以确保在工作负载之间公平共享配额。
  3. 运行多切片工作负载。

准备工作

在开始之前,请确保您已执行以下任务:

  • 启用 Google Kubernetes Engine API。
  • 启用 Google Kubernetes Engine API
  • 如果您要使用 Google Cloud CLI 执行此任务,请安装初始化 gcloud CLI。 如果您之前安装了 gcloud CLI,请运行 gcloud components update 以获取最新版本。

准备环境

  1. 在 Google Cloud 控制台中,启动 Cloud Shell 实例:
    打开 Cloud Shell

  2. 设置默认环境变量:

    gcloud config set project PROJECT_ID
    gcloud config set compute/region COMPUTE_REGION
    

    替换以下值:

默认情况下,运行 1.29.2-gke.1521000 版或更高版本的 Autopilot 集群会启用 TPU。Autopilot 集群上的 TPU 在工作负载规范中配置。如需了解详情,请参阅使用 JobSet 定义多切片工作负载部分。

创建 GKE 集群

在 Cloud Shell 中创建一个 GKE 集群:

Autopilot

gcloud container clusters create-auto multislice-cluster \
    --location=LOCATION \
    --cluster-version 1.29.2-gke.1521000 \
    --release-channel rapid

标准

gcloud container clusters create multislice-cluster \
    --location=LOCATION

LOCATION 替换为要在其中创建集群的位置。确保集群有足够的容量用于 ct5lp-hightpu-4t 机器类型。集群创建过程可能需要几分钟时间。

如果您使用 GKE Autopilot 模式,请跳到创建 Kueue 资源部分。默认情况下,运行 1.29.2-gke.1521000 版或更高版本的 Autopilot 集群会启用 TPU。

创建三个 Standard 模式 TPU 切片节点池

  1. 创建第一个节点池,名为 nodepool1

    gcloud beta container node-pools create nodepool1 \
        --location=LOCATION \
        --cluster=multislice-cluster \
        --node-locations=NODE_LOCATION \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=2x4 \
        --num-nodes=2 \
        --project=PROJECT_ID
    

    NODE_LOCATION 替换为您要在其中创建节点的集群区域中的一个或多个可用区。

  2. 创建第二个节点池,名为 nodepool2

    gcloud beta container node-pools create nodepool2 \
        --location=LOCATION \
        --cluster=multislice-cluster \
        --node-locations=NODE_LOCATION \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=2x4 \
        --num-nodes=2 \
        --project=PROJECT_ID
    
  3. 创建第三个节点池,名为 nodepool3

    gcloud beta container node-pools create nodepool3 \
        --location=LOCATION \
        --cluster=multislice-cluster \
        --node-locations=NODE_LOCATION \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=2x4 \
        --num-nodes=2 \
        --project=PROJECT_ID
    

GKE 会创建三个节点池。每个节点池都是一个单独的 TPU 切片。

创建 Kueue 资源

  1. 创建以下 kueue.yaml 清单:

    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ResourceFlavor
    metadata:
      name: "vlp-24"
    spec:
      nodeLabels:
        cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
        cloud.google.com/gke-tpu-topology: 2x4
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ClusterQueue
    metadata:
      name: "cluster-queue"
    spec:
      namespaceSelector: {}
      queueingStrategy: BestEffortFIFO
      resourceGroups:
      - coveredResources: ["google.com/tpu"]
        flavors:
        - name: "vlp-24"
          resources:
          - name: "google.com/tpu"
            nominalQuota: 24
    
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: LocalQueue
    metadata:
      namespace: default
      name: multislice-queue
    spec:
      clusterQueue: cluster-queue
    
  2. 应用 kueue.yaml 清单:

    kubectl apply -f kueue.yaml
    

    GKE 会创建以下 Kueue 资源:

  • ResourceFlavor:集群中资源的抽象。在此示例中,GKE 会创建三个具有 2x4 拓扑的 TPU 切片。每个 TPU 切片都有一个包含 8 个芯片的 2x4 拓扑(总共 24 个 TPU 芯片)。
  • ClusterQueue:管理工作负载和集群资源的全局队列。
  • LocalQueue:对密切相关的工作负载分组,这些工作负载通常由单个租户(用户)运行。每个 LocalQueue 都指向一个 ClusterQueue,系统会从 ClusterQueue 分配资源以运行其工作负载。Kueue 工作负载是表示批量工作负载的抽象,在此情况下,每个工作负载都是一个 JobSet。

使用 JobSet 定义多切片工作负载

在本部分中,您将创建三个 JobSet。这些 JobSet 运行 Jax 工作负载,该工作负载会输出切片中的 TPU 芯片的全局数量,然后休眠 60 秒以模拟一些模型训练时间,然后退出。

  1. 创建以下 jobsets-multislice.yaml 清单:

    Autopilot

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-1slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                    resources:
                      limits:
                        google.com/tpu: 4
    
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-2slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 2
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-3slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 3
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    

    标准

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-1slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                    resources:
                      limits:
                        google.com/tpu: 4
    
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-2slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 2
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-3slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 3
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    
  2. 应用 jobsets-multislice.yaml 清单:

    kubectl apply -f jobsets-multislice.yaml
    

GKE 会使用以下资源请求创建 Job:

  • multislice-1slice JobSet 创建一个 Job,总共需要一个 TPU 切片。
  • multislice-2slice JobSet 创建两个 Job,总共需要两个 TPU 切片。
  • multislice-3slice JobSet 创建三个 Job,总共需要三个 TPU 切片。

由于集群只有三个 TPU 切片,因此并非所有 JobSet 可以同时运行。当 Kueue 将所有三个 multislice-3slice JobSet 加入队列时,其 Job 会单独运行以完成。multislice-1slicemultislice-2slice 会等待,然后一起运行。

验证 Kueue 是否允许工作负载

  1. 检查 Kueue 中已加入队列的工作负载:

    kubectl get workloads
    

    输出类似于以下内容:

    NAME                             QUEUE              ADMITTED BY     AGE
    jobset-multislice-1slice-2530a   multislice-queue                   3s
    jobset-multislice-2slice-ffb02   multislice-queue                   4s
    jobset-multislice-3slice-8c695   multislice-queue   cluster-queue   10s
    

Kueue 会将一个或多个工作负载加入队列,具体取决于它们所需的 TPU 资源。

监控工作负载

  1. 监控哪些 pod 正在运行:

    kubectl get pods
    

    输出类似于以下内容:

    NAME                                READY   STATUS      RESTARTS   AGE
    multislice-1slice-slice-0-0-pf2ll   1/1     Running     0          1s
    multislice-1slice-slice-0-1-55g62   1/1     Running     0          1s
    multislice-2slice-slice-0-0-f4hf7   1/1     Running     0          3s
    multislice-2slice-slice-0-1-c8kv7   1/1     Running     0          3s
    multislice-2slice-slice-1-0-7h46t   1/1     Running     0          3s
    multislice-2slice-slice-1-1-lj9hb   1/1     Running     0          3s
    multislice-3slice-slice-0-0-wzq9t   0/1     Completed   0          2m31s
    multislice-3slice-slice-0-1-zf4dp   0/1     Completed   0          2m30s
    multislice-3slice-slice-1-0-hbfn5   0/1     Completed   0          2m31s
    multislice-3slice-slice-1-1-45fgl   0/1     Completed   0          2m30s
    multislice-3slice-slice-2-0-wjbp4   0/1     Completed   0          2m30s
    multislice-3slice-slice-2-1-lwnvs   0/1     Completed   0          2m30s
    

    看到 GKE 先为 multislice-3slice 安排、创建和运行了 pod。然后,GKE 从 multislice-1slicemultislice-2slice JobSet 运行 Pod。

启用 Kueue 工作负载优先级和抢占

(可选)您可以分配 Kueue 工作负载优先级,用于确定 Kueue 允许已加入队列的工作负载的顺序。

  1. 更新 ClusterQueue,使其具有抢占政策:

    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ResourceFlavor
    metadata:
      name: "vlp-24"
    spec:
      nodeLabels:
        cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
        cloud.google.com/gke-tpu-topology: 2x4
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ClusterQueue
    metadata:
      name: "cluster-queue"
    spec:
      namespaceSelector: {}
      resourceGroups:
      - coveredResources: ["google.com/tpu"]
        flavors:
        - name: "vlp-24"
          resources:
          - name: "google.com/tpu"
            nominalQuota: 24
     preemption:
        reclaimWithinCohort: Any
        withinClusterQueue: LowerPriority
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: LocalQueue
    metadata:
      namespace: default
      name: multislice-queue
    spec:
      clusterQueue: cluster-queue
    
  2. 为要分配给工作负载的每个不同优先级创建一个 PriorityClass

    apiVersion: scheduling.k8s.io/v1
    kind: PriorityClass
    metadata:
      name: low-priority
    value: 100
    globalDefault: false
    description: "This low priority class should be used for some Pods only."
    
  3. 为 JobSet 分配 priorityClassName

    Autopilot

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: low-priority
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  priorityClassName: low-priority
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4 # Number of TPU chips per worker
    

    标准

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: low-priority
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  priorityClassName: low-priority
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4 # Number of TPU chips per worker
      ```
    

清理

为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

删除项目

  1. In the Google Cloud console, go to the Manage resources page.

    Go to Manage resources

  2. In the project list, select the project that you want to delete, and then click Delete.
  3. In the dialog, type the project ID, and then click Shut down to delete the project.

逐个删除资源

  1. 删除 Kueue 配额系统:

    kubectl delete -n team-a localqueue
    kubectl delete -n team-b localqueue
    kubectl delete clusterqueue
    kubectl delete clusterqueue
    kubectl delete clusterqueue
    kubectl delete resourceflavor
    kubectl delete resourceflavor
    kubectl delete resourceflavor
    
  2. 删除 Kueue 清单:

    VERSION=kueue.x-k8s.io/v1beta1
    kubectl delete -f \
        https://github.com/kubernetes-sigs/kueue/releases/download/$VERSION/manifests.yaml
    
  3. 删除集群:

    gcloud container clusters delete kueue-cohort --region=COMPUTE_REGION
    

后续步骤