JobSet と Kueue を使用してマルチスライス ワークロードをオーケストレートする


このチュートリアルでは、Google Kubernetes Engine(GKE)と Kueue で TPU マルチスライスを使用して Jax ワークロードを実行する方法について説明します。Kueue は Job キューイングを実装しており、チーム間で公平にリソースを共有するための割り当てや階層に基づいて Job を待機するタイミングや開始するタイミングを決定します。

このチュートリアルでは、TPU リソースを同時に実行する必要がある複数のマルチスライス ワークロードをオーケストレートする方法について説明します。

GKE で TPU を使用する前に、次の学習プログラムを完了することをおすすめします。

  1. Cloud TPU システム アーキテクチャで、現在の TPU バージョンの可用性について学習する。
  2. GKE での TPU マルチスライスについて学習する。

目標

このチュートリアルは、GKE のクラスタがすでに存在し、マルチスライス ワークロードを初めて実行する GKE 管理者を対象としています。

このチュートリアルでは、次の手順について説明します。

  1. 3 つの v5e TPU スライスを含む GKE クラスタで環境を準備します。各 TPU スライスには、8 個のチップを含む 2x4 トポロジがあります。そのため、合計で 24 個の TPU v5e チップがあります。
  2. Kueue リソースを作成して、ワークロード間で割り当てが公平に共有されるようにします。
  3. マルチスライス ワークロードを実行します。

始める前に

始める前に、次の作業が完了していることを確認してください。

  • Google Kubernetes Engine API を有効にする。
  • Google Kubernetes Engine API の有効化
  • このタスクに Google Cloud CLI を使用する場合は、gcloud CLI をインストールして初期化する。すでに gcloud CLI をインストールしている場合は、gcloud components update を実行して最新のバージョンを取得する。

環境を準備する

  1. Google Cloud コンソールで、Cloud Shell インスタンスを起動します。
    Cloud Shell を開く

  2. デフォルトの環境変数を設定します。

    gcloud config set project PROJECT_ID
    gcloud config set compute/region COMPUTE_REGION
    

    次の値を置き換えます。

バージョン 1.29.2-gke.1521000 以降を実行する Autopilot クラスタでは、デフォルトで TPU が有効になります。Autopilot クラスタの TPU は、ワークロード仕様で構成されます。詳細については、JobSet を使用してマルチスライス ワークロードを定義するをご覧ください。

GKE クラスタを作成する

Cloud Shell で、GKE クラスタを作成します。

Autopilot

gcloud container clusters create-auto multislice-cluster \
    --location=LOCATION \
    --cluster-version 1.29.2-gke.1521000 \
    --release-channel rapid

Standard

gcloud container clusters create multislice-cluster \
    --location=LOCATION

LOCATION は、クラスタを作成するロケーションに置き換えます。ct5lp-hightpu-4t マシンタイプの容量があることを確認します。クラスタの作成には数分かかることもあります。

GKE Autopilot モードを使用している場合は、Kueue リソースを作成するセクションに進みます。バージョン 1.29.2-gke.1521000 以降を実行する Autopilot クラスタでは、デフォルトで TPU が有効になります。

3 つの Standard モード TPU スライス ノードプールを作成する

  1. nodepool1 という名前で 1 つ目のノードプールを作成します。

    gcloud beta container node-pools create nodepool1 \
        --location=LOCATION \
        --cluster=multislice-cluster \
        --node-locations=NODE_LOCATION \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=2x4 \
        --num-nodes=2 \
        --project=PROJECT_ID
    

    NODE_LOCATION は、ノードを作成するクラスタ リージョンの 1 つ以上のゾーンに置き換えます。

  2. nodepool2 という名前で 2 つ目のノードプールを作成します。

    gcloud beta container node-pools create nodepool2 \
        --location=LOCATION \
        --cluster=multislice-cluster \
        --node-locations=NODE_LOCATION \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=2x4 \
        --num-nodes=2 \
        --project=PROJECT_ID
    
  3. nodepool3 という名前の 3 つ目のノードプールを作成します。

    gcloud beta container node-pools create nodepool3 \
        --location=LOCATION \
        --cluster=multislice-cluster \
        --node-locations=NODE_LOCATION \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=2x4 \
        --num-nodes=2 \
        --project=PROJECT_ID
    

