GKE で JetStream と Pathways を使用してマルチホスト TPU を使用して LLM をサービングする


このガイドでは、複数のノードで Tensor Processing Unit(TPU)を使用して、Google Kubernetes Engine(GKE)で Llama 3.1 405B などの最先端の大規模言語モデル(LLM)をサービングする方法について説明します。

このガイドでは、ポータブルなオープンソース テクノロジー(Kubernetes、JetStreamPathways on CloudLeaderWorkerSet(LWS)API)を使用して、GKE の詳細な制御、拡張性、復元力、移植性、費用対効果を活用して、GKE に AI/ML ワークロードをデプロイしてサービングする方法について説明します。

背景

大規模言語モデルのサイズは大きくなり、単一のホスト TPU スライスに収まらなくなりました。ML 推論では、Cloud 上の Pathways を使用して、相互接続された複数の TPU ノードにまたがる GKE で大規模なマルチホスト推論を実行できます。このガイドでは、マルチホスト TPU スライスを使用して GKE クラスタをプロビジョニングし、Pathways on Cloud バイナリを使用して MaxText フレームワークで JetStream サーバーを起動し、マルチホスト推論リクエストを行う方法について説明します。

GKE で TPU を使用して JetStreamMaxTextPathways で LLM をサービングすることで、マネージド Kubernetes のメリット(費用効率、拡張性、高可用性など)をすべて活用した、プロダクション レディな堅牢なサービング ソリューションを構築できます。このセクションでは、このチュートリアルで使用されている重要なテクノロジーについて説明します。

TPU について

TPU は、Google が独自に開発した特定用途向け集積回路(ASIC)であり、TensorFlowPyTorchJAX などのフレームワークを使用して構築された ML モデルと AI モデルを高速化するために使用されます。

GKE で TPU を使用する前に、次の学習プログラムを完了することをおすすめします。

  1. Cloud TPU システム アーキテクチャで、現在の TPU バージョンの可用性について学習する。
  2. GKE の TPU についてを確認する。

このチュートリアルでは、Llama 3.1-405B モデルのサービングについて説明します。GKE は、低レイテンシでプロンプトをサービングするモデルの要件に基づいて構成された TPU トポロジを使用して、マルチホスト TPU v6e ノードにモデルをデプロイします。

Pathways on Cloud

Pathways は、アクセラレータの大規模なオーケストレーション レイヤです。Pathways は、現在のモデルの最先端のパフォーマンスを維持しながら、新しいシステムと ML 研究のアイデアの探求を可能にするように明示的に設計されています。Pathways を使用すると、単一の JAX クライアント プロセスで 1 つ以上の大規模な TPU スライスにまたがる計算を調整できるため、数百または数千の TPU チップにまたがる ML コンピューティングを効率化できます。

JetStream

JetStream は、Google が開発したオープンソースの推論サービング フレームワークです。JetStream を使用すると、TPU と GPU で高性能、高スループット、メモリ最適化された推論が可能になります。JetStream では、連続バッチ処理、KV キャッシュの最適化、量子化手法などの高度なパフォーマンス最適化により、LLM を簡単にデプロイできます。JetStream では、PyTorch/XLA と JAX TPU のサービングにより、パフォーマンスを最適化できます。

MaxText

MaxText は、FlaxOrbaxOptax などのオープンソースの JAX ライブラリ上に構築された、パフォーマンス、スケーラビリティ、適応性に優れた JAX LLM 実装です。MaxText のデコーダ専用の LLM 実装は Python で記述されています。XLA コンパイラの活用により、カスタム カーネルを構築しなくても高いパフォーマンスを実現できます。

MaxText がサポートする最新のモデルとパラメータ サイズの詳細については、MaxText プロジェクト リポジトリをご覧ください。

Llama 3.1 405B

Llama 3.1 405B は、テキスト生成、翻訳、質問応答など、さまざまな自然言語処理タスク用に設計された Meta の大規模言語モデルです。GKE は、この規模のモデルの分散トレーニングとサービングの実現に必要なインフラストラクチャを提供します。

詳細については、Llama のドキュメントをご覧ください。

アーキテクチャ

