GKE で Saxml を実行してマルチホスト TPU を使用して LLM を提供する


このチュートリアルでは、Saxml を使用して、Google Kubernetes Engine(GKE)でマルチホスト TPU スライス ノードプールを利用し、大規模言語モデル(LLM)をデプロイして提供する方法について説明します。これにより、効率的なスケーラブルなアーキテクチャを実現できます。

背景

Saxml は、PaxmlJAXPyTorch の各フレームワークを提供する試験運用版のシステムです。TPU を使用すると、これらのフレームワークでデータ処理を高速化できます。GKE で TPU のデプロイのデモを行うため、このチュートリアルでは 175B の LmCloudSpmd175B32Test テストモデルを使用します。GKE は、このテストモデルをそれぞれ 4x8 トポロジの 2 つの v5e TPU スライス ノードプールにデプロイします。

テストモデルを適切にデプロイするために、TPU トポロジはモデルのサイズに基づいて定義されています。N0 億の 16 ビットモデルには約 2 倍(2 x N)の GB 数のメモリが必要ですが、175B LmCloudSpmd175B32Test モデルには約 350 GB のメモリが必要です。TPU v5e シングル TPU チップの容量は 16 GB です。350 GB をサポートするには、GKE に 21 個の v5e TPU チップが必要です(350÷16= 21)。TPU 構成のマッピングに基づいて、このチュートリアルの適切な TPU 構成は次のようになります。

  • マシンタイプ: ct5lp-hightpu-4t
  • トポロジ: 4x8(32 個の TPU チップ)

GKE に TPU をデプロイする場合は、モデルの提供に適した TPU トポロジを選択することが重要です。詳細については、TPU 構成の計画をご覧ください。

目標

このチュートリアルは、データモデルを提供するために GKE オーケストレーション機能を使用する MLOps または DevOps エンジニア、プラットフォーム管理者を対象としています。

このチュートリアルでは、次の手順について説明します。

  1. GKE Standard クラスタで環境を準備します。クラスタには、4x8 トポロジの 2 つの v5e TPU スライス ノードプールがあります。
  2. Saxml をデプロイします。Saxml には、管理者サーバー、モデルサーバーとして機能する Pod のグループ、事前に構築された HTTP サーバー、ロードバランサが必要です。
  3. Saxml を使用して LLM を提供します。

次の図は、このチュートリアルで実装するアーキテクチャを示しています。

GKE 上のマルチホスト TPU のアーキテクチャ。
図: GKE 上のマルチホスト TPU のアーキテクチャ例。

始める前に

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  • Make sure that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Enable the API

  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  • Make sure that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Enable the API

  • Make sure that you have the following role or roles on the project: roles/container.admin, roles/iam.serviceAccountAdmin

    Check for the roles

    1. In the Google Cloud console, go to the IAM page.

      Go to IAM
    2. Select the project.
    3. In the Principal column, find all rows that identify you or a group that you're included in. To learn which groups you're included in, contact your administrator.

    4. For all rows that specify or include you, check the Role colunn to see whether the list of roles includes the required roles.

    Grant the roles

    1. In the Google Cloud console, go to the IAM page.

      [IAM] に移動
    2. プロジェクトを選択します。
    3. [ アクセスを許可] をクリックします。
    4. [新しいプリンシパル] フィールドに、ユーザー ID を入力します。 これは通常、Google アカウントのメールアドレスです。

    5. [ロールを選択] リストでロールを選択します。
    6. 追加のロールを付与するには、 [別のロールを追加] をクリックして各ロールを追加します。
    7. [保存] をクリックします。

環境を準備する

  1. Google Cloud コンソールで、Cloud Shell インスタンスを起動します。
    Cloud Shell を開く

  2. デフォルトの環境変数を設定します。

      gcloud config set project PROJECT_ID
      export PROJECT_ID=$(gcloud config get project)
      export REGION=COMPUTE_REGION
      export ZONE=COMPUTE_ZONE
      export GSBUCKET=PROJECT_ID-gke-bucket
    

    次の値を置き換えます。

GKE Standard クラスタを作成する

