本教程介绍了如何在 Google Kubernetes Engine (GKE) 上编排多切片工作负载。您可以使用 TPU 多切片、JobSet 和 Kueue 运行 Jax 工作负载。Kueue 会根据配额和层次结构,在团队之间共享资源,从而实现 Job 排队,确定 Job 应等待的时间和应开始的时间。
如果您使用的工作负载需要 TPU 资源才能并发运行,我们建议您完成本教程。
使用 GKE 中的 TPU 之前,我们建议您完成以下学习路线:
- 了解 Cloud TPU 系统架构中的当前 TPU 版本可用性。
- 了解 GKE 中的 TPU 多切片。
目标
本教程适用于已有 GKE 集群并希望首次运行多切片工作负载的 GKE 管理员。
本教程介绍以下步骤:
- 使用具有三个 v5e TPU 切片的 GKE 集群准备环境。每个 TPU 切片都有一个包含 8 个芯片的
2x4
拓扑。因此,总共有 24 个 TPU v5e TPU 芯片。 - 创建 Kueue 资源以确保在工作负载之间公平共享配额。
- 运行多切片工作负载。
准备工作
在开始之前,请确保您已执行以下任务:
- 启用 Google Kubernetes Engine API。 启用 Google Kubernetes Engine API
- 如果您要使用 Google Cloud CLI 执行此任务,请安装并初始化 gcloud CLI。 如果您之前安装了 gcloud CLI,请运行
gcloud components update
以获取最新版本。
准备环境
在 Google Cloud 控制台中,启动 Cloud Shell 实例:
打开 Cloud Shell设置默认环境变量:
gcloud config set project PROJECT_ID gcloud config set compute/region COMPUTE_REGION
替换以下值:
- PROJECT_ID:您的 Google Cloud 项目 ID。
- COMPUTE_REGION:Compute Engine 区域。
默认情况下,运行 1.29.2-gke.1521000 版或更高版本的 Autopilot 集群会启用 TPU。Autopilot 集群上的 TPU 在工作负载规范中配置。如需了解详情,请参阅使用 JobSet 定义多切片工作负载部分。
创建 GKE 集群
在 Cloud Shell 中创建一个 GKE 集群:
Autopilot
gcloud container clusters create-auto multislice-cluster \
--location=LOCATION \
--cluster-version 1.29.2-gke.1521000 \
--release-channel rapid
标准
gcloud container clusters create multislice-cluster \
--location=LOCATION
将 LOCATION 替换为要在其中创建集群的位置。确保集群有足够的容量用于 ct5lp-hightpu-4t
机器类型。集群创建过程可能需要几分钟时间。
如果您使用 GKE Autopilot 模式,请跳到创建 Kueue 资源部分。默认情况下,运行 1.29.2-gke.1521000 版或更高版本的 Autopilot 集群会启用 TPU。
创建三个 Standard 模式 TPU 切片节点池
创建第一个节点池,名为
nodepool1
:gcloud beta container node-pools create nodepool1 \ --location=LOCATION \ --cluster=multislice-cluster \ --node-locations=NODE_LOCATION \ --machine-type=ct5lp-hightpu-4t \ --tpu-topology=2x4 \ --num-nodes=2 \ --project=PROJECT_ID
将 NODE_LOCATION 替换为您要在其中创建节点的集群区域中的一个或多个可用区。
创建第二个节点池,名为
nodepool2
:gcloud beta container node-pools create nodepool2 \ --location=LOCATION \ --cluster=multislice-cluster \ --node-locations=NODE_LOCATION \ --machine-type=ct5lp-hightpu-4t \ --tpu-topology=2x4 \ --num-nodes=2 \ --project=PROJECT_ID
创建第三个节点池,名为
nodepool3
:gcloud beta container node-pools create nodepool3 \ --location=LOCATION \ --cluster=multislice-cluster \ --node-locations=NODE_LOCATION \ --machine-type=ct5lp-hightpu-4t \ --tpu-topology=2x4 \ --num-nodes=2 \ --project=PROJECT_ID
GKE 会创建三个节点池。每个节点池都是一个单独的 TPU 切片。
创建 Kueue 资源
创建以下
kueue.yaml
清单:apiVersion: kueue.x-k8s.io/v1beta1 kind: ResourceFlavor metadata: name: "vlp-24" spec: nodeLabels: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 --- apiVersion: kueue.x-k8s.io/v1beta1 kind: ClusterQueue metadata: name: "cluster-queue" spec: namespaceSelector: {} queueingStrategy: BestEffortFIFO resourceGroups: - coveredResources: ["google.com/tpu"] flavors: - name: "vlp-24" resources: - name: "google.com/tpu" nominalQuota: 24 --- apiVersion: kueue.x-k8s.io/v1beta1 kind: LocalQueue metadata: namespace: default name: multislice-queue spec: clusterQueue: cluster-queue
应用
kueue.yaml
清单:kubectl apply -f kueue.yaml
GKE 会创建以下 Kueue 资源:
- ResourceFlavor:集群中资源的抽象。在此示例中,GKE 会创建三个具有
2x4
拓扑的 TPU 切片。每个 TPU 切片都有一个包含 8 个芯片的2x4
拓扑(总共 24 个 TPU 芯片)。 - ClusterQueue:管理工作负载和集群资源的全局队列。
- LocalQueue:对密切相关的工作负载分组,这些工作负载通常由单个租户(用户)运行。每个 LocalQueue 都指向一个 ClusterQueue,系统会从 ClusterQueue 分配资源以运行其工作负载。Kueue 工作负载是表示批量工作负载的抽象,在此情况下,每个工作负载都是一个 JobSet。
使用 JobSet 定义多切片工作负载
在本部分中,您将创建三个 JobSet。这些 JobSet 运行 Jax 工作负载,该工作负载会输出切片中的 TPU 芯片的全局数量,然后休眠 60 秒以模拟一些模型训练时间,然后退出。
创建以下
jobsets-multislice.yaml
清单:Autopilot
apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: multislice-1slice labels: kueue.x-k8s.io/queue-name: multislice-queue annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: 1 template: spec: parallelism: 2 completions: 2 backoffLimit: 0 template: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 command: - bash - -c - | pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html python -c 'import jax; print("Global device count:", jax.device_count())' resources: limits: google.com/tpu: 4 --- apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: multislice-2slice labels: kueue.x-k8s.io/queue-name: multislice-queue annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: 2 template: spec: parallelism: 2 completions: 2 backoffLimit: 0 template: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 command: - bash - -c - | pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html python -c 'import jax; print("Global device count:", jax.device_count())' sleep 60 resources: limits: google.com/tpu: 4 --- apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: multislice-3slice labels: kueue.x-k8s.io/queue-name: multislice-queue annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: 3 template: spec: parallelism: 2 completions: 2 backoffLimit: 0 template: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 command: - bash - -c - | sleep 60 resources: limits: google.com/tpu: 4
标准
apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: multislice-1slice labels: kueue.x-k8s.io/queue-name: multislice-queue annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: 1 template: spec: parallelism: 2 completions: 2 backoffLimit: 0 template: spec: hostNetwork: true dnsPolicy: ClusterFirstWithHostNet nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 securityContext: privileged: true command: - bash - -c - | pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html python -c 'import jax; print("Global device count:", jax.device_count())' resources: limits: google.com/tpu: 4 --- apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: multislice-2slice labels: kueue.x-k8s.io/queue-name: multislice-queue annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: 2 template: spec: parallelism: 2 completions: 2 backoffLimit: 0 template: spec: hostNetwork: true dnsPolicy: ClusterFirstWithHostNet nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 securityContext: privileged: true command: - bash - -c - | pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html python -c 'import jax; print("Global device count:", jax.device_count())' sleep 60 resources: limits: google.com/tpu: 4 --- apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: multislice-3slice labels: kueue.x-k8s.io/queue-name: multislice-queue annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: 3 template: spec: parallelism: 2 completions: 2 backoffLimit: 0 template: spec: hostNetwork: true dnsPolicy: ClusterFirstWithHostNet nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 securityContext: privileged: true command: - bash - -c - | sleep 60 resources: limits: google.com/tpu: 4
应用
jobsets-multislice.yaml
清单:kubectl apply -f jobsets-multislice.yaml
GKE 会使用以下资源请求创建 Job:
multislice-1slice
JobSet 创建一个 Job,总共需要一个 TPU 切片。multislice-2slice
JobSet 创建两个 Job,总共需要两个 TPU 切片。multislice-3slice
JobSet 创建三个 Job,总共需要三个 TPU 切片。
由于集群只有三个 TPU 切片,因此并非所有 JobSet 可以同时运行。当 Kueue 将所有三个 multislice-3slice
JobSet 加入队列时,其 Job 会单独运行以完成。multislice-1slice
和 multislice-2slice
会等待,然后一起运行。
验证 Kueue 是否允许工作负载
检查 Kueue 中已加入队列的工作负载:
kubectl get workloads
输出类似于以下内容:
NAME QUEUE ADMITTED BY AGE jobset-multislice-1slice-2530a multislice-queue 3s jobset-multislice-2slice-ffb02 multislice-queue 4s jobset-multislice-3slice-8c695 multislice-queue cluster-queue 10s
Kueue 会将一个或多个工作负载加入队列,具体取决于它们所需的 TPU 资源。
监控工作负载
监控哪些 pod 正在运行:
kubectl get pods
输出类似于以下内容:
NAME READY STATUS RESTARTS AGE multislice-1slice-slice-0-0-pf2ll 1/1 Running 0 1s multislice-1slice-slice-0-1-55g62 1/1 Running 0 1s multislice-2slice-slice-0-0-f4hf7 1/1 Running 0 3s multislice-2slice-slice-0-1-c8kv7 1/1 Running 0 3s multislice-2slice-slice-1-0-7h46t 1/1 Running 0 3s multislice-2slice-slice-1-1-lj9hb 1/1 Running 0 3s multislice-3slice-slice-0-0-wzq9t 0/1 Completed 0 2m31s multislice-3slice-slice-0-1-zf4dp 0/1 Completed 0 2m30s multislice-3slice-slice-1-0-hbfn5 0/1 Completed 0 2m31s multislice-3slice-slice-1-1-45fgl 0/1 Completed 0 2m30s multislice-3slice-slice-2-0-wjbp4 0/1 Completed 0 2m30s multislice-3slice-slice-2-1-lwnvs 0/1 Completed 0 2m30s
看到 GKE 先为
multislice-3slice
安排、创建和运行了 pod。然后,GKE 从multislice-1slice
和multislice-2slice
JobSet 运行 Pod。
启用 Kueue 工作负载优先级和抢占
(可选)您可以分配 Kueue 工作负载优先级,用于确定 Kueue 允许已加入队列的工作负载的顺序。
更新
ClusterQueue
,使其具有抢占政策:apiVersion: kueue.x-k8s.io/v1beta1 kind: ResourceFlavor metadata: name: "vlp-24" spec: nodeLabels: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 --- apiVersion: kueue.x-k8s.io/v1beta1 kind: ClusterQueue metadata: name: "cluster-queue" spec: namespaceSelector: {} resourceGroups: - coveredResources: ["google.com/tpu"] flavors: - name: "vlp-24" resources: - name: "google.com/tpu" nominalQuota: 24 preemption: reclaimWithinCohort: Any withinClusterQueue: LowerPriority --- apiVersion: kueue.x-k8s.io/v1beta1 kind: LocalQueue metadata: namespace: default name: multislice-queue spec: clusterQueue: cluster-queue
为要分配给工作负载的每个不同优先级创建一个
PriorityClass
:apiVersion: scheduling.k8s.io/v1 kind: PriorityClass metadata: name: low-priority value: 100 globalDefault: false description: "This low priority class should be used for some Pods only."
为 JobSet 分配
priorityClassName
:Autopilot
apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: low-priority labels: kueue.x-k8s.io/queue-name: multislice-queue annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: 1 template: spec: parallelism: 2 completions: 2 backoffLimit: 0 template: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 priorityClassName: low-priority containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 command: - bash - -c - | sleep 60 resources: limits: google.com/tpu: 4 # Number of TPU chips per worker
标准
apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: low-priority labels: kueue.x-k8s.io/queue-name: multislice-queue annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: 1 template: spec: parallelism: 2 completions: 2 backoffLimit: 0 template: spec: hostNetwork: true dnsPolicy: ClusterFirstWithHostNet nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 priorityClassName: low-priority containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 securityContext: privileged: true command: - bash - -c - | sleep 60 resources: limits: google.com/tpu: 4 # Number of TPU chips per worker ```
清理
为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。
删除项目
- In the Google Cloud console, go to the Manage resources page.
- In the project list, select the project that you want to delete, and then click Delete.
- In the dialog, type the project ID, and then click Shut down to delete the project.
逐个删除资源
删除 Kueue 配额系统:
kubectl delete -n team-a localqueue kubectl delete -n team-b localqueue kubectl delete clusterqueue kubectl delete clusterqueue kubectl delete clusterqueue kubectl delete resourceflavor kubectl delete resourceflavor kubectl delete resourceflavor
删除 Kueue 清单:
VERSION=kueue.x-k8s.io/v1beta1 kubectl delete -f \ https://github.com/kubernetes-sigs/kueue/releases/download/$VERSION/manifests.yaml
删除集群:
gcloud container clusters delete kueue-cohort --region=COMPUTE_REGION
后续步骤
- 详细了解 Kueue。
- 了解如何在 GKE 上通过命名空间之间的配额共享实现 Job 排队系统。