Trillium(v6e)の概要

v6e は、このドキュメント、TPU API、ログで Trillium を指すために使用されます。v6e は Google の第 6 世代の TPU を表します。

Pod あたり 256 個のチップを使用する v6e は、v5e と多くの類似点があります。このシステムは、トランスフォーマー、text-to-image、畳み込みニューラル ネットワーク(CNN)のトレーニング、微調整、サービス提供に適した最適なプロダクトとなるように最適化されています。

v6e システム アーキテクチャ

Cloud TPU の構成については、v6e のドキュメントをご覧ください。

このドキュメントでは、JAXPyTorchTensorFlow フレームワークを使用したモデル トレーニングの設定プロセスについて説明します。各フレームワークでは、キューに格納されたリソースまたは Google Kubernetes Engine(GKE)を使用して TPU をプロビジョニングできます。GKE の設定は、XPK コマンドまたは GKE コマンドを使用して行うことができます。

Google Cloud プロジェクトを準備する

  1. Google アカウントにログインします。Google アカウントをまだお持ちでない場合は、新しいアカウントを登録します。
  2. Google Cloud コンソールで、プロジェクト セレクタページから Cloud プロジェクトを選択するか作成します。
  3. Google Cloud プロジェクトに対する課金を有効にします。Google Cloud の使用にはすべて課金が必要です。
  4. gcloud alpha components をインストールします。
  5. 次のコマンドを実行して、gcloud コンポーネントの最新バージョンをインストールします。

    gcloud components update
    
  6. Cloud Shell で、次の gcloud コマンドを使用して TPU API を有効にします。Google Cloud コンソールから有効にすることもできます。

    gcloud services enable tpu.googleapis.com
    
  7. Compute Engine API の TPU サービス アカウントで権限を有効にする

    サービス アカウントにより、Cloud TPU サービスが他の Google Cloud サービスにアクセスできるようになります。ユーザー管理のサービス アカウントは、Google Cloud で推奨される方法です。ガイドに沿ってロールを作成し、付与します。次のロールが必要です。

    • TPU 管理者
    • ストレージ管理者
    • ログ書き込み
    • モニタリング指標の書き込み

    a. GKE のユーザー アカウント(XPK)で XPK 権限を設定します。

  8. プロジェクト ID とゾーンの環境変数を作成します。

     gcloud auth login
     gcloud config set project ${PROJECT_ID}
     gcloud config set compute/zone ${ZONE}
    
  9. TPU VM のサービス ID を作成します。

     gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
    

容量を確保する

TPU 割り当てをリクエストし、容量に関する質問に回答するには、Cloud TPU サポートの営業担当者またはアカウントにお問い合わせください。

Cloud TPU 環境をプロビジョニングする

v6e TPU は、GKE、GKE と XPK(GKE をラップする CLI ツール)、またはキューに登録されたリソースでプロビジョニングして管理できます。

前提条件

  • プロジェクトに十分な TPUS_PER_TPU_FAMILY 割り当てがあることを確認します。これは、Google Cloud プロジェクト内でアクセスできるチップの最大数を指定します。
  • v6e は、次の構成でテストされています。
    • python 3.10 以降
    • ナイトリー ソフトウェアのバージョン:
      • 夜間 JAX 0.4.32.dev20240912
      • ナイトリー LibTPU 0.1.dev20240912+nightly
    • 安定版ソフトウェアのバージョン:
      • JAX + v0.4.35 の JAX Lib
  • プロジェクトに次の TPU 割り当てがあることを確認します。
    • TPU VM の割り当て
    • IP アドレスの割り当て
    • Hyperdisk-balance 割り当て
  • ユーザー プロジェクトの権限

環境変数

Cloud Shell で、次の環境変数を作成します。

export NODE_ID=TPU_NODE_ID # TPU name
export PROJECT_ID=PROJECT_ID
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-central2-b
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=YOUR_SERVICE_ACCOUNT
export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID
export VALID_DURATION=VALID_DURATION

# Additional environment variable needed for Multislice:
export NUM_SLICES=NUM_SLICES

# Use a custom network for better performance as well as to avoid having the
# default network becoming overloaded.
export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

コマンドフラグの説明

変数 説明
NODE_ID キューに入れられたリソース リクエストの割り当て時に作成される TPU のユーザー割り当て ID。
PROJECT_ID Google Cloud プロジェクト名。既存のプロジェクトを使用するか、 で新しいプロジェクトを作成します。
ゾーン サポートされているゾーンについては、TPU のリージョンとゾーンのドキュメントをご覧ください。
ACCELERATOR_TYPE アクセラレータ タイプをご覧ください。
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT これは、Google Cloud コンソール -> IAM -> サービス アカウント で確認できるサービス アカウントのメールアドレスです。

例: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

NUM_SLICES 作成するスライス数(マルチスライスの場合のみ必要)
QUEUED_RESOURCE_ID キューに格納されたリソース リクエストのユーザー割り当てテキスト ID。
VALID_DURATION キューに入れられたリソース リクエストが有効である期間。
NETWORK_NAME 使用するセカンダリ ネットワークの名前。
NETWORK_FW_NAME 使用するセカンダリ ネットワーク ファイアウォールの名前。

ネットワーク パフォーマンスの最適化

最適なパフォーマンスを得るには、8,896 MTU(最大伝送単位)のネットワークを使用します。