Cloud Shell で以下の操作を行います。

  1. GKE 用 Workload Identity 連携を使用する Standard クラスタを作成します。

    gcloud container clusters create saxml \
        --zone=${ZONE} \
        --workload-pool=${PROJECT_ID}.svc.id.goog \
        --cluster-version=VERSION \
        --num-nodes=4
    

    VERSION は、GKE のバージョン番号に置き換えます。GKE は、バージョン 1.27.2-gke.2100 以降で TPU v5e をサポートしています。詳細については、GKE での TPU の可用性をご覧ください。

    クラスタの作成には数分かかることもあります。

  2. tpu1 という名前で 1 つ目のノードプールを作成します。

    gcloud container node-pools create tpu1 \
        --zone=${ZONE} \
        --num-nodes=8 \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=4x8 \
        --cluster=saxml
    
  3. tpu2 という名前で 2 つ目のノードプールを作成します。

    gcloud container node-pools create tpu2 \
        --zone=${ZONE} \
        --num-nodes=8 \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=4x8 \
        --cluster=saxml
    

次のリソースを作成しました。

  • 4 つの CPU ノードを持つ Standard クラスタ。
  • 4x8 トポロジを持つ 2 つの v5e TPU スライス ノードプール。各ノードプールは、それぞれ 4 つの TPU チップを持つ 8 つの TPU スライスノードを表します。

175B モデルは、少なくとも 4x8 トポロジ スライス(32 個の v5e TPU チップ)を持つマルチホスト v5e TPU スライスで提供する必要があります。

Cloud Storage バケットを作成する

Saxml 管理者サーバーの構成を保存する Cloud Storage バケットを作成します。実行中の管理者サーバーは、その状態と公開モデルの詳細を定期的に保存します。

Cloud Shell で次のコマンドを実行します。

gcloud storage buckets create gs://${GSBUCKET}

GKE 用 Workload Identity 連携を使用してワークロード アクセスを構成する

アプリケーションに Kubernetes ServiceAccount を割り当て、IAM サービス アカウントとして機能するようにその Kubernetes ServiceAccount を構成します。

  1. クラスタと通信を行うように kubectl を構成します。

    gcloud container clusters get-credentials saxml --zone=${ZONE}
    
  2. アプリケーションで使用する Kubernetes ServiceAccount を作成します。

    kubectl create serviceaccount sax-sa --namespace default
    
  3. アプリケーションの IAM サービス アカウントを作成します。

    gcloud iam service-accounts create sax-iam-sa
    
  4. IAM サービス アカウントの IAM ポリシー バインディングを追加して、Cloud Storage に対する読み取りと書き込みを行います。

    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
      --member "serviceAccount:sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com" \
      --role roles/storage.admin
    
  5. 2 つのサービス アカウントの間に IAM ポリシー バインディングを追加して、Kubernetes ServiceAccount が IAM サービス アカウントの権限を借用できるようにします。このバインドで、Kubernetes ServiceAccount が IAM サービス アカウントとして機能するようになるため、Kubernetes ServiceAccount が Cloud Storage に対して読み書きを行うことができます。

    gcloud iam service-accounts add-iam-policy-binding sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com \
      --role roles/iam.workloadIdentityUser \
      --member "serviceAccount:${PROJECT_ID}.svc.id.goog[default/sax-sa]"
    
  6. Kubernetes サービス アカウントに IAM サービス アカウントのメールアドレスでアノテーションを付けます。これにより、サンプルアプリが Google Cloud サービスへのアクセスに使用するサービス アカウントを認識できます。そのため、アプリが標準の Google API クライアント ライブラリを使用して Google Cloud サービスにアクセスする場合は、その IAM サービス アカウントを使用します。

    kubectl annotate serviceaccount sax-sa \
      iam.gke.io/gcp-service-account=sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com
    

Saxml をデプロイする

このセクションでは、Saxml 管理者サーバーと Saxml モデルサーバーをデプロイします。

Saxml 管理者サーバーをデプロイする

  1. 次の sax-admin-server.yaml マニフェストを作成します。

    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: sax-admin-server
    spec:
      replicas: 1
      selector:
        matchLabels:
          app: sax-admin-server
      template:
        metadata:
          labels:
            app: sax-admin-server
        spec:
          hostNetwork: false
          serviceAccountName: sax-sa
          containers:
          - name: sax-admin-server
            image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server:v1.1.0
            securityContext:
              privileged: true
            ports:
            - containerPort: 10000
            env:
            - name: GSBUCKET
              value: BUCKET_NAME

    BUCKET_NAME は、Cloud Storage バケット名に置き換えます。

  2. 次のようにマニフェストを適用します。

    kubectl apply -f sax-admin-server.yaml
    
  3. 管理者サーバーの Pod が稼働していることを確認します。

    kubectl get deployment
    

    出力は次のようになります。

    NAME               READY   UP-TO-DATE   AVAILABLE   AGE
    sax-admin-server   1/1     1            1           52s
    

