GKE で JAX、Ray Train、TPU Trillium を使用して LLM をトレーニングする

このチュートリアルでは、MaxTextRay Train、TPU を使用して Google Kubernetes Engine(GKE)で Llama 3 8B 大規模言語モデル(LLM)をトレーニングする方法について説明します。

このチュートリアルでは、必要なクラウド インフラストラクチャの構成から、マルチホスト TPU でトレーニング ワークロードを送信して正常に実行するまで、エンドツーエンドの完全なチュートリアルを提供します。

このチュートリアルは、分散マルチホスト TPU スライスで大規模なモデルをトレーニングする方法を学習するプラットフォーム管理者とオペレーター、データおよび AI スペシャリストを対象としています。

背景

GKE、KubeRay、MaxText、TPU を組み合わせることで、大規模なモデル トレーニングのための強力でスケーラブルなプラットフォームが実現します。このセクションでは、このガイドで使用されている重要なテクノロジーについて説明します。

JAX

JAX は、アクセラレータ指向の配列計算とプログラム変換のための Python ライブラリで、高パフォーマンスの数値計算と大規模な ML 用に設計されています。

JAX は、jax.gradjax.jitjax.vmap などの数値関数を変換するための拡張可能なシステムを提供します。XLA コンパイラを利用して、GPU や TPU などのアクセラレータで効率的にスケーリングする高度に最適化されたコードを作成します。JAX の主な強みはコンポーザビリティです。ユーザーはこれらの変換を組み合わせて、分散実行用の複雑で高パフォーマンスの数値プログラムを構築できます。

MaxText

MaxText は、拡張性とカスタマイズ性を重視して設計された、高パフォーマンスのオープンソース大規模言語モデル(LLM)です。MaxText は JAX 上に構築されており、Cloud TPU と GPU で効率的に実行できるように最適化されています。

TPU

Tensor Processing Unit(TPU)は、機械学習ワークロードを最適化するために Google が作成したカスタム設計のアクセラレータです。汎用 CPU や並列処理 GPU とは異なり、TPU はディープ ラーニングの基盤となる大規模な行列とテンソルの計算に特化しているため、この特定のタスクを効率的に実行できます。TPU の主な利点は、パフォーマンス拡張です。

このチュートリアルでは、第 6 世代 TPU である TPU Trillium を使用します。詳細については、TPU Trillium を使用するメリットをご覧ください。

KubeRay

KubeRay は、Kubernetes で Ray アプリケーションをデプロイ、管理、モニタリングするための統一された方法を提供する Kubernetes オペレーターです。KubeRay オペレーターは、Ray on GKE アドオンを介してインストールおよび管理されます。これは、GKE 上の Ray クラスタをデプロイして管理するおすすめの方法です。

目標

このチュートリアルでは、次の方法を説明します。

  1. マルチホスト TPU ノードプールを使用して GKE クラスタを設定します。
  2. 分散トレーニング環境を管理するように KubeRay を構成します。
  3. MaxText、Ray、JAX の依存関係を含むカスタム Docker イメージをビルドします。
  4. Ray Train の JaxTrainer を使用して、TPU スライス全体で MaxText トレーニング ループをオーケストレートする Python トレーニング スクリプトを作成します。
  5. 必要な TPU リソースを使用してヘッドノードとワーカーノードをプロビジョニングする RayCluster カスタム リソースを定義します。
  6. トレーニング ジョブを RayCluster に送信し、進行状況をモニタリングします。
  7. Cloud Storage を使用して、モデルのチェックポイントを保存します。

始める前に

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • Install the Google Cloud CLI.

  • 外部 ID プロバイダ(IdP)を使用している場合は、まず連携 ID を使用して gcloud CLI にログインする必要があります。

  • gcloud CLI を初期化するには、次のコマンドを実行します。

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com
  • Install the Google Cloud CLI.

  • 外部 ID プロバイダ(IdP)を使用している場合は、まず連携 ID を使用して gcloud CLI にログインする必要があります。

  • gcloud CLI を初期化するには、次のコマンドを実行します。

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com
  • Grant roles to your user account. Run the following command once for each of the following IAM roles: roles/container.admin, roles/iam.serviceAccountAdmin

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    Replace the following:

    • PROJECT_ID: Your project ID.
    • USER_IDENTIFIER: The identifier for your user account. For example, myemail@example.com.
    • ROLE: The IAM role that you grant to your user account.
  • このチュートリアルでは TPU Trillium(v6e)を利用するため、利用可能なリージョンまたはゾーンを選択します。詳細については、Cloud TPU の割り当てをご覧ください。

環境を準備する