GKE は 3 つのノードプールを作成します。それぞれのノードプールは個別の TPU スライスです。

Kueue リソースを作成する

  1. 次の kueue.yaml マニフェストを作成します。

    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ResourceFlavor
    metadata:
      name: "vlp-24"
    spec:
      nodeLabels:
        cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
        cloud.google.com/gke-tpu-topology: 2x4
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ClusterQueue
    metadata:
      name: "cluster-queue"
    spec:
      namespaceSelector: {}
      queueingStrategy: BestEffortFIFO
      resourceGroups:
      - coveredResources: ["google.com/tpu"]
        flavors:
        - name: "vlp-24"
          resources:
          - name: "google.com/tpu"
            nominalQuota: 24
    
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: LocalQueue
    metadata:
      namespace: default
      name: multislice-queue
    spec:
      clusterQueue: cluster-queue
    
  2. kueue.yaml マニフェストを適用します。

    kubectl apply -f kueue.yaml
    

    GKE は次の Kueue リソースを作成します。

  • ResourceFlavor: クラスタ内のリソースの抽象化。この例では、GKE は 2x4 トポロジを持つ 3 つの TPU スライスを作成します。各 TPU スライスには、8 個のチップ(合計 24 個の TPU チップ)を含む 2x4 トポロジがあります。
  • ClusterQueue: ワークロードとクラスタ リソースを管理するグローバル キュー。
  • LocalQueue: 通常は単一のテナント(ユーザー)が実行する密接に関連したワークロードをグループ化します。各 LocalQueue は、ワークロードを実行するためにリソースが割り当てられる ClusterQueue を参照します。Kueue ワークロード はバッチ ワークロードを表す抽象化です。この場合の各ワークロードは JobSet です。

JobSet を使用してマルチスライス ワークロードを定義する

このセクションでは、3 つの JobSet を作成します。これらの JobSet は、スライス内の TPU チップの総数を出力する Jax ワークロードを実行し、60 秒間スリープしてモデルのトレーニング時間をシミュレートした後に終了します。

  1. 次の jobsets-multislice.yaml マニフェストを作成します。

    Autopilot

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-1slice  
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue  
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                    resources:
                      limits:
                        google.com/tpu: 4
    
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-2slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 2
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-3slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 3
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    

    Standard

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-1slice  
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue  
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                    resources:
                      limits:
                        google.com/tpu: 4
    
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-2slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 2
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-3slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 3
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    
  2. jobsets-multislice.yaml マニフェストを適用します。

    kubectl apply -f jobsets-multislice.yaml
    

GKE は、次のリソース リクエストを使用して Job を作成します。

  • multislice-1slice JobSet は、合計 1 つの TPU スライスを必要とする 1 つの Job を作成します。
  • multislice-2slice JobSet は、合計 2 つの TPU スライスを必要とする 2 つの Job を作成します。
  • multislice-3slice JobSet は、合計 3 つの TPU スライスを必要とする 3 つの Job を作成します。

クラスタには TPU スライスが 3 つしかないため、すべての JobSet を一度に実行できるわけではありません。Kueue が 3 つすべての multislice-3slice JobSet をキューに登録すると、その Job は完了するまで単独で実行されます。multislice-1slicemultislice-2slice は待機し、その後一緒に実行されます。

Kueue がワークロードを承諾したことを確認する

  1. Kueue でキューに登録されたワークロードを確認します。

    kubectl get workloads
    

    出力は次のようになります。

    NAME                             QUEUE              ADMITTED BY     AGE
    jobset-multislice-1slice-2530a   multislice-queue                   3s
    jobset-multislice-2slice-ffb02   multislice-queue                   4s
    jobset-multislice-3slice-8c695   multislice-queue   cluster-queue   10s
    

Kueue は、必要な TPU リソースに応じて 1 つ以上のワークロードをキューに登録します。