Saxml モデルサーバーをデプロイする

マルチホスト TPU スライスで実行されるワークロードでは、同じ TPU スライス内のピアを検出するために、各 Pod に安定したネットワーク識別子が必要です。これらの識別子を定義するには、IndexedJobStatefulSet ヘッドレス Service または JobSet を使用します。これにより、JobSet に属するすべての Job に対してヘッドレス Service が自動的に作成されます。次のセクションでは、JobSet を使用してモデルサーバー Pod の複数のグループを管理する方法について説明します。

  1. v0.2.3 以降の JobSet をインストールします。

    kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/JOBSET_VERSION/manifests.yaml
    

    JOBSET_VERSION は、JobSet のバージョンに置き換えます。例: v0.2.3

  2. JobSet コントローラが jobset-system Namespace で実行されていることを確認します。

    kubectl get pod -n jobset-system
    

    出力は次のようになります。

    NAME                                        READY   STATUS    RESTARTS   AGE
    jobset-controller-manager-69449d86bc-hp5r6   2/2     Running   0          2m15s
    
  3. 2 つの TPU スライス ノードプールに 2 つのモデルサーバーをデプロイします。次の sax-model-server-set マニフェストを保存します。

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: sax-model-server-set
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: sax-model-server
          replicas: 2
          template:
            spec:
              parallelism: 8
              completions: 8
              backoffLimit: 0
              template:
                spec:
                  serviceAccountName: sax-sa
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 4x8
                  containers:
                  - name: sax-model-server
                    image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server:v1.1.0
                    args: ["--port=10001","--sax_cell=/sax/test", "--platform_chip=tpuv5e"]
                    ports:
                    - containerPort: 10001
                    - containerPort: 8471
                    securityContext:
                      privileged: true
                    env:
                    - name: SAX_ROOT
                      value: "gs://BUCKET_NAME/sax-root"
                    - name: MEGASCALE_NUM_SLICES
                      value: ""
                    resources:
                      requests:
                        google.com/tpu: 4
                      limits:
                        google.com/tpu: 4

    BUCKET_NAME は、Cloud Storage バケット名に置き換えます。

    このマニフェストの内容:

    • replicas: 2 は、Job のレプリカの数です。各ジョブはモデルサーバーを表します。したがって、8 つの Pod のグループになります。
    • parallelism: 8completions: 8 は、各ノードプール内のノード数と等しくなります。
    • Pod が失敗した場合に Job を失敗としてマークするには、backoffLimit: 0 を 0 にする必要があります。
    • ports.containerPort: 8471 は、VM 通信用のデフォルト ポートです。
    • GKE はマルチスライス トレーニングを実行していないため、name: MEGASCALE_NUM_SLICES は環境変数の設定を解除します。
  4. 次のようにマニフェストを適用します。

    kubectl apply -f sax-model-server-set.yaml
    
  5. Saxml 管理サーバーと Model Server Pod のステータスを確認します。

    kubectl get pods
    

    出力は次のようになります。

    NAME                                              READY   STATUS    RESTARTS   AGE
    sax-admin-server-557c85f488-lnd5d                 1/1     Running   0          35h
    sax-model-server-set-sax-model-server-0-0-nj4sm   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-1-sl8w4   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-2-hb4rk   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-3-qv67g   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-4-pzqz6   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-5-nm7mz   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-6-7br2x   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-7-4pw6z   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-0-8mlf5   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-1-h6z6w   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-2-jggtv   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-3-9v8kj   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-4-6vlb2   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-5-h689p   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-6-bgv5k   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-7-cd6gv   1/1     Running   0          24m
    

この例では、16 個のモデルサーバー コンテナがあります。sax-model-server-set-sax-model-server-0-0-nj4smsax-model-server-set-sax-model-server-1-0-8mlf5 は、各グループの 2 つのプライマリ モデルサーバーです。

Saxml クラスタには、それぞれ 4x8 トポロジを持つ 2 つの v5e TPU スライス ノードプールにデプロイされた 2 つのモデルサーバーがあります。