このチュートリアルでは、Cloud Shell を使用します。Cloud Shell には、このチュートリアルで使用する gcloudhelmkubectl コマンドライン ツールがプリインストールされています。

  1. Google Cloud コンソールに移動します。

  2. Google Cloud コンソール ウィンドウの上部にある [Cloud Shell をアクティブにする] Shell をアクティブにするボタン ボタンをクリックします。

    Google Cloud コンソールの新しいフレーム内で Cloud Shell セッションが開き、コマンドライン プロンプトが表示されます。

  3. Python 仮想環境を作成してアクティブにします。

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  4. Ray CLI とその他の依存関係をインストールします。

    pip install "ray[default]==2.49.1"
    
  5. 次の環境変数を設定します。

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export ARTIFACT_REGISTRY=ARTIFACT_REGISTRY
    

    次のように置き換えます。

    • GS_BUCKET: Cloud Storage バケットの名前。
    • KSA_NAME: Kubernetes サービス アカウントの名前。
    • CLUSTER_NAME: 新しいクラスタの名前。
    • REGION: TPU Trillium の容量が使用可能なリージョン。
    • ZONE: TPU Trillium の容量が使用可能なゾーン。詳細については、GKE での TPU の可用性をご覧ください。
    • ARTIFACT_REGISTRY: Artifact Registry リポジトリの名前。

GKE クラスタを作成する

GKE Autopilot クラスタまたは GKE Standard クラスタの TPU で KubeRay を構成できます。フルマネージドの Kubernetes エクスペリエンスを実現するには、Autopilot クラスタを使用することをおすすめします。ワークロードに最適な GKE の運用モードを選択するには、GKE の運用モードについてをご覧ください。

Autopilot

  1. Cloud Shell で、次のコマンドを実行します。

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION
    
  2. クラスタと通信するには、kubectl を構成します。

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    

Standard

  1. Cloud Shell で、次のコマンドを実行して、Ray オペレータ アドオンを有効にする Standard クラスタを作成します。

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator \
        --addons GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE
    

    このコマンドは GcsFuseCsiDriver も有効にします。これにより、Pod は Cloud Storage バケットをローカル ファイル システムとしてマウントできます。クラスタの作成には数分かかることもあります。

  2. クラスタと通信するには、kubectl を構成します。

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=LOCATION
    
  3. マルチホスト TPU スライス ノードプールを作成します。

    gcloud container node-pools create v6e-16 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4
    

GKE は、4 つの TPU Trillium(v6e)VM で構成されるノードプールをプロビジョニングします。これらは、4x4 トポロジを使用してマルチホスト TPU スライスとして構成され、分散トレーニング ワークロードに対応できます。

Ray オペレーターが有効になっている GKE クラスタは、クラスタに KubeRay と KubeRay TPU Webhook を自動的にインストールします。

Cloud Storage バケットとサービス アカウントを構成する

  1. マルチホスト TPU ノード間で共有されるチェックポイント用の Cloud Storage バケットを作成します。

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. Cloud Storage バケットへのアクセスを有効にするには、Kubernetes サービス アカウントを作成します。

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. Cloud Storage バケットへのアクセスを有効にするには、必要な IAM ポリシー バインディングをサービス アカウントに追加します。

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

トレーニング スクリプトを作成する