このセクションでは、このチュートリアルで使用する GKE アーキテクチャについて説明します。このアーキテクチャには、TPU をプロビジョニングし、モデルをデプロイしてサービングするための JetStream コンポーネントと Pathways コンポーネントをホストする GKE Standard クラスタが含まれています。

次の図は、このアーキテクチャのコンポーネントを示しています。

JetStream コンポーネントと Pathways コンポーネントを含むマルチホスト TPU ノードプールを使用した GKE クラスタのアーキテクチャ。

このアーキテクチャには次のコンポーネントが含まれています。

  • GKE Standard リージョン クラスタ。
  • JetStream デプロイと Pathways コンポーネントをホストするマルチホスト TPU スライス ノードプール。
  • アクセラレータ リソースを管理し、ユーザージョブのアクセラレータの割り当てを調整する Pathways resource manager
  • Pathways resource manager と連携してコンパイルされたプログラムを実行する場所を決定する Pathways client
  • アクセラレータ マシンで実行されて計算を行い、IFRT プロキシ サーバーを介してワークロードにデータを送り返す Pathways worker
  • OSS の Interim Framework Runtime(IFRT) API を実装し、ワークロードと Pathways コンポーネント間の通信ブリッジとして機能する IFRT proxy client
  • IFRT proxy client からリクエストを受け取り、Pathways client に転送して作業を分散する IFRT proxy server
  • 推論リクエストを受け取り、実行プロセスを Pathways workers に委任する JAX ベースの推論サーバーを提供する JetStream-Pathways コンテナ。
  • Service コンポーネントは、インバウンド トラフィックをすべての JetStream HTTP レプリカに分散します。
  • JetStream HTTP は、JetStream の必須フォーマットのラッパーとしてリクエストを受け取り、JetStream の GRPC クライアントに送信する HTTP サーバーです。