ワークロードをモニタリングする

  1. 実行中の Pod をモニタリングします。

    kubectl get pods
    

    出力は次のようになります。

    NAME                                READY   STATUS      RESTARTS   AGE
    multislice-1slice-slice-0-0-pf2ll   1/1     Running     0          1s
    multislice-1slice-slice-0-1-55g62   1/1     Running     0          1s
    multislice-2slice-slice-0-0-f4hf7   1/1     Running     0          3s
    multislice-2slice-slice-0-1-c8kv7   1/1     Running     0          3s
    multislice-2slice-slice-1-0-7h46t   1/1     Running     0          3s
    multislice-2slice-slice-1-1-lj9hb   1/1     Running     0          3s
    multislice-3slice-slice-0-0-wzq9t   0/1     Completed   0          2m31s
    multislice-3slice-slice-0-1-zf4dp   0/1     Completed   0          2m30s
    multislice-3slice-slice-1-0-hbfn5   0/1     Completed   0          2m31s
    multislice-3slice-slice-1-1-45fgl   0/1     Completed   0          2m30s
    multislice-3slice-slice-2-0-wjbp4   0/1     Completed   0          2m30s
    multislice-3slice-slice-2-1-lwnvs   0/1     Completed   0          2m30s
    

    GKE が最初に multislice-3slice の Pod をスケジュール、作成、実行したことがわかります。次に、GKE は multislice-1slicemultislice-2slice JobSet から Pod を実行しています。

Kueue ワークロードの優先度とプリエンプションを有効にする

必要に応じて、Kueue ワークロードに優先度を割り当てて Kueue がキューに登録されたワークロードを承諾する順序を決めることができます。

  1. ClusterQueue を更新してプリエンプション ポリシーを設定します。

    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ResourceFlavor
    metadata:
      name: "vlp-24"
    spec:
      nodeLabels:
        cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
        cloud.google.com/gke-tpu-topology: 2x4
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ClusterQueue
    metadata:
      name: "cluster-queue"
    spec:
      namespaceSelector: {}
      resourceGroups:
      - coveredResources: ["google.com/tpu"]
        flavors:
        - name: "vlp-24"
          resources:
          - name: "google.com/tpu"
            nominalQuota: 24
     preemption:
        reclaimWithinCohort: Any
        withinClusterQueue: LowerPriority
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: LocalQueue
    metadata:
      namespace: default
      name: multislice-queue
    spec:
      clusterQueue: cluster-queue
    
  2. ワークロードに割り当てる優先度ごとに PriorityClass を作成します。

    apiVersion: scheduling.k8s.io/v1
    kind: PriorityClass
    metadata:
      name: low-priority
    value: 100
    globalDefault: false
    description: "This low priority class should be used for some Pods only."
    
  3. priorityClassName を JobSet に割り当てます。

    Autopilot

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: low-priority
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  priorityClassName: low-priority
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4 # Number of TPU chips per worker
    

    Standard

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: low-priority
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  priorityClassName: low-priority
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4 # Number of TPU chips per worker
      ```
    

クリーンアップ

このチュートリアルで使用したリソースについて、Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。

プロジェクトを削除する

  1. In the Google Cloud console, go to the Manage resources page.

    Go to Manage resources

  2. In the project list, select the project that you want to delete, and then click Delete.
  3. In the dialog, type the project ID, and then click Shut down to delete the project.

個々のリソースを削除する

  1. Kueue 割り当てシステムを削除します。

    kubectl delete -n team-a localqueue
    kubectl delete -n team-b localqueue
    kubectl delete clusterqueue
    kubectl delete clusterqueue
    kubectl delete clusterqueue
    kubectl delete resourceflavor
    kubectl delete resourceflavor
    kubectl delete resourceflavor
    
  2. Kueue マニフェストを削除します。

    VERSION=kueue.x-k8s.io/v1beta1
    kubectl delete -f \
        https://github.com/kubernetes-sigs/kueue/releases/download/$VERSION/manifests.yaml
    
  3. クラスタを削除します。

    gcloud container clusters delete kueue-cohort --region=COMPUTE_REGION
    

次のステップ