次のスクリプトは、Ray Train の JaxTrainer を使用して分散 MaxText トレーニング ジョブを実行します。このスクリプトは、マルチホスト TPU スライス ノードプールのトレーニング環境を構成し、各ワーカーノードで MaxText トレーニング ジョブを実行します。train_loop_per_worker 関数は MaxText のメイン エントリ ポイントをラップし、Ray の分散スケジューラを使用してマルチホスト TPU スライスで MaxText トレーナーを実行します。

  1. 次の Python スクリプトを maxtext_ray_trainer.py として保存します。

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        from MaxText.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                num_workers=4,
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                worker_runtime_env={
                    "env_vars": {
                        "JAX_PLATFORMS": "tpu",
                        "ENABLE_PJRT_COMPATIBILITY": "true",
                        "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true",
                        "TPU_SLICE_BUILDER_DUMP_ICI": "true",
                        "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",
                    }
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
  2. カスタム イメージをホストするには、Artifact Registry リポジトリを作成します。

    gcloud artifacts repositories create ${ARTIFACT_REGISTRY} \
        --repository-format=docker --location=${REGION} && \
    gcloud auth configure-docker ${REGION}-docker.pkg.dev
    
  3. トレーニング用の Ray と MaxText の依存関係を含むイメージをビルドするには、Dockerfile を作成します。

    # Start from a Ray base image which includes JaxTrainer API.
    # Maxtext with TPU requires Python 3.12.
    FROM rayproject/ray:2.49.1-py312
    
    USER root
    RUN groupadd -r ray 2>/dev/null || true && usermod -g ray ray
    
    RUN sudo apt-get update -y \
      && sudo apt-get install --no-install-recommends -y git \
      && sudo rm -rf /var/lib/apt/lists/*
    
    WORKDIR /app
    
    # Clone the Maxtext repo and build from source, installing TPU dependencies.
    RUN git clone https://github.com/AI-Hypercomputer/maxtext.git
    
    RUN pip install --no-cache-dir uv
    
    RUN cd maxtext && \
        uv pip install --no-cache --system -e .[tpu] --resolution=lowest && \
        install_maxtext_github_deps
    
    # Copy the Ray Maxtext trainer to run on the remote container.
    COPY maxtext_ray_trainer.py .
    
    RUN chown -R ray:ray .
    ENV PYTHONPATH=/app/maxtext/src:/app/maxtext:/app
    USER ray
  4. Docker イメージをビルドしてタグ付けし、Artifact Registry に push します。

    export DOCKER_IMAGE=${REGION}-docker.pkg.dev/${PROJECT_ID}/${ARTIFACT_REGISTRY}/ray-maxtext:latest
    gcloud builds submit --tag ${DOCKER_IMAGE}
    

モデルのトレーニング

  1. 次のサンプル マニフェストを maxtext-tpu-cluster.yaml として保存します。

    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: ${DOCKER_IMAGE}
                imagePullPolicy: IfNotPresent
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: gcs-fuse-cache
              emptyDir:
                medium: Memory
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs"
      workerGroupSpecs:
        - replicas: 1
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: {}
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              containers:
                - name: ray-worker
                  image: ${DOCKER_IMAGE}
                  imagePullPolicy: IfNotPresent
                  resources:
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: JAX_PLATFORMS
                      value: tpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: gcs-fuse-cache
                emptyDir:
                  medium: Memory
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs"
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4

    上記の RayCluster 仕様では、レプリカごとに 4 つのワーカー(numOfHosts: 4)を含む TPU ワーカー グループを作成します。各ワーカーは 4 つの TPU チップ(google.com/tpu: "4")をリクエストします。ワーカーは、TPU Trillium(tpu-v6e-slice)を実行し、同じコロケーションされたマルチホスト スライスの一部であるノードでスケジュールされます。KubeRay は 4 つのワーカーすべてをアトミックにスケーリングします。必要な JAX 環境変数とスケジューリング用の Pod アフィニティは、変更用 Webhook を介して GKE によってブートストラップされます。

  2. YAML ファイルで必要な値を構成するには、envsubst を使用して RayCluster を作成します。

    envsubst < maxtext-tpu-cluster.yaml | kubectl apply -f -
    
  3. クラスタが使用できるようになり、実行中であることを確認します。

    kubectl get rayclusters maxtext-tpu-cluster
    

    出力例を以下に示します。

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY        GPUS   STATUS   AGE
    maxtext-tpu-cluster   4                 4                   40     798027216Ki   0      ready    11m
    
  4. Ray ヘッドサービスを介して Ray ダッシュボードにアクセスするには、ポート転送セッションを確立します。

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. ローカル環境から RayCluster にアクセスできることを確認します。

    ray list nodes --address http://localhost:8265
    

    出力例を以下に示します。

    ======== List: 2025-09-13 03:53:16.988269 ========
    Stats:
    ------------------------------
    Total: 5
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP    IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                  LABELS
    0  92c79d04c34b659c1e3044f7642ad3fd47eb16f290785237149fab56  10.84.0.9
    (...)
    
  6. JaxTrainer スクリプトを RayCluster に送信し、RayJob が正常に完了することを確認します。

    ray job submit \
      --address http://localhost:8265 \
      -- python /app/maxtext_ray_trainer.py \
          /app/maxtext/src/MaxText/configs/base.yml \
           base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=1 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-8b-4096-tp4-4x4
    

    上記のコマンドは、JaxTrainer Ray コードを呼び出す Python スクリプトを RayCluster に送信します。ray job submit コマンドには、モデル構成に渡す MaxText に固有の引数が含まれています。

    ターミナルに、次のような出力が表示されます。

    (RayTrainWorker pid=21663, ip=10.12.3.6) completed step: 99, seconds: 1.100, TFLOP/s/device: 179.739, Tokens/s/device: 3725.218, total_weights: 65536, loss: 0.000 [repeated 3x across cluster]
    
    ------------------------------------------
    Job 'raysubmit_zCrJcWnuymMQv4C3' succeeded
    ------------------------------------------
    

クリーンアップ

このチュートリアルで使用したリソースについて、 Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。

  1. RayCluster を削除します。

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. GKE クラスタを削除します。

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. Cloud Storage バケットを削除します。

    gsutil rm -r gs://${GS_BUCKET}
    
  4. Artifact Registry リポジトリを削除します。

    gcloud artifacts repositories delete ${ARTIFACT_REGISTRY} --location=${REGION} --quiet
    

次のステップ