デフォルトでは、Virtual Private Cloud(VPC)は 1,460 バイトの MTU のみを提供します。これにより、ネットワーク パフォーマンスが最適化されません。VPC ネットワークの MTU は、1,300 ~ 8,896 バイトの任意の値に設定できます。一般的なカスタム MTU サイズは 1,500 バイト(標準イーサネット)または 8,896 バイト(可能な最大値)です。詳細については、有効な VPC ネットワークの MTU サイズをご覧ください。

既存またはデフォルトのネットワークの MTU 設定の変更の詳細については、VPC ネットワークの MTU 設定を変更するをご覧ください。

次の例では、8,896 MTU のネットワークを作成します。

export RESOURCE_NAME=RESOURCE_NAME
export NETWORK_NAME=${RESOURCE_NAME}
export NETWORK_FW_NAME=${RESOURCE_NAME}
export PROJECT=X
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT} --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} \

マルチ NIC の使用(マルチスライス向けのオプション)

マルチスライス環境を使用している場合、セカンダリ サブネットには次の環境変数が必要です。

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2

次のコマンドを使用して、ネットワークとサブネットのカスタム IP ルーティングを作成します。

gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
   --bgp-routing-mode=regional --subnet-mode=custom --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
   --network="${NETWORK_NAME_2}" \
   --range=10.10.0.0/18 --region="${REGION}" \
   --project=$PROJECT

gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
   --network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project="${PROJECT}"

gcloud compute routers create "${ROUTER_NAME}" \
  --project="${PROJECT}" \
  --network="${NETWORK_NAME_2}" \
  --region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \
  --router="${ROUTER_NAME}" \
  --region="${REGION}" \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project="${PROJECT}" \
  --enable-logging

マルチネットワーク スライスが作成されたら、XPK ワークロードの一部として --command ifconfig を実行して、両方の NIC が使用されていることを確認できます。次に、Cloud コンソールのログで、その XPK ワークロードの出力を確認します。eth0 と eth1 の両方に mtu=8896 があることを確認します。

python3 xpk.py workload create \
   --cluster ${CLUSTER_NAME} \
   (--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}) \
   --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone $ZONE \
   --project $PROJECT_ID \
   [--enable-debug-logs] \
   [--use-vertex-tensorboard] \
   --command "ifconfig"

eth0 と eth1 の両方に mtu=8,896 があることを確認します。マルチ NIC が実行されていることを確認するには、XPK ワークロードの一部としてコマンド --command "ifconfig" を実行します。次に、Cloud コンソール ログでその xpk ワークロードの出力を確認して、eth0 と eth1 の両方に mtu=8896 があることを確認します。

TCP 設定の改善

キューに格納されたリソース インターフェースを使用して作成された TPU の場合、次のコマンドを実行して rto_minquickack のデフォルトの TCP 設定を変更することで、ネットワーク パフォーマンスを改善できます。

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
   --project "$PROJECT" --zone "${ZONE}" \
   --command='ip route show | while IFS= read -r route; do if ! echo $route | \
   grep -q linkdown; then sudo ip route change ${route/lock/} rto_min 5ms quickack 1; fi; done' \
   --worker=all

キューに入れられたリソースによるプロビジョニング(Cloud TPU API)