Saxml HTTP Server とロードバランサをデプロイする

  1. 次のビルド済みイメージの HTTP サーバー イメージを使用します。次の sax-http.yaml マニフェストを保存します。

    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: sax-http
    spec:
      replicas: 1
      selector:
        matchLabels:
          app: sax-http
      template:
        metadata:
          labels:
            app: sax-http
        spec:
          hostNetwork: false
          serviceAccountName: sax-sa
          containers:
          - name: sax-http
            image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-http:v1.0.0
            ports:
            - containerPort: 8888
            env:
            - name: SAX_ROOT
              value: "gs://BUCKET_NAME/sax-root"
    ---
    apiVersion: v1
    kind: Service
    metadata:
      name: sax-http-lb
    spec:
      selector:
        app: sax-http
      ports:
      - protocol: TCP
        port: 8888
        targetPort: 8888
      type: LoadBalancer

    BUCKET_NAME は、Cloud Storage バケット名に置き換えます。

  2. sax-http.yaml マニフェストを適用します。

    kubectl apply -f sax-http.yaml
    
  3. HTTP サーバー コンテナの作成が完了するまで待ちます。

    kubectl get pods
    

    出力は次のようになります。

    NAME                                              READY   STATUS    RESTARTS   AGE
    sax-admin-server-557c85f488-lnd5d                 1/1     Running   0          35h
    sax-http-65d478d987-6q7zd                         1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-0-nj4sm   1/1     Running   0          24m
    ...
    
  4. Service に外部 IP アドレスが割り当てられるまで待ちます。

    kubectl get svc
    

    出力は次のようになります。

    NAME           TYPE           CLUSTER-IP    EXTERNAL-IP   PORT(S)          AGE
    sax-http-lb    LoadBalancer   10.48.11.80   10.182.0.87   8888:32674/TCP   7m36s
    

Saxml を使用する

v5e TPU マルチホスト スライスの Saxml でモデルを読み込んでデプロイし、提供します。

モデルを読み込む

  1. Saxml のロードバランサの IP アドレスを取得します。

    LB_IP=$(kubectl get svc sax-http-lb -o jsonpath='{.status.loadBalancer.ingress[*].ip}')
    PORT="8888"
    
  2. 2 つの v5e TPU スライス ノードプールに LmCloudSpmd175B テストモデルを読み込みます。

    curl --request POST \
    --header "Content-type: application/json" \
    -s ${LB_IP}:${PORT}/publish --data \
    '{
        "model": "/sax/test/spmd",
        "model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
        "checkpoint": "None",
        "replicas": 2
    }'
    

    テストモデルにはファインチューニングされたチェックポイントがなく、重みはランダムに生成されます。モデルの読み込みには最大 10 分かかります。

    出力は次のようになります。

    {
        "model": "/sax/test/spmd",
        "path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
        "checkpoint": "None",
        "replicas": 2
    }
    
  3. モデルの準備状況を確認します。

    kubectl logs sax-model-server-set-sax-model-server-0-0-nj4sm
    

    出力は次のようになります。

    ...
    loading completed.
    Successfully loaded model for key: /sax/test/spmd
    

    モデルが完全に読み込まれました。

  4. モデルに関する情報を取得します。

    curl --request GET \
    --header "Content-type: application/json" \
    -s ${LB_IP}:${PORT}/listcell --data \
    '{
        "model": "/sax/test/spmd"
    }'
    

    出力は次のようになります。

    {
    "model": "/sax/test/spmd",
    "model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
    "checkpoint": "None",
    "max_replicas": 2,
    "active_replicas": 2
    }
    

モデルを提供する

プロンプト リクエストを処理します。

curl --request POST \
--header "Content-type: application/json" \
-s ${LB_IP}:${PORT}/generate --data \
'{
  "model": "/sax/test/spmd",
  "query": "How many days are in a week?"
}'

出力には、モデルのレスポンスの例が表示されます。テストモデルにはランダムな重みがあるため、このレスポンスは意味をなさない可能性があります。

モデルの公開を停止する

次のコマンドを実行して、モデルを非公開にします。

curl --request POST \
--header "Content-type: application/json" \
-s ${LB_IP}:${PORT}/unpublish --data \
'{
    "model": "/sax/test/spmd"
}'

出力は次のようになります。

{
  "model": "/sax/test/spmd"
}

クリーンアップ

このチュートリアルで使用したリソースについて、Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。

デプロイされたリソースを削除する

  1. このチュートリアル用に作成したクラスタを削除します。

    gcloud container clusters delete saxml --zone ${ZONE}
    
  2. サービス アカウントを削除します。

    gcloud iam service-accounts delete sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com
    
  3. Cloud Storage バケットを削除します。

    gcloud storage rm -r gs://${GSBUCKET}
    

次のステップ