このチュートリアルでは、MaxText、Ray Train、TPU を使用して Google Kubernetes Engine(GKE)で Llama 3 8B 大規模言語モデル(LLM)をトレーニングする方法について説明します。
このチュートリアルでは、必要なクラウド インフラストラクチャの構成から、マルチホスト TPU でトレーニング ワークロードを送信して正常に実行するまで、エンドツーエンドの完全なチュートリアルを提供します。
このチュートリアルは、分散マルチホスト TPU スライスで大規模なモデルをトレーニングする方法を学習するプラットフォーム管理者とオペレーター、データおよび AI スペシャリストを対象としています。
背景
GKE、KubeRay、MaxText、TPU を組み合わせることで、大規模なモデル トレーニングのための強力でスケーラブルなプラットフォームが実現します。このセクションでは、このガイドで使用されている重要なテクノロジーについて説明します。
JAX
JAX は、アクセラレータ指向の配列計算とプログラム変換のための Python ライブラリで、高パフォーマンスの数値計算と大規模な ML 用に設計されています。
JAX は、jax.grad、jax.jit、jax.vmap などの数値関数を変換するための拡張可能なシステムを提供します。XLA コンパイラを利用して、GPU や TPU などのアクセラレータで効率的にスケーリングする高度に最適化されたコードを作成します。JAX の主な強みはコンポーザビリティです。ユーザーはこれらの変換を組み合わせて、分散実行用の複雑で高パフォーマンスの数値プログラムを構築できます。
MaxText
MaxText は、拡張性とカスタマイズ性を重視して設計された、高パフォーマンスのオープンソース大規模言語モデル(LLM)です。MaxText は JAX 上に構築されており、Cloud TPU と GPU で効率的に実行できるように最適化されています。
TPU
Tensor Processing Unit(TPU)は、機械学習ワークロードを最適化するために Google が作成したカスタム設計のアクセラレータです。汎用 CPU や並列処理 GPU とは異なり、TPU はディープ ラーニングの基盤となる大規模な行列とテンソルの計算に特化しているため、この特定のタスクを効率的に実行できます。TPU の主な利点は、パフォーマンス拡張です。
このチュートリアルでは、第 6 世代 TPU である TPU Trillium を使用します。詳細については、TPU Trillium を使用するメリットをご覧ください。
KubeRay
KubeRay は、Kubernetes で Ray アプリケーションをデプロイ、管理、モニタリングするための統一された方法を提供する Kubernetes オペレーターです。KubeRay オペレーターは、Ray on GKE アドオンを介してインストールおよび管理されます。これは、GKE 上の Ray クラスタをデプロイして管理するおすすめの方法です。
目標
このチュートリアルでは、次の方法を説明します。
- マルチホスト TPU ノードプールを使用して GKE クラスタを設定します。
- 分散トレーニング環境を管理するように KubeRay を構成します。
- MaxText、Ray、JAX の依存関係を含むカスタム Docker イメージをビルドします。
- Ray Train の
JaxTrainerを使用して、TPU スライス全体で MaxText トレーニング ループをオーケストレートする Python トレーニング スクリプトを作成します。 - 必要な TPU リソースを使用してヘッドノードとワーカーノードをプロビジョニングする
RayClusterカスタム リソースを定義します。 - トレーニング ジョブを
RayClusterに送信し、進行状況をモニタリングします。 - Cloud Storage を使用して、モデルのチェックポイントを保存します。
始める前に
- Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
-
Install the Google Cloud CLI.
-
外部 ID プロバイダ(IdP)を使用している場合は、まず連携 ID を使用して gcloud CLI にログインする必要があります。
-
gcloud CLI を初期化するには、次のコマンドを実行します。
gcloud init -
Create or select a Google Cloud project.
Roles required to select or create a project
- Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
-
Create a project: To create a project, you need the Project Creator role
(
roles/resourcemanager.projectCreator), which contains theresourcemanager.projects.createpermission. Learn how to grant roles.
-
Create a Google Cloud project:
gcloud projects create PROJECT_ID
Replace
PROJECT_IDwith a name for the Google Cloud project you are creating. -
Select the Google Cloud project that you created:
gcloud config set project PROJECT_ID
Replace
PROJECT_IDwith your Google Cloud project name.
-
Verify that billing is enabled for your Google Cloud project.
-
Enable the required API:
Roles required to enable APIs
To enable APIs, you need the Service Usage Admin IAM role (
roles/serviceusage.serviceUsageAdmin), which contains theserviceusage.services.enablepermission. Learn how to grant roles.gcloud services enable container.googleapis.com
-
Install the Google Cloud CLI.
-
外部 ID プロバイダ(IdP)を使用している場合は、まず連携 ID を使用して gcloud CLI にログインする必要があります。
-
gcloud CLI を初期化するには、次のコマンドを実行します。
gcloud init -
Create or select a Google Cloud project.
Roles required to select or create a project
- Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
-
Create a project: To create a project, you need the Project Creator role
(
roles/resourcemanager.projectCreator), which contains theresourcemanager.projects.createpermission. Learn how to grant roles.
-
Create a Google Cloud project:
gcloud projects create PROJECT_ID
Replace
PROJECT_IDwith a name for the Google Cloud project you are creating. -
Select the Google Cloud project that you created:
gcloud config set project PROJECT_ID
Replace
PROJECT_IDwith your Google Cloud project name.
-
Verify that billing is enabled for your Google Cloud project.
-
Enable the required API:
Roles required to enable APIs
To enable APIs, you need the Service Usage Admin IAM role (
roles/serviceusage.serviceUsageAdmin), which contains theserviceusage.services.enablepermission. Learn how to grant roles.gcloud services enable container.googleapis.com
-
Grant roles to your user account. Run the following command once for each of the following IAM roles:
roles/container.admin, roles/iam.serviceAccountAdmingcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE
Replace the following:
PROJECT_ID: Your project ID.USER_IDENTIFIER: The identifier for your user account. For example,myemail@example.com.ROLE: The IAM role that you grant to your user account.
- このチュートリアルでは TPU Trillium(v6e)を利用するため、利用可能なリージョンまたはゾーンを選択します。詳細については、Cloud TPU の割り当てをご覧ください。
環境を準備する
このチュートリアルでは、Cloud Shell を使用します。Cloud Shell には、このチュートリアルで使用する gcloud、helm、kubectl コマンドライン ツールがプリインストールされています。
Google Cloud コンソールに移動します。
Google Cloud コンソール ウィンドウの上部にある [Cloud Shell をアクティブにする]
ボタンをクリックします。Google Cloud コンソールの新しいフレーム内で Cloud Shell セッションが開き、コマンドライン プロンプトが表示されます。
Python 仮想環境を作成してアクティブにします。
python3 -m venv ray-env source ray-env/bin/activateRay CLI とその他の依存関係をインストールします。
pip install "ray[default]==2.49.1"次の環境変数を設定します。
export PROJECT_ID=$(gcloud config get project) export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)") export GS_BUCKET=GS_BUCKET export KSA_NAME=KSA_NAME export NAMESPACE=default export CLUSTER_NAME=CLUSTER_NAME export REGION=REGION export ZONE=ZONE export ARTIFACT_REGISTRY=ARTIFACT_REGISTRY次のように置き換えます。
GS_BUCKET: Cloud Storage バケットの名前。KSA_NAME: Kubernetes サービス アカウントの名前。CLUSTER_NAME: 新しいクラスタの名前。REGION: TPU Trillium の容量が使用可能なリージョン。ZONE: TPU Trillium の容量が使用可能なゾーン。詳細については、GKE での TPU の可用性をご覧ください。ARTIFACT_REGISTRY: Artifact Registry リポジトリの名前。
GKE クラスタを作成する
GKE Autopilot クラスタまたは GKE Standard クラスタの TPU で KubeRay を構成できます。フルマネージドの Kubernetes エクスペリエンスを実現するには、Autopilot クラスタを使用することをおすすめします。ワークロードに最適な GKE の運用モードを選択するには、GKE の運用モードについてをご覧ください。
Autopilot
Cloud Shell で、次のコマンドを実行します。
gcloud container clusters create-auto $CLUSTER_NAME \ --enable-ray-operator \ --machine-type=n1-standard-16 \ --location=$REGIONクラスタと通信するには、
kubectlを構成します。gcloud container clusters get-credentials CLUSTER_NAME \ --location=$ZONE
Standard
Cloud Shell で、次のコマンドを実行して、Ray オペレータ アドオンを有効にする Standard クラスタを作成します。
gcloud container clusters create $CLUSTER_NAME \ --addons=RayOperator \ --addons GcsFuseCsiDriver \ --machine-type=n1-standard-16 \ --workload-pool=$PROJECT_ID.svc.id.goog \ --location=$ZONEこのコマンドは
GcsFuseCsiDriverも有効にします。これにより、Pod は Cloud Storage バケットをローカル ファイル システムとしてマウントできます。クラスタの作成には数分かかることもあります。クラスタと通信するには、
kubectlを構成します。gcloud container clusters get-credentials CLUSTER_NAME \ --location=LOCATIONマルチホスト TPU スライス ノードプールを作成します。
gcloud container node-pools create v6e-16 \ --location=$ZONE \ --cluster=$CLUSTER_NAME \ --machine-type=ct6e-standard-4t \ --threads-per-core=1 \ --tpu-topology=4x4 \ --num-nodes=4
GKE は、4 つの TPU Trillium(v6e)VM で構成されるノードプールをプロビジョニングします。これらは、4x4 トポロジを使用してマルチホスト TPU スライスとして構成され、分散トレーニング ワークロードに対応できます。
Ray オペレーターが有効になっている GKE クラスタは、クラスタに KubeRay と KubeRay TPU Webhook を自動的にインストールします。
Cloud Storage バケットとサービス アカウントを構成する
マルチホスト TPU ノード間で共有されるチェックポイント用の Cloud Storage バケットを作成します。
gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}Cloud Storage バケットへのアクセスを有効にするには、Kubernetes サービス アカウントを作成します。
kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}Cloud Storage バケットへのアクセスを有効にするには、必要な IAM ポリシー バインディングをサービス アカウントに追加します。
gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \ --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \ --role "roles/storage.objectUser"
トレーニング スクリプトを作成する
次のスクリプトは、Ray Train の JaxTrainer を使用して分散 MaxText トレーニング ジョブを実行します。このスクリプトは、マルチホスト TPU スライス ノードプールのトレーニング環境を構成し、各ワーカーノードで MaxText トレーニング ジョブを実行します。train_loop_per_worker 関数は MaxText のメイン エントリ ポイントをラップし、Ray の分散スケジューラを使用してマルチホスト TPU スライスで MaxText トレーナーを実行します。
次の Python スクリプトを
maxtext_ray_trainer.pyとして保存します。カスタム イメージをホストするには、Artifact Registry リポジトリを作成します。
gcloud artifacts repositories create ${ARTIFACT_REGISTRY} \ --repository-format=docker --location=${REGION} && \ gcloud auth configure-docker ${REGION}-docker.pkg.devトレーニング用の Ray と MaxText の依存関係を含むイメージをビルドするには、
Dockerfileを作成します。Docker イメージをビルドしてタグ付けし、Artifact Registry に push します。
export DOCKER_IMAGE=${REGION}-docker.pkg.dev/${PROJECT_ID}/${ARTIFACT_REGISTRY}/ray-maxtext:latest gcloud builds submit --tag ${DOCKER_IMAGE}
モデルのトレーニング
次のサンプル マニフェストを
maxtext-tpu-cluster.yamlとして保存します。上記の RayCluster 仕様では、レプリカごとに 4 つのワーカー(
numOfHosts: 4)を含む TPU ワーカー グループを作成します。各ワーカーは 4 つの TPU チップ(google.com/tpu: "4")をリクエストします。ワーカーは、TPU Trillium(tpu-v6e-slice)を実行し、同じコロケーションされたマルチホスト スライスの一部であるノードでスケジュールされます。KubeRay は 4 つのワーカーすべてをアトミックにスケーリングします。必要な JAX 環境変数とスケジューリング用の Pod アフィニティは、変更用 Webhook を介して GKE によってブートストラップされます。YAML ファイルで必要な値を構成するには、
envsubstを使用して RayCluster を作成します。envsubst < maxtext-tpu-cluster.yaml | kubectl apply -f -クラスタが使用できるようになり、実行中であることを確認します。
kubectl get rayclusters maxtext-tpu-cluster出力例を以下に示します。
NAME DESIRED WORKERS AVAILABLE WORKERS CPUS MEMORY GPUS STATUS AGE maxtext-tpu-cluster 4 4 40 798027216Ki 0 ready 11mRay ヘッドサービスを介して Ray ダッシュボードにアクセスするには、ポート転送セッションを確立します。
kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &ローカル環境から RayCluster にアクセスできることを確認します。
ray list nodes --address http://localhost:8265出力例を以下に示します。
======== List: 2025-09-13 03:53:16.988269 ======== Stats: ------------------------------ Total: 5 Table: ------------------------------ NODE_ID NODE_IP IS_HEAD_NODE STATE STATE_MESSAGE NODE_NAME RESOURCES_TOTAL LABELS 0 92c79d04c34b659c1e3044f7642ad3fd47eb16f290785237149fab56 10.84.0.9 (...)JaxTrainer スクリプトを RayCluster に送信し、RayJob が正常に完了することを確認します。
ray job submit \ --address http://localhost:8265 \ -- python /app/maxtext_ray_trainer.py \ /app/maxtext/src/MaxText/configs/base.yml \ base_output_directory=/data/ \ dataset_type=synthetic \ per_device_batch_size=1 \ max_target_length=4096 \ model_name=llama3-8b \ steps=100 \ ici_fsdp_parallelism=4 \ ici_tensor_parallelism=4 \ run_name=rayjob-8b-4096-tp4-4x4上記のコマンドは、JaxTrainer Ray コードを呼び出す Python スクリプトを RayCluster に送信します。
ray job submitコマンドには、モデル構成に渡す MaxText に固有の引数が含まれています。ターミナルに、次のような出力が表示されます。
(RayTrainWorker pid=21663, ip=10.12.3.6) completed step: 99, seconds: 1.100, TFLOP/s/device: 179.739, Tokens/s/device: 3725.218, total_weights: 65536, loss: 0.000 [repeated 3x across cluster] ------------------------------------------ Job 'raysubmit_zCrJcWnuymMQv4C3' succeeded ------------------------------------------
クリーンアップ
このチュートリアルで使用したリソースについて、 Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。
RayCluster を削除します。
kubectl delete raycluster maxtext-tpu-clusterGKE クラスタを削除します。
gcloud container clusters delete $CLUSTER_NAME --zone=$ZONECloud Storage バケットを削除します。
gsutil rm -r gs://${GS_BUCKET}Artifact Registry リポジトリを削除します。
gcloud artifacts repositories delete ${ARTIFACT_REGISTRY} --location=${REGION} --quiet
次のステップ
- Ray on Kubernetes について学習する。
- GKE で TPU を使用して vLLM をサービングする方法を確認する。
- GKE で TPU を使用して SDXL をサービングする方法を確認する。
- GKE の TPU の詳細を確認する。