容量は、キューに入れられたリソースの create コマンドを使用してプロビジョニングできます。

  1. TPU のキューに入れられたリソース リクエストを作成します。

    --reserved フラグは、オンデマンド リソースではなく、予約済みリソースに対してのみ必要です。

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --valid-until-duration ${VALID_DURATION} \
      --service-account ${SERVICE_ACCOUNT} \
      [--reserved]

    キューに格納されたリソース リクエストが正常に作成されると、[response] フィールド内の状態は「WAITING_FOR_RESOURCES」または「FAILED」のいずれかになります。キューに格納されたリソース リクエストの状態が「WAITING_FOR_RESOURCES」の場合、キューに格納されたリソースはキューに追加され、十分な TPU 容量があるときにプロビジョニングされます。キューに格納されたリソース リクエストが「FAILED」状態の場合、失敗の理由が出力に表示されます。指定した時間内に v6e がプロビジョニングされず、状態が「FAILED」になった場合、キューに格納されたリソース リクエストは期限切れになります。詳細については、キューに格納されたリソースの公開ドキュメントをご覧ください。

    キューに入れられたリソース リクエストが「ACTIVE」状態になると、SSH を使用して TPU VM に接続できます。list コマンドまたは describe コマンドを使用して、キューに格納されたリソースのステータスをクエリします。

    gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project ${PROJECT_ID} --zone ${ZONE}
    

    キューに格納されたリソースが「ACTIVE」状態の場合、出力は次のようになります。

      state:
       state: ACTIVE
    
  2. TPU VM を管理します。TPU VM を管理するオプションについては、TPU VM の管理をご覧ください。

  3. SSH を使用して TPU VM に接続する

    TPU スライスの各 TPU VM にバイナリをインストールしてコードを実行できます。スライスに含まれる VM の数を決定するには、VM のタイプのセクションをご覧ください。

    バイナリをインストールするかコードを実行するには、SSH を使用して tpu-vm ssh コマンドを使用して VM に接続します。

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --node=all # add this flag if you are using Multislice
    

    SSH を使用して特定の VM に接続するには、0 ベースのインデックスに続けて --worker フラグを使用します。

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
    

    スライス形状が 8 チップを超える場合、1 つのスライスに複数の VM があります。この場合は、gcloud alpha compute tpus tpu-vm ssh コマンドの --worker=all パラメータと --command パラメータを使用して、すべての VM でコマンドを同時に実行します。次に例を示します。

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
      --zone  ${ZONE} --worker=all \
      --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
      -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. キューに格納されているリソースの削除

    セッションの終了時にキューに格納されているリソースを削除するか、「FAILED」状態のキューに入れられたリソース リクエストを削除します。キューに格納されたリソースを削除するには、次の 2 つの手順でスライスを削除してから、キュー内のリソースを削除します。

    gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \
     --zone=${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
     --project ${PROJECT_ID} --zone ${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
    

v6e で GKE を使用する

v6e で GKE コマンドを使用している場合は、Kubernetes コマンドまたは XPK を使用して TPU をプロビジョニングし、モデルをトレーニングまたは提供できます。TPU と v6e で GKE を使用する方法については、GKE で TPU を計画するをご覧ください。

フレームワークの設定

このセクションでは、JAXPyTorchTensorFlow フレームワークを使用した ML モデル トレーニングの一般的な設定プロセスについて説明します。キューに入れられたリソースまたは GKE を使用して TPU をプロビジョニングできます。GKE の設定は、XPK または Kubernetes コマンドを使用して行うことができます。

キューに入れられたリソースを使用して JAX を設定する

gcloud alpha compute tpus tpu-vm ssh を使用して、スライス内のすべての TPU VM に JAX を同時にインストールします。マルチスライスの場合は、--node=all を追加します。


gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'

次の Python コードを実行して、スライドで使用可能な TPU コアの数を確認し、すべてが正しくインストールされていることをテストできます(ここに表示されている出力は、v6e-16 スライスで生成されたものです)。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

出力は次のようになります。

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() は、指定されたスライス内のチップの合計数を示します。jax.local_device_count() は、このスライス内の単一の VM からアクセス可能なチップの数を示します。

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 &&
   && pip install -r requirements.txt  && pip install . '

JAX のセットアップに関するトラブルシューティング

一般的なヒントとして、GKE ワークロード マニフェストで詳細なロギングを有効にします。次に、ログを GKE サポートに提供します。

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

エラー メッセージ

no endpoints available for service 'jobset-webhook-service'

このエラーは、ジョブセットが正しくインストールされていないことを意味します。jobset-controller-manager Deployment Kubernetes Pod が実行されていることを確認します。詳細については、JobSet のトラブルシューティングに関するドキュメントをご覧ください。

TPU initialization failed: Failed to connect

GKE ノード バージョンが 1.30.4-gke.1348000 以降であることを確認します(GKE 1.31 はサポートされていません)。

PyTorch を設定する

このセクションでは、PyTorch/XLA を使用して v6e で PJRT の使用を開始する方法について説明します。Python 3.10 が推奨される Python バージョンです。

XPK で GKE を使用して PyTorch を設定する

PyTorch の依存関係がすでにインストールされている XPK で、次の Docker コンテナを使用できます。

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

XPK ワークロードを作成するには、次のコマンドを使用します。

python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -$NUM_SLICES \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES}  \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

--base-docker-image を使用すると、現在の作業ディレクトリが新しい Docker にビルドされた新しい Docker イメージが作成されます。

キューに格納されたリソースを使用して PyTorch を設定する

キューに登録されたリソースを使用して PyTorch をインストールし、v6e で小さなスクリプトを実行する手順は次のとおりです。

SSH を使用して依存関係をインストールし、VM にアクセスします。

マルチスライスの場合は、--node=all を追加します。

   gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt install -y libopenblas-base pip3 \
    install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
    --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

サイズが大きく、頻繁な割り当てがあるモデルのパフォーマンスを改善する

サイズ設定に頻繁な割り当てがあるモデルの場合、tcmalloc を使用すると、デフォルトの malloc 実装よりもパフォーマンスが大幅に向上するため、TPU VM で使用されるデフォルトの malloctcmalloc です。しかし、ワークロードによっては(たとえば、埋め込みテーブルへの割り当てが非常に大きい DLRM の場合)、tcmalloc では速度が低下する可能性があります。その場合は、次の変数の設定を解除して、代わりにデフォルトの malloc を使用してください。

unset LD_PRELOAD

Python スクリプトを使用して v6e VM で計算を行います。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

これにより、次のような出力が生成されます。

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

TensorFlow の設定

v6e 公開プレビュー版では、tf-nightly ランタイム バージョンのみがサポートされています。

次のコマンドを実行して、v6e 互換の TensorFlow バージョンで tpu-runtime をリセットできます。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
    --zone  ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
    --zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

SSH を使用して worker-0 にアクセスします。

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
     --zone ${ZONE}

worker-0 に TensorFlow をインストールします。

sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

TPU_NAME 環境変数をエクスポートします。

export TPU_NAME=v6e-16

次の Python スクリプトを実行して、スライドで使用可能な TPU コアの数を確認し、すべてが正しくインストールされていることをテストできます(ここに示す出力は v6e-16 スライスで生成されています)。

import TensorFlow as tf
print("TensorFlow version " + tf.__version__)

@tf.function
  def add_fn(x,y):
  z = x + y
  return z

  cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  tf.config.experimental_connect_to_cluster(cluster_resolver)
  tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
  strategy = tf.distribute.TPUStrategy(cluster_resolver)

  x = tf.constant(1.)
  y = tf.constant(1.)
  z = strategy.run(add_fn, args=(x,y))
  print(z)

出力は次のようになります。

PerReplica:{
  0: tf.Tensor(2.0, shape=(), dtype=float32),
  1: tf.Tensor(2.0, shape=(), dtype=float32),
  2: tf.Tensor(2.0, shape=(), dtype=float32),
  3: tf.Tensor(2.0, shape=(), dtype=float32),
  4: tf.Tensor(2.0, shape=(), dtype=float32),
  5: tf.Tensor(2.0, shape=(), dtype=float32),
  6: tf.Tensor(2.0, shape=(), dtype=float32),
  7: tf.Tensor(2.0, shape=(), dtype=float32)
}

v6e(SkyPilot 対応)

TPU v6e は SkyPilot で使用できます。次の手順で、v6e 関連の場所/料金情報を SkyPilot に追加します。

  1. ~/.sky/catalogs/v5/gcp/vms.csv の末尾に次のように追加します。

    ,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
    
  2. YAML ファイルで次のリソースを指定します。

    # tpu_v6.yaml
    resources:
      accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use
      accelerator_args:
        runtime_version: v2-alpha-tpuv6e # Official suggested runtime
    
  3. TPU v6e を使用してクラスタを起動します。

       sky launch tpu_v6.yaml -c tpu_v6
    
  4. SSH を使用して TPU v6e に接続します。ssh tpu_v6

推論チュートリアル

以降のセクションでは、JetStream を使用して MaxText モデルと PyTorch モデルをサービングするチュートリアルと、TPU v6e で MaxDiffusion モデルをサービングするチュートリアルについて説明します。

JetStream での MaxText

このチュートリアルでは、JetStream を使用して TPU v6e で MaxText(JAX)モデルを提供する方法について説明します。JetStream は、XLA デバイス(TPU)での大規模言語モデル(LLM)推論用にスループットとメモリが最適化されたエンジンです。このチュートリアルでは、Llama2-7B モデルの推論ベンチマークを実行します。

始める前に

  1. 4 つのチップを含む TPU v6e を作成します。

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. SSH を使用して TPU に接続します。

    gcloud compute tpus tpu-vm ssh TPU_NAME

チュートリアルを実行する

JetStream と MaxText を設定し、モデル チェックポイントを変換して推論ベンチマークを実行するには、GitHub リポジトリの手順に沿って操作します。

クリーンアップ

TPU を削除します。

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

PyTorch TPU での vLLM

以下は、TPU VM で vLLM を使用する方法を示す簡単なチュートリアルです。本番環境で Trillium に vLLM をデプロイするベスト プラクティスの例については、数日以内に GKE ユーザーガイドを公開する予定です(お待ちください)。

始める前に

  1. 4 つのチップを含む TPU v6e を作成します。

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
       --node-id TPU_NAME \
       --project PROJECT_ID \
       --zone ZONE \
       --accelerator-type v6e-4 \
       --runtime-version v2-alpha-tpuv6e \
       --service-account SERVICE_ACCOUNT

    コマンドフラグの説明

    変数 説明
    NODE_ID キューに格納されたリソース リクエストの割り当て時に作成される TPU のユーザー割り当て ID。
    PROJECT_ID Google Cloud プロジェクト名。既存のプロジェクトを使用するか、 で新しいプロジェクトを作成します。
    ZONE サポートされているゾーンについては、TPU のリージョンとゾーンのドキュメントをご覧ください。
    ACCELERATOR_TYPE アクセラレータ タイプをご覧ください。
    RUNTIME_VERSION v2-alpha-tpuv6e
    SERVICE_ACCOUNT これは、Google Cloud コンソール -> IAM -> サービス アカウント で確認できるサービス アカウントのメールアドレスです。

    例: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

  2. SSH を使用して TPU に接続します。

    gcloud compute tpus tpu-vm ssh TPU_NAME
    

Create a Conda environment

  1. (Recommended) Create a new conda environment for vLLM:

    conda create -n vllm python=3.10 -y
    conda activate vllm

TPU で vLLM を設定する

  1. vLLM リポジトリのクローンを作成し、vLLM ディレクトリに移動します。

    git clone https://github.com/vllm-project/vllm.git && cd vllm
    
  2. 既存の torch パッケージと torch-xla パッケージをクリーンアップします。

    pip uninstall torch torch-xla -y
    
  3. PyTorch と PyTorch XLA をインストールします。

    pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
    
  4. JAX と Pallas をインストールします。

    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
    pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
    
    
  5. 他のビルド依存関係をインストールします。

    pip install -r requirements-tpu.txt
    VLLM_TARGET_DEVICE="tpu" python setup.py develop
    sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
    

モデルへのアクセス権を取得する

HuggingFace リポジトリの Llama3 ファミリーのモデルを使用するには、同意契約に署名する必要があります。

Hugging Face トークンをまだ生成していない場合は、新しいトークンを生成します。

  1. [Your Profile] > [Settings] > [Access Tokens] の順にクリックします。
  2. [New Token] を選択します。
  3. 任意の名前と、少なくとも Read ロールを指定します。
  4. [Generate a token] を選択します。
  5. 生成されたトークンをクリップボードにコピーし、環境変数として設定して、huggingface-cli で認証します。

    export TOKEN=''
    git config --global credential.helper store
    huggingface-cli login --token $TOKEN

ベンチマーク データをダウンロードする

  1. /data ディレクトリを作成し、Hugging Face から ShareGPT データセットをダウンロードします。

    mkdir ~/data && cd ~/data
    wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
    

vLLM サーバーを起動する

次のコマンドは、Hugging Face Model Hub から TPU VM の /tmp ディレクトリにモデルの重みをダウンロードし、さまざまな入力シェイプを事前コンパイルして、モデルのコンパイルを ~/.cache/vllm/xla_cache に書き込みます。

詳細については、vLLM のドキュメントをご覧ください。

   cd ~/vllm
   vllm serve "meta-llama/Meta-Llama-3.1-8B" --download_dir /tmp --num-scheduler-steps 4 --swap-space 16 --disable-log-requests --tensor_parallel_size=4 --max-model-len=2048 &> serve.log &

vLLM ベンチマークを実行する

vLLM ベンチマーク スクリプトを実行します。

   python benchmarks/benchmark_serving.py \
       --backend vllm \
       --model "meta-llama/Meta-Llama-3.1-8B"  \
       --dataset-name sharegpt \
       --dataset-path ~/data/ShareGPT_V3_unfiltered_cleaned_split.json  \
       --num-prompts 1000

クリーンアップ

TPU を削除します。

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

JetStream での PyTorch

このチュートリアルでは、JetStream を使用して TPU v6e で PyTorch モデルを提供する方法について説明します。JetStream は、XLA デバイス(TPU)での大規模言語モデル(LLM)推論用にスループットとメモリが最適化されたエンジンです。このチュートリアルでは、Llama2-7B モデルの推論ベンチマークを実行します。

始める前に

  1. 4 つのチップを含む TPU v6e を作成します。

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. SSH を使用して TPU に接続します。

    gcloud compute tpus tpu-vm ssh TPU_NAME

チュートリアルを実行する

JetStream-PyTorch を設定し、モデル チェックポイントを変換して推論ベンチマークを実行するには、GitHub リポジトリの手順に沿って操作します。

クリーンアップ

TPU を削除します。

   gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --force \
      --async

MaxDiffusion 推論

このチュートリアルでは、TPU v6e で MaxDiffusion モデルを提供する方法について説明します。このチュートリアルでは、Stable Diffusion XL モデルを使用して画像を生成します。

始める前に

  1. 4 つのチップを含む TPU v6e を作成します。

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. SSH を使用して TPU に接続します。

    gcloud compute tpus tpu-vm ssh TPU_NAME

Conda 環境を作成する

  1. Miniconda のディレクトリを作成します。

    mkdir -p ~/miniconda3
  2. Miniconda インストーラ スクリプトをダウンロードします。

    wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
  3. Miniconda をインストールします。

    bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
  4. Miniconda インストーラ スクリプトを削除します。

    rm -rf ~/miniconda3/miniconda.sh
  5. Miniconda を PATH 変数に追加します。

    export PATH="$HOME/miniconda3/bin:$PATH"
  6. ~/.bashrc を再読み込みして、PATH 変数に変更を適用します。

    source ~/.bashrc
  7. 新しい Conda 環境を作成します。

    conda create -n tpu python=3.10
  8. Conda 環境を有効にします。

    source activate tpu

MaxDiffusion を設定する

  1. MaxDiffusion リポジトリのクローンを作成し、MaxDiffusion ディレクトリに移動します。

    https://github.com/google/maxdiffusion.git && cd maxdiffusion
  2. mlperf-4.1 ブランチに切り替えます。

    git checkout mlperf4.1
  3. MaxDiffusion をインストールします。

    pip install -e .
  4. 依存関係をインストールします。

    pip install -r requirements.txt
  5. JAX をインストールします。

    pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

画像を生成

  1. 環境変数を設定して TPU ランタイムを構成します。

    LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536"
  2. src/maxdiffusion/configs/base_xl.yml で定義されたプロンプトと構成を使用して画像を生成します。

    python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run"

クリーンアップ

TPU を削除します。

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

トレーニング チュートリアル

以降のセクションでは、MaxText のトレーニングに関するチュートリアルについて説明します。

TPU v6e での MaxDiffusion モデルと PyTorch モデル。

MaxText と MaxDiffusion

以降のセクションでは、MaxText モデルと MaxDiffusion モデルのトレーニング ライフサイクルについて説明します。

一般的な手順は次のとおりです。

  1. ワークロードのベースイメージをビルドします。
  2. XPK を使用してワークロードを実行します。
    1. ワークロードのトレーニング コマンドをビルドします。
    2. ワークロードをデプロイします。
  3. ワークロードを追跡して指標を表示する。
  4. XPK ワークロードが不要な場合は削除します。
  5. 不要になった XPK クラスタを削除します。

ベースイメージをビルドする

MaxText または MaxDiffusion をインストールして Docker イメージをビルドします。

  1. 使用するリポジトリのクローンを作成し、リポジトリのディレクトリに移動します。

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. Google Cloud CLI を使用するように Docker を構成します。

    gcloud auth configure-docker
    
  3. 次のコマンドまたは JAX Stable Stack を使用して Docker イメージをビルドします。JAX Stable Stack の詳細については、JAX Stable Stack を使用して Docker イメージを作成するをご覧ください。

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    
  4. ローカルにビルドされたイメージがないマシンからワークロードを起動する場合は、イメージをアップロードします。

    bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
    
JAX Stable Stack を使用して Docker イメージを作成する

MaxText と MaxDiffusion の Docker イメージは、JAX Stable Stack ベースイメージを使用してビルドできます。

JAX Stable Stack は、JAX を orbaxflaxoptax などのコア パッケージと、TPU プログラム ユーティリティやその他の重要なツールを駆動する適格な libtpu.so とともにバンドルすることで、MaxText と MaxDiffusion の一貫した環境を提供します。これらのライブラリは互換性を確認するためにテストされており、MaxText と MaxDiffusion のビルドと実行のための安定した基盤を提供し、互換性のないパッケージ バージョンによる潜在的な競合を排除します。

JAX Stable Stack には、完全にリリースされ、適格性のある libtpu.so が含まれています。これは、TPU プログラムのコンパイル、実行、ICI ネットワーク構成を駆動するコア ライブラリです。libtpu リリースは、JAX で以前に使用されていたナイトリー ビルドに代わるもので、HLO/StableHLO IR で PJRT レベルの適格性テストを行い、TPU での XLA 計算の一貫した機能を保証します。

JAX Stable Stack を使用して MaxText と MaxDiffusion の Docker イメージをビルドするには、docker_build_dependency_image.sh スクリプトを実行するときに、MODE 変数を stable_stack に設定し、BASEIMAGE 変数を使用するベースイメージに設定します。

次の例では、us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1 をベースイメージとして指定しています。

bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1

使用可能な JAX Stable Stack ベースイメージの一覧については、Artifact Registry の JAX Stable Stack イメージをご覧ください。

XPK を使用してワークロードを実行する

  1. MaxText で設定されたデフォルト値または MaxDiffusion を使用していない場合は、次の環境変数を設定します。

    BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    PER_DEVICE_BATCH_SIZE=2
    NUM_STEPS=30
    MAX_TARGET_LENGTH=8192
  2. 次のステップでトレーニング コマンドとしてコピーされるモデル スクリプトを作成します。モデル スクリプトはまだ実行しないでください。

    MaxText

    MaxText は、ピュア Python と JAX で記述された、高性能でスケーラビリティに優れたオープンソースの LLM です。トレーニングと推論のために Google Cloud TPU と GPU をターゲットとしています。

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \
            base_output_directory=$BASE_OUTPUT_DIR \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}"  # attention='dot_product'"
    

    Gemma2

    Gemma は、Gemini の研究とテクノロジーに基づいて Google DeepMind が開発した、オープンウェイトの大規模言語モデル(LLM)ファミリーです。

    # Requires v6e-256
    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral は、Mistral AI によって開発された最先端の AI モデルであり、スパースな Mixture of Experts(MoE)アーキテクチャを利用しています。

    python3 MaxText/train.py MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama は、Meta が開発したオープン ウェイトの大規模言語モデル(LLM)のファミリーです。

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=llama3-8b \
        base_output_directory=${BASE_OUTPUT_DIR} \
        dataset_type=synthetic \
        tokenizer_path=assets/tokenizer_llama3.tiktoken \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \
        gcs_metrics=true \
        profiler=xplane \
        skip_first_n_steps_for_profiler=5 \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        attention=flash"
    

    MaxDiffusion

    MaxDiffusion は、Cloud TPU や GPU などの XLA デバイスで実行される、純粋な Python と JAX で記述されたさまざまな潜在拡散モデルのリファレンス実装のコレクションです。Stable Diffusion は、テキスト入力からフォトリアリスティックな画像を生成する、潜在的 text-to-image モデルです。

    MaxDiffusion を実行するには、特定のブランチをインストールする必要があります。

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    トレーニング スクリプト:

        cd maxdiffusion && OUT_DIR=${your_own_bucket}
        python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml \
            run_name=v6e-sd2 \
            split_head_dim=True \
            attention=flash \
            train_new_unet=false \
            norm_num_groups=16 \
            output_dir=${BASE_OUTPUT_DIR} \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            [dcn_data_parallelism=2] \
            enable_profiler=True \
            skip_first_n_steps_for_profiler=95 \
            max_train_steps=${NUM_STEPS} ]
            write_metrics=True'
        
  3. 前の手順で作成したスクリプトを使用してモデルを実行します。MaxText ベースイメージを使用するには、--base-docker-image フラグを指定するか、--docker-image フラグと使用するイメージを指定する必要があります。

    省略可: --enable-debug-logs フラグを含めると、デバッグ ロギングを有効にできます。詳細については、MaxText で JAX をデバッグするをご覧ください。

    省略可: --use-vertex-tensorboard フラグを指定して Vertex AI Experiments を作成し、Vertex AI TensorBoard にデータをアップロードできます。詳細については、Vertex AI を使用して MaxText で JAX をモニタリングするをご覧ください。

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \
        --workload ${USER}-xpk-ACCELERATOR_TYPE-NUM_SLICES \
        --tpu-type=ACCELERATOR_TYPE \
        --num-slices=NUM_SLICES  \
        --on-demand \
        --zone $ZONE \
        --project $PROJECT_ID \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command YOUR_MODEL_SCRIPT

    次の変数を置き換えます。

    • CLUSTER_NAME: XPK クラスタの名前。
    • ACCELERATOR_TYPE: TPU のバージョンとサイズ。例: v6e-256
    • NUM_SLICES: TPU スライスの数。
    • YOUR_MODEL_SCRIPT: トレーニング コマンドとして実行するモデル スクリプト。

    出力には、次のようなワークロードのフォローリンクが含まれます。

    [XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
    

    リンクを開き、[ログ] タブをクリックしてワークロードをリアルタイムで追跡します。

MaxText で JAX をデバッグする

補足 XPK コマンドを使用して、クラスタまたはワークロードが実行されていない理由を診断します。

Vertex AI を使用して MaxText の JAX をモニタリングする

Vertex AI のマネージド TensorBoard を使用して、スカラーデータとプロファイル データを表示します。

  1. 使用しているゾーンのリソース管理(CRUD)リクエストを 600 から 5,000 に増やします。これは、16 個未満の VM を使用する小規模なワークロードでは問題にならない場合があります。
  2. Vertex AI の cloud-accelerator-diagnostics などの依存関係をインストールします。

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Vertex AI TensorBoard を作成するで説明されているように、--create-vertex-tensorboard フラグを使用して XPK クラスタを作成します。このコマンドは既存のクラスタでも実行できます。

  4. --use-vertex-tensorboard フラグとオプションの --experiment-name フラグを使用して XPK ワークロードを実行するときに、Vertex AI テストを作成します。手順の一覧については、Vertex AI Experiments を作成して Vertex AI TensorBoard にデータをアップロードするをご覧ください。

ログには、次のような Vertex AI TensorBoard へのリンクが含まれています。

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Vertex AI TensorBoard のリンクは Google Cloud コンソールで確認することもできます。Google Cloud コンソールで [Vertex AI Experiments] に移動します。プルダウンから適切なリージョンを選択します。

TensorBoard ディレクトリも、${BASE_OUTPUT_DIR} で指定した Cloud Storage バケットに書き込まれます。

XPK ワークロードを削除する

xpk workload delete コマンドを使用して、ジョブ接頭辞またはジョブ ステータスに基づいて 1 つ以上のワークロードを削除します。このコマンドは、実行する必要がない XPK ワークロードを送信した場合や、キューにスタックしているジョブがある場合に便利です。

XPK クラスタを削除する

クラスタを削除するには、xpk cluster delete コマンドを使用します。

python3 xpk.py cluster delete --cluster CLUSTER_NAME --zone $ZONE --project $PROJECT_ID

Llama と PyTorch

このチュートリアルでは、WikiText データセットを使用して、TPU v6e で PyTorch/XLA を使用して Llama モデルをトレーニングする方法について説明します。また、PyTorch TPU モデルのレシピには、Docker イメージとしてこちらからアクセスできます。

インストール

Hugging Face Transformers の pytorch-tpu/transformers フォークと依存関係を仮想環境にインストールします。

git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
cd transformers
pip3 install -e .
pip3 install datasets
pip3 install evaluate
pip3 install scikit-learn
pip3 install accelerate

モデル構成を設定する

次のセクションのトレーニング コマンド(モデル スクリプトを作成する)では、2 つの JSON 構成ファイルを使用して、モデル パラメータと FSDP(完全にシャーディングされたデータ パラレル)構成を定義します。FSDP シャーディングは、トレーニング中にモデル重みが大きなバッチサイズに適合するために使用されます。小規模なモデルでトレーニングする場合は、データ並列処理を使用して各デバイスに重みを複製するだけで十分な場合があります。PyTorch/XLA でデバイス間でテンソルをシャーディングする方法については、PyTorch/XLA SPMD ユーザーガイドをご覧ください。

  1. モデル パラメータ構成ファイルを作成します。Llama3-8B のモデル パラメータ構成は次のとおりです。他のモデルの場合は、Hugging Face で構成を確認します。たとえば、Llama2-7B 構成をご覧ください。

    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
  2. FSDP 構成ファイルを作成します。

    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }

    FSDP について詳しくは、FSDPv2 をご覧ください。

  3. 次のコマンドを使用して、構成ファイルを TPU VM にアップロードします。

        gcloud alpha compute tpus tpu-vm scp YOUR_CONFIG_FILE.json $TPU_NAME:. \
            --worker=all \
            --project=$PROJECT \
            --zone $ZONE

    現在の作業ディレクトリに構成ファイルを作成し、XPK で --base-docker-image フラグを使用することもできます。