始める前に

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  • Make sure that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Enable the API

  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  • Make sure that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Enable the API

  • Make sure that you have the following role or roles on the project: roles/container.admin, roles/iam.serviceAccountAdmin, roles/resourcemanager.projectIamAdmin

    Check for the roles

    1. In the Google Cloud console, go to the IAM page.

      Go to IAM
    2. Select the project.
    3. In the Principal column, find all rows that identify you or a group that you're included in. To learn which groups you're included in, contact your administrator.

    4. For all rows that specify or include you, check the Role column to see whether the list of roles includes the required roles.

    Grant the roles

    1. In the Google Cloud console, go to the IAM page.

      IAM に移動
    2. プロジェクトを選択します。
    3. [ アクセスを許可] をクリックします。
    4. [新しいプリンシパル] フィールドに、ユーザー ID を入力します。 これは通常、Google アカウントのメールアドレスです。

    5. [ロールを選択] リストでロールを選択します。
    6. 追加のロールを付与するには、 [別のロールを追加] をクリックして各ロールを追加します。
    7. [保存] をクリックします。
      • 16 個の TPU v6e PodSlice Lite チップに十分な割り当てがあることを確認します。このチュートリアルでは、オンデマンド インスタンスを使用します。
      • Google Cloud プロジェクトが Pathways の許可リストに登録されていることを確認します。

      モデルへのアクセス権を取得する

      GKE へのデプロイのために Meta Llama 3.1-405B チェックポイントにアクセスする手順は次のとおりです。

      1. ライセンス同意契約に署名する
      2. Meta Llama のダウンロード ページにアクセスします。
      3. モデルの利用規約を確認して同意し、モデルのダウンロードに必要な URL を取得します。
      4. モデルのチェックポイントをダウンロードするために、適切なモデルのモデル ID を見つます。サポートされているモデルとその ID の一覧については、llama CLI のドキュメントをご覧ください。たとえば、Llama 3.1-405B モデルの場合は Llama 3.1-405B-Instruct:bf16-mp16 を使用します。

      環境を準備する

      このチュートリアルでは、Cloud Shell を使用してGoogle Cloudでホストされているリソースを管理します。Cloud Shell には、このチュートリアルに必要な kubectlgcloud CLI などのソフトウェアがプリインストールされています。

      Cloud Shell を使用して環境を設定するには、次の操作を行います。

      1. Google Cloud コンソールで Cloud Shell 有効化アイコンCloud Shell をアクティブにする)をクリックして、Google Cloud コンソールで Cloud Shell セッションを起動します。これにより、 Google Cloud コンソールの下部ペインでセッションが起動します。

      2. デフォルトの環境変数を設定します。

        gcloud config set project PROJECT_ID
        gcloud config set billing/quota_project PROJECT_ID
        export PROJECT_ID=$(gcloud config get project)
        export CLUSTER_NAME=CLUSTER_NAME
        export BUCKET_NAME=BUCKET_NAME
        export CONTROL_PLANE_LOCATION=CONTROL_PLANE_LOCATION
        export NODE_LOCATION=NODE_LOCATION
        export CLUSTER_VERSION=CLUSTER_VERSION
        export MACHINE_TYPE=ct6e-standard-4t
        export TPU_TYPE=v6e
        export TOPOLOGY=4x4
        export WORKERS_PER_SLICE=4
        

        次の値を置き換えます。

        • PROJECT_ID: 実際の Google Cloud プロジェクト ID
        • CLUSTER_NAME: GKE クラスタの名前。
        • BUCKET_NAME: Cloud Storage バケットの名前。gs:// プレフィックスを指定する必要はありません。
        • CONTROL_PLANE_LOCATION: GKE クラスタ、Cloud Storage バケット、TPU ノードがある Compute Engine リージョン。TPU v6e マシンタイプを使用できるゾーン(us-east1us-east5europe-west4asia-northeast1us-south1 など)が含まれているリージョンです。
        • NODE_LOCATION: TPU リソースが使用可能なゾーン(例: us-east1-d)。
        • CLUSTER_VERSION: GKE バージョン。使用するマシンタイプをサポートしている必要があります。デフォルトの GKE バージョンは、ターゲット TPU で利用できない場合があります。TPU マシンタイプで使用できる最小 GKE バージョンのリストについては、GKE での TPU の可用性をご覧ください。
        • MACHINE_TYPE: v6e マシンタイプ。
        • TPU_TYPE: ノードプール(v6e)の命名に使用される接頭辞。
        • TOPOLOGY: TPU v6e トポロジ。
        • WORKERS_PER_SLICE: ノードプールまたは TPU スライスあたりのノード数。

      Google Cloud リソースを作成して構成する

      必要なリソースを作成する手順は次のとおりです。

      GKE クラスタを作成する

      1. リージョン GKE Standard クラスタを作成し、

        gcloud container clusters create CLUSTER_NAME \
            --project=PROJECT_ID \
            --cluster-version=CLUSTER_VERSION \
            --location=CONTROL_PLANE_LOCATION \
            --scopes=cloud-platform \
            --machine-type=n2-standard-32
        

        クラスタの作成には数分かかることもあります。

        CLUSTER_VERSION は、適切なクラスタ バージョンに置き換えます。

      2. 4x4 トポロジと 4 つのノードを持つ TPU v6e ノードプールを 1 つ作成します。

        gcloud container node-pools create multihost-np \
        --project=PROJECT_ID \
        --location=CONTROL_PLANE_LOCATION \
        --node-locations=NODE_LOCATION \
        --cluster=CLUSTER_NAME \
        --machine-type=MACHINE_TYPE \
        --num-nodes=WORKERS_PER_SLICE \
        --tpu-topology=TOPOLOGY \
        --scopes cloud-platform \
        --placement-type=COMPACT \
        --workload-metadata=GCE_METADATA
        

      Storage オブジェクト アクセス用のサービス アカウントを構成する

      IAM サービス アカウントとして機能するように Kubernetes サービス アカウントを構成します。

      1. アプリケーションの IAM サービス アカウントを作成します。

        gcloud iam service-accounts create jetstream-pathways
        
      2. Cloud Storage を管理する IAM サービス アカウントの IAM ポリシー バインディングを追加します。これは、IAM サービス アカウントがチェックポイントの保存先となるストレージ バケットにアクセスできるようにするためです。

        gcloud projects add-iam-policy-binding ${PROJECT} \
          --member "serviceAccount:jetstream-pathways@${PROJECT}.iam.gserviceaccount.com" \
          --role roles/storage.objectUser
        
        gcloud projects add-iam-policy-binding ${PROJECT} \
          --member "serviceAccount:jetstream-pathways@${PROJECT}.iam.gserviceaccount.com" \
          --role roles/storage.insightsCollectorService
        
      3. Kubernetes サービス アカウントに IAM サービス アカウントのメールアドレスでアノテーションを付けます。

        kubectl annotate serviceaccount default \
        iam.gke.io/gcp-service-account=jetstream-pathways@${PROJECT}.iam.gserviceaccount.com
        

      Artifact Registry で認証するよう Docker を構成する

      許可リストに登録された Pathways イメージを pull できるように、Artifact Registry に対する認証を行うよう Docker を構成します

      gcloud auth login
      gcloud auth configure-docker
      

      チェックポイントの変換

      Meta Llama 3.1-405B チェックポイントを MaxText 互換の int8 推論チェックポイントに変換するには、Llama3.1-405B を使用したチェックポイント変換の手順で操作します。デプロイでは、load_parameters_path フラグを使用してチェックポイントが読み込まれます。

      Pathways の一時ファイルを保存する Cloud Storage バケットを作成する

      コンパイル キャッシュなどの Pathways の一時ファイルを保存する Cloud Storage バケットを作成します。

      export PATHWAYS_BUCKET=PATHWAYS_BUCKET
      gcloud storage buckets create gs://$PATHWAYS_BUCKET
      

      JetStream-MaxText と Pathways をデプロイする

      JetStream-MaxText と Pathways のモデルサーバーをデプロイします。

      GKE クラスタに接続する

      gcloud container clusters get-credentials "${CLUSTER}" --project "${PROJECT}" --location "${ZONE}"
      

      LeaderWorkerSet(LWS)API をデプロイする

      LWS は、ステートフルな分散アプリケーション(特にリーダー / ワーカー アーキテクチャを持つアプリケーション)のデプロイと管理用に設計されたカスタム リソースです。特に、大規模なモデルがシャーディングされ、複数のノード上の複数のデバイスで提供される AI/ML ワークロードに適しています。

      VERSION=v0.6.1
      kubectl apply --server-side -f https://github.com/kubernetes-sigs/lws/releases/download/$VERSION/manifests.yaml
      

      LeaderWorkerSet コントローラが完全に利用可能になるまで待ちます。

      kubectl wait deploy/lws-controller-manager -n lws-system --for=condition=available --timeout=5m
      

      出力例を以下に示します。

      deployment.apps/lws-controller-manager condition met
      

      LeaderWorkerSet コントローラが lws-system Namespace で実行されていることを確認します。

      kubectl get pod -n lws-system
      

      出力例を以下に示します。

      NAME                          READY   STATUS    RESTARTS    AGE
      lws-controller-manager-abcd   1/1     Running   0           40s
      lws-controller-manager-efgh   1/1     Running   0           40s
      

      ワークロード マニフェストをデプロイする

      1. 次のマニフェストを jetstream-pathways-llama-3-1-405b-4x4.yaml として保存します。

        apiVersion: leaderworkerset.x-k8s.io/v1
        kind: LeaderWorkerSet
        metadata:
          name: jetstream-pathways
          annotations:
            leaderworkerset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
        spec:
          replicas: 1
          leaderWorkerTemplate:
            leaderTemplate:
              metadata:
                labels:
                  app: jetstream-pathways
              spec:
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: 4x4
                tolerations:
                - key: "google.com/tpu"
                  operator: "Exists"
                  effect: "NoSchedule"
                containers:
                - name: pathways-proxy
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-0.5.3
                  args:
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38681
                - name: pathways-rm
                  env:
                  - name: HOST_ADDRESS
                    value: "$(LWS_LEADER_ADDRESS)"
                  - name: TPU_SKIP_MDS_QUERY
                    value: "true"
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-0.5.3
                  args:
                  - --server_port=38677
                  - --gcs_scratch_location=PATHWAYS_BUCKET
                  - --node_type=resource_manager
                  - --instance_count=1
                  - --instance_type=tpuv6e:4x4
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38677
                - name: jax-tpu
                  image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0
                  env:
                  - name: LOG_LEVEL
                    value: "INFO"
                  args:
                  - MaxText/configs/v5e/inference/llama3_405b_v5e-64.yml
                  - model_name=llama3.1-405b
                  - load_parameters_path=CHECKPOINT_PATH
                  - max_prefill_predict_length=1024
                  - max_target_length=2048
                  - async_checkpointing=false
                  - steps=1
                  - ici_fsdp_parallelism=1
                  - ici_autoregressive_parallelism=2
                  - ici_tensor_parallelism=8
                  - scan_layers=false
                  - weight_dtype=bfloat16
                  - per_device_batch_size=6
                  - enable_single_controller=true
                  - quantization=int8
                  - quantize_kvcache=true
                  - checkpoint_is_quantized=true
                  - enable_model_warmup=true
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 9000
                  startupProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 1
                    initialDelaySeconds: 600
                    failureThreshold: 10000
                  livenessProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 60
                    failureThreshold: 10
                  readinessProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 60
                    failureThreshold: 10
                - name: jetstream-http
                  image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 8000
            size: 5
            workerTemplate:
              spec:
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: 4x4
                tolerations:
                - key: "google.com/tpu"
                  operator: "Exists"
                  effect: "NoSchedule"
                containers:
                - name: worker
                  args:
                  - --server_port=38679
                  - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
                  - --gcs_scratch_location=PATHWAYS_BUCKET
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-0.5.3
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38679
                  resources:
                    limits:
                      google.com/tpu: "4"
        --- 
        apiVersion: v1
        kind: Service
        metadata:
          name: jetstream-svc
        spec:
          selector:
            app: jetstream-pathways
          ports:
          - protocol: TCP
            name: jetstream-http
            port: 8000
            targetPort: 8000
      2. load_parameters_path フィールドの値を、チェックポイント変換プロセスで生成されたチェックポイント パスに設定します。

        • bf16 チェックポイントの場合、パスは gs://OUTPUT_BUCKET_DIRECTORY/bf16/unscanned/checkpoints/0/items のようになります。
        • int8 チェックポイントの場合、gs://OUTPUT_BUCKET_DIRECTORY/int8 のようになります。

        gcs_scratch_location フィールドの値を、作成済みの Pathways バケットに設定します。

        perl -pi -e 's|CHECKPOINT_PATH|gs://OUTPUT_BUCKET_DIRECTORY/int8|g' jetstream-pathways-llama-3-1-405b-4x4.yaml
        perl -pi -e 's|PATHWAYS_BUCKET|gs://PATHWAYS_BUCKET|g' jetstream-pathways-llama-3-1-405b-4x4.yaml
        

      Deployment マニフェストを適用する

      マニフェストを適用してサーバーをデプロイします。

      kubectl apply -f jetstream-pathways-llama-3-1-405b-4x4.yaml
      

      モデルサーバーが起動します。

      モデルサーバーの起動を確認する

      405B モデルの場合、チェックポイントの復元に 10~20 分ほどかかることがあります。enable_model_warmup フラグを有効にしている場合は、モデルのウォームアップにさらに時間がかかることがあります。

      kubectl logs -f jetstream-pathways-0 -c jax-tpu
      

      出力は次のようになります。

      2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode
      2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0.
      2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0.
      2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0.
      2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0.
      2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized.
      ...
      ...
      ...
      INFO:     Started server process [7]
      INFO:     Waiting for application startup.
      INFO:     Application startup complete.
      INFO:     Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
      

      Llama 3.1-405b をサービングする

      Llama 3.1-405b モデルをサービングするには、ポート転送を設定します。

      kubectl port-forward svc/jetstream-svc 8000:8000
      

      ポート転送を使用すると、クラスタの外部から Service にアクセスできます。JetStream-Pathways Deployment には、GKE の ClusterIP Service を介してアクセスできます。ClusterIP Service にはクラスタ内からのみアクセスできます。

      モデルとのやりとりを行う

      新しいターミナルで次のコマンドを実行します。

      curl --request POST \
      --header "Content-type: application/json" \
      -s \
      localhost:8000/generate \
      --data \
      '{
          "prompt": "What are the top 5 programming languages",
          "max_tokens": 200
      }'
      

      モデルのウォームアップにより、最初のリクエストが完了するまでに数秒かかることがあります。出力例を以下に示します。

      {
          "response": " for web development?\nThe top 5 programming languages for web development are:\n1. **JavaScript**: JavaScript is the most popular language for web development, used by over 90% of websites for client-side scripting. It's also popular for server-side programming with technologies like Node.js.\n2. **HTML/CSS**: HTML (Hypertext Markup Language) and CSS (Cascading Style Sheets) are not programming languages, but are essential for building websites. HTML is used for structuring content, while CSS is used for styling and layout.\n3. **Python**: Python is a popular language for web development, especially with frameworks like Django and Flask. It's known for its simplicity, flexibility, and large community of developers.\n4. **Java**: Java is a popular language for building enterprise-level web applications, especially with frameworks like Spring and Hibernate. It's known for its platform independence, strong security features, and large community of developers.\n5. **PHP**: PHP is a mature language for web"
      }
      

      次の操作が終わりました。

      1. TPU を使用して MaxText と Pathways を活用する JetStream モデルサーバーを GKE にデプロイしました。
      2. gs://BUCKET_NAME に Llama 3.1-405B int8 チェックポイントを作成しました。
      3. モデルをサービングして操作しました。

      分離型サービング

      分離型サービングは、プレフィル ステージとデコード ステージを異なるホストに分割しながら LLM をサービングする手法です。このアプローチにより、リソース使用率が最適化され、スループットとレイテンシが改善されます。

      • プリフィル: 入力プロンプトのフォワードパスで、Key-Value キャッシュを初期化します。

      • デコード: 出力トークンを段階的に生成する手順。1 ステップあたり 1 つのトークン、1 回の反復あたり 1 つの KV キャッシュ値。

      1. デフォルトの環境変数を設定します。

        export NODE_POOL_NAME=dis-v6e-8
        export NODE_POOL_SIZE=2
        export MACHINE_TYPE=ct6e-standard-4t
        export TOPOLOGY=2x4
        export WORKERS_PER_SLICE=2
        
      2. v6e-8 ノードを使用するノードプールを 2 つ作成します。

        for i in $(seq 1 NODE_POOL_SIZE); do
          gcloud container node-pools create NODE_POOL_NAME-${i}-np \
          --project=PROJECT \
          --cluster=CLUSTER_NAME \
          --location=CONTROL_PLANE_LOCATION \
          --node-locations=NODE_LOCATION \
          --machine-type=MACHINE_TYPE \
          --num-nodes=WORKERS_PER_SLICE \
          --tpu-topology=TOPOLOGY \
          --scopes=cloud-platform \
          --workload-metadata=GCE_METADATA
        done
        

      チェックポイントの変換

      Meta Llama 2-70B チェックポイントを MaxText 互換の int8 推論チェックポイントに変換するには、Llama2-70B を使用したチェックポイント変換の手順で操作します。Meta の利用規約に同意する際に、モデルとして Llama2-70B を選択します。デプロイでは、load_parameters_path フラグを使用してチェックポイントが読み込まれます。

      checkpoint-job.yaml ファイルで次のパラメータを置き換えます。

      - --meta_url=META_URL
      - --model_name=llama-2
      - --model_path=Llama-2-70b-chat
      - --output_directory=gs://BUCKET_NAME/maxtext/llama-2-70b
      

      チェックポイントは、load_parameters_path フラグとともにデプロイで使用されます。

      分離されたサービングで JetStream Pathways をデプロイする

      1. 次のマニフェストを jetstream-pathways-disagg-llama-2-70b-2-2x4.yaml として保存します。

        apiVersion: leaderworkerset.x-k8s.io/v1
        kind: LeaderWorkerSet
        metadata:
          name: jetstream-pathways
          annotations:
            leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology: cloud.google.com/gke-nodepool
        spec:
          replicas: 1
          leaderWorkerTemplate:
            subGroupPolicy:
              subGroupSize: 2
            leaderTemplate:
              metadata:
                labels:
                  app: jetstream-pathways
              spec:
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: 2x4
                tolerations:
                - key: "google.com/tpu"
                  operator: "Exists"
                  effect: "NoSchedule"
                containers:
                - name: pathways-proxy
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-0.5.3
                  args:
                  - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
                  - --server_port=38681
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  - --xla_jf_auto_cross_replica_sharding=false
                  - --xla_tpu_enable_windowed_einsum_for_reduce_scatter=false
                  - --xla_tpu_enable_windowed_einsum_for_all_gather=false
                  - --xla_tpu_prefer_latch_optimized_rhs_layouts=true
                  - --xla_tpu_enable_experimental_fusion_cost_model=false
                  - --xla_tpu_dot_dot_fusion_duplicated=false
                  - --xla_tpu_dot_dot_fusion=true
                  - --xla_jf_conv_input_fusion=true
                  - --xla_jf_conv_output_fusion=true
                  - --xla_tpu_rwb_fusion=false
                  - --xla_tpu_copy_fusion_pad_unpad_ratio=0
                  - --xla_tpu_licm_size_inflation_ratio=1
                  - --xla_tpu_copy_elision_analysis_allowance=150000
                  - --xla_tpu_copy_insertion_use_region_analysis_limit=10000
                  - --xla_tpu_order_dot_after_layout=true
                  - --xla_jf_rematerialization_percent_shared_memory_limit=100
                  - --xla_tpu_use_repeated_instance_for_preferred_prefetch_time=true
                  - --xla_tpu_enforce_prefetch_fifo_order=false
                  - --xla_tpu_prefetch_interval_picker_size_override=6000000
                  - --xla_tpu_async_copy_bandwidth_scaling_factor=1
                  - --xla_tpu_nd_short_transfer_max_chunks=-1
                  - --xla_tpu_enable_aggressive_broadcast_priority_update=true
                  - --xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers=SQRT
                  - --xla_tpu_memory_bound_loop_optimizer_options=enabled:true
                  - --xla_tpu_enable_copy_fusion=true
                  - --xla_tpu_enable_cross_program_prefetch_freeing=false
                  - --xla_tpu_enable_dot_strength_reduction=true
                  - --xla_tpu_layout_use_dot_grouping=false
                  - --xla_tpu_msa_inefficient_use_to_copy_ratio=0.5
                  - --xla_tpu_reduce_loop_fusion_dup_with_unfusable_user=false
                  - --xla_tpu_vector_load_fusion_window=1024
                  - --xla_tpu_vector_store_fusion_window=256
                  - --xla_jf_conv_reshape_fusion=false
                  - --xla_tpu_input_conv_multi_users=false
                  - --xla_tpu_enable_multi_level_input_dot_dot_fusion=false
                  - --xla_tpu_enable_multi_level_output_dot_dot_fusion=false
                  - --xla_tpu_dot_dot_fusion_separable_convs_only=false
                  - --xla_tpu_enable_multi_level_nested_loop_fusion=true
                  - --xla_tpu_nested_dot_fusion=true
                  - --xla_tpu_enable_multi_level_nested_dot_fusion=false
                  - --xla_jf_enable_multi_output_fusion=true
                  - --xla_tpu_use_lp_llo_scheduler_for_dot_dot_fusions=false
                  - --xla_tpu_enable_flash_attention=true
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38681
                - name: pathways-rm
                  env:       
                  - name: HOST_ADDRESS
                    value: "$(LWS_LEADER_ADDRESS)"
                  - name: TPU_SKIP_MDS_QUERY
                    value: "true"
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-0.5.3
                  args:
                  - --server_port=38677
                  - --gcs_scratch_location=PATHWAYS_BUCKET
                  - --node_type=resource_manager
                  - --instance_count=2
                  - --instance_type=tpuv6e:2x4
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38677
                - name: jax-tpu
                  image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0
                  args:
                  - MaxText/configs/base.yml
                  - tokenizer_path=assets/tokenizer.llama2
                  - load_parameters_path=CHECKPOINT_PATH
                  - max_prefill_predict_length=1024
                  - max_target_length=2048
                  - model_name=llama2-70b
                  - ici_fsdp_parallelism=1
                  - ici_autoregressive_parallelism=1
                  - ici_tensor_parallelism=-1
                  - scan_layers=false
                  - weight_dtype=bfloat16
                  - per_device_batch_size=27
                  - checkpoint_is_quantized=true 
                  - quantization=int8
                  - quantize_kvcache=true
                  - compute_axis_order=0,2,1,3
                  - ar_cache_axis_order=0,2,1,3
                  - stack_prefill_result_cache=True
                  - inference_server=ExperimentalMaxtextDisaggregatedServer_8
                  - inference_benchmark_test=True
                  - enable_model_warmup=True
                  env:
                  - name: LOG_LEVEL
                    value: "INFO"
                  imagePullPolicy: Always
                  securityContext:
                    capabilities:
                      add: ["SYS_PTRACE", "NET_ADMIN", "SYS_TIME"]
                  ports: 
                  - containerPort: 9000
                  startupProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 1
                    initialDelaySeconds: 240
                    failureThreshold: 10000
                  livenessProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 60
                    failureThreshold: 100
                  readinessProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 60
                    failureThreshold: 100
                - name: jetstream-http
                  image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 8000
            size: 5
            workerTemplate:
              spec:
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: 2x4
                containers:
                - name: worker
                  args:
                  - --server_port=38679
                  - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
                  - --gcs_scratch_location=PATHWAYS_BUCKET
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-0.5.3
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38679
                  resources:
                    limits:
                      google.com/tpu: "4"
        --- 
        apiVersion: v1
        kind: Service
        metadata:
          name: jetstream-svc
        spec:
          selector:
            app: jetstream-pathways
          ports:
          - protocol: TCP
            name: jetstream-http
            port: 8000
            targetPort: 8000
      2. load_parameters_path フィールドの値を、チェックポイント変換プロセスで生成されたチェックポイント パスに設定します。

        • bf16 チェックポイントの場合、パスは gs://OUTPUT_BUCKET_DIRECTORY/bf16/unscanned/checkpoints/0/items のようになります。
        • int8 チェックポイントの場合、gs://OUTPUT_BUCKET_DIRECTORY/int8 のようになります。

        gcs_scratch_location フィールドの値を、作成済みの Pathways バケットに設定します。

        perl -pi -e 's|CHECKPOINT_PATH|BUCKET_NAME/maxtext/llama-2-70b/int8|g' jetstream-pathways-disagg-llama-2-70b-2-2x4.yaml
        perl -pi -e 's|PATHWAYS_BUCKET|gs://PATHWAYS_BUCKET|g' jetstream-pathways-disagg-llama-2-70b-2-2x4.yaml
        
      3. 次のようにマニフェストを適用します。

        kubectl apply -f jetstream-pathways-disagg-llama-2-70b-2-2x4.yaml
        

        チェックポイントのサイズによっては、モデルサーバーがチェックポイントを復元するまでに時間がかかることがあります。70B モデルの場合、モデルのウォームアップを含めて、チェックポイントの復元に約 8 分かかることがあります。ログをさらに確認して、モデルサーバーの起動を確認し、モデルを操作できるようにポート転送を設定することで、準備ができた時点を特定できます。

      次の操作が終わりました。

      1. TPU と分離型サービングを使用して MaxText と Pathways を活用する JetStream モデルサーバーを GKE にデプロイしました。
      2. gs://BUCKET_NAME に Llama 2-70B int8 チェックポイントを作成しました。
      3. モデルをサービングして操作しました。

      問題のトラブルシューティング

      • Empty reply from server というメッセージが表示された場合は、コンテナがモデルデータのダウンロードを完了していない可能性があります。モデルのサービング準備ができていることを示す「Connected」というメッセージがないか、再度 Pod のログを確認します。
      • Connection refused メッセージが見つかった場合は、ポート転送が有効であることを確認します。

      クリーンアップする

      このチュートリアルで使用したリソースについて、Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。

      デプロイされたリソースを削除する

      このガイドで作成したリソースについて Google Cloud アカウントに課金されないようにするには、次のコマンドを実行し、プロンプトに従って操作します。

      gcloud container clusters delete CLUSTER_NAME --location=CONTROL_PLANE_LOCATION
      
      gcloud iam service-accounts delete jetstream-pathways@PROJECT_ID.iam.gserviceaccount.com
      
      gcloud storage rm --recursive gs://BUCKET_NAME
      

      次のステップ