モデル スクリプトを作成する

--config_name フラグを使用してモデル パラメータ構成ファイルと、--fsdp_config フラグを使用して FSDP 構成ファイルを指定して、モデル スクリプトをビルドします。次のセクションのモデルを実行するで、このスクリプトを TPU で実行します。モデル スクリプトはまだ実行しないでください。

    PJRT_DEVICE=TPU
    XLA_USE_SPMD=1
    ENABLE_PJRT_COMPATIBILITY=true
    # Optional variables for debugging:
    XLA_IR_DEBUG=1
    XLA_HLO_DEBUG=1
    PROFILE_EPOCH=0
    PROFILE_STEP=3
    PROFILE_DURATION_MS=100000
    PROFILE_LOGDIR=local VM path or gs://my-bucket/profile_path
    python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 8 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/config-8B.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp_config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20

モデルを実行する

前の手順のモデル スクリプトを作成するで作成したスクリプトを使用してモデルを実行します。

単一ホストの TPU VM(v6e-4 など)を使用している場合は、TPU VM でトレーニング コマンドを直接実行できます。マルチホスト TPU VM を使用している場合は、次のコマンドを使用して、すべてのホストでスクリプトを同時に実行します。

gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \
    --zone $ZONE \
    --worker=all \
    --command=YOUR_COMMAND

PyTorch/XLA のトラブルシューティング

前のセクションでデバッグ用のオプション変数を設定した場合、モデルのプロファイルは変数 PROFILE_LOGDIR で指定された場所に保存されます。この場所に保存されている xplane.pb ファイルを抽出し、tensorboard を使用して TensorBoard の手順に沿ってブラウザでプロファイルを表示できます。PyTorch/XLA が期待どおりに動作しない場合は、トラブルシューティング ガイドをご覧ください。モデルのデバッグ、プロファイリング、最適化に関する推奨事項が記載されています。

DLRM DCN v2 チュートリアル

このチュートリアルでは、TPU v6e で DLRM DCN v2 モデルをトレーニングする方法について説明します。

マルチホストで実行している場合は、次のコマンドを実行して、適切な TensorFlow バージョンで tpu-runtime をリセットします。単一ホストで実行している場合は、次の 2 つのコマンドを実行する必要はありません。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID}
--zone  ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
 --zone  ${ZONE}   \
 --worker=all \
 --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

worker-0 に SSH 接続する

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}

TPU 名を設定する

export TPU_NAME=${TPU_NAME}

DLRM v2 を実行する

pip install cloud-tpu-client

pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git

export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'

TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py  --mode=train     --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
  distribution_strategy: tpu
  mixed_precision_dtype: 'mixed_bfloat16'
task:
  use_synthetic_data: false
  use_tf_record_reader: true
  train_data:
    input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  validation_data:
    input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  model:
    num_dense_features: 13
    bottom_mlp: [512, 256, 128]
    embedding_dim: 128
    interaction: 'multi_layer_dcn'
    dcn_num_layers: 3
    dcn_low_rank_dim: 512
    size_threshold: 8000
    top_mlp: [1024, 1024, 512, 256, 1]
    use_multi_hot: true
    concat_dense: false
    dcn_use_bias: true
    vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
    multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
    max_ids_per_chip_per_sample: 128
    max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
    max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
    use_partial_tpu_embedding: false
    size_threshold: 0
    initialize_tables_on_host: true
trainer:
  train_steps: 10000
  validation_interval: 1000
  validation_steps: 660
  summary_interval: 1000
  steps_per_loop: 1000
  checkpoint_interval: 0
  optimizer_config:
    embedding_optimizer: 'Adagrad'
    dense_optimizer: 'Adagrad'
    lr_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.025
      warmup_steps: 0
    dense_sgd_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.00025
      warmup_steps: 8000
  train_tf_function: true
  train_tf_while_loop: true
  eval_tf_while_loop: true
  use_orbit: true
  pipeline_sparse_and_dense_execution: true"

script.sh を実行します。

chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

推奨ワークロード(DLRM DCN)を実行するには、次のフラグが必要です。

ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"

ベンチマーク結果

以降のセクションでは、DLRM DCN v2 と v6e の MaxDiffusion のベンチマーク結果について説明します。

DLRM DCN v2

DLRM DCN v2 トレーニング スクリプトは、さまざまなスケールで実行されました。スループットは次の表を参照してください。

v6e-64 v6e-128 v6e-256
トレーニングのステップ 7000 7000 7000
グローバル バッチサイズ 131072 262144 524288
スループット(例/秒) 2975334 5111808 10066329

MaxDiffusion

MaxDiffusion のトレーニング スクリプトを v6e-4、v6e-16、2xv6e-16 で実行しました。スループットは次の表を参照してください。

v6e-4 v6e-16 2 個の v6e-16
トレーニングのステップ 0.069 0.073 0.13
グローバル バッチサイズ 8 32 64
スループット(例/秒) 115.9 438.4 492.3

コレクション

v6e では、サービング ワークロードを実行するユーザー向けにコレクションという新機能が導入されています。コレクション機能は v6e にのみ適用されます。

コレクションを使用すると、サービング ワークロードの一部となる TPU ノードを Google Cloud に指定できます。これにより、基盤となる Google Cloud インフラストラクチャは、通常の運用中にトレーニング ワークロードに適用される可能性のある中断を制限して効率化できます。

Cloud TPU API のコレクションを使用する

Cloud TPU API の単一ホスト コレクションは、ワークロードのサービス提供に使用されることを基盤となるインフラストラクチャに示すために、特別なフラグ(--workload-type = availability-optimized)が設定されたキューに入れられたリソースです。

次のコマンドは、Cloud TPU API を使用して単一ホスト コレクションをプロビジョニングします。

gcloud alpha compute tpus queued-resources create COLLECTION_NAME \
   --project=project name \
   --zone=zone name \
   --accelerator-type=accelerator type \
   --node-count=number of nodes \
   --workload-type=availability-optimized

モニタリングとプロファイル

Cloud TPU v6e は、以前の世代の Cloud TPU と同じ方法でモニタリングとプロファイリングをサポートしています。モニタリングの詳細については、TPU VM をモニタリングするをご覧ください。