Trillium(v6e)の概要

v6e は、このドキュメント、TPU API、ログで Trillium を指すために使用されます。v6e は Google の第 6 世代の TPU を表します。

Pod あたり 256 個のチップを使用する v6e アーキテクチャは、v5e と多くの類似点があります。このシステムは、トランスフォーマー、text-to-image、畳み込みニューラル ネットワーク(CNN)のトレーニング、微調整、サービス提供に最適化されています。

v6e システム アーキテクチャと構成については、v6e のドキュメントをご覧ください。

この概要ドキュメントでは、JAXPyTorchTensorFlow フレームワークを使用したモデルのトレーニングとサービス提供のプロセスについて説明します。各フレームワークでは、キューに格納されたリソースまたは Google Kubernetes Engine(GKE)を使用して TPU をプロビジョニングできます。GKE の設定は、XPK コマンドまたは GKE コマンドを使用して行うことができます。

v6e を使用してモデルをトレーニングまたはサービングする一般的な手順

  1. プロジェクトを準備する Google Cloud
  2. 容量を確保する
  3. TPU 環境を設定する
  4. Cloud TPU 環境をプロビジョニングする
  5. モデルの トレーニングまたは推論ワークロードを実行する
  6. クリーンアップ

Google Cloud プロジェクトを準備する

  1. Google アカウントにログインします。Google アカウントをまだお持ちでない場合は、新しいアカウントを登録します。
  2. Google Cloud コンソールで、プロジェクト セレクタページから Cloud プロジェクトを選択するか作成します。
  3. Google Cloud プロジェクトに対する課金を有効にします。Google Cloud の使用にはすべて課金が必要です。
  4. gcloud alpha components をインストールします。
  5. 次のコマンドを実行して、gcloud コンポーネントの最新バージョンをインストールします。

    gcloud components update
    
  6. Cloud Shell で、次の gcloud コマンドを使用して TPU API を有効にします。Google Cloud コンソールから有効にすることもできます。

    gcloud services enable tpu.googleapis.com
    
  7. Compute Engine API の TPU サービス アカウントで権限を有効にする

    サービス アカウントにより、Cloud TPU サービスが他の Google Cloud サービスにアクセスできるようになります。ユーザー管理のサービス アカウントは、Google Cloud で推奨される方法です。ガイドに沿ってロールを作成し、付与します。次のロールが必要です。

    • TPU 管理者
    • ストレージ管理者
    • ログ書き込み
    • モニタリング指標の書き込み

    a. GKE のユーザー アカウント(XPK)で XPK 権限を設定します。

  8. Google アカウントで認証し、デフォルトのプロジェクト ID とゾーンを設定します。
    auth login は、Google ユーザー認証情報を使用して Google Cloud にアクセスする gcloud を承認します。
    PROJECT_ID は Google Cloud プロジェクト名です。
    ZONE は、TPU を作成するゾーンです。

     gcloud auth login
     gcloud config set project ${PROJECT_ID}
     gcloud config set compute/zone ${ZONE}
    
  9. TPU VM のサービス ID を作成します。

     gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
    

容量を確保する

TPU 割り当てをリクエストし、容量に関する質問に回答するには、Cloud TPU サポートの営業担当者またはアカウントにお問い合わせください。

Cloud TPU 環境をプロビジョニングする

v6e TPU は、GKE、GKE と XPK(GKE をラップする CLI ツール)、またはキューに入れられたリソースでプロビジョニングして管理できます。

前提条件

  • プロジェクトに十分な TPUS_PER_TPU_FAMILY 割り当てがあることを確認します。これは、Google Cloud プロジェクト内でアクセスできるチップの最大数を指定します。
  • v6e は、次の構成でテストされています。
    • python 3.10 以降
    • ナイトリー ソフトウェアのバージョン:
      • 夜間 JAX 0.4.32.dev20240912
      • ナイトリー LibTPU 0.1.dev20240912+nightly
    • 安定版ソフトウェア バージョン:
      • JAX + v0.4.37 の JAX Lib
  • プロジェクトに次の TPU 割り当てがあることを確認します。

    • TPU VM の割り当て
    • IP アドレスの割り当て
    • Hyperdisk-balance 割り当て

  • ユーザー プロジェクトの権限

環境変数

Cloud Shell で、次の環境変数を作成します。

export NODE_ID=TPU_NODE_ID # TPU name
export PROJECT_ID=PROJECT_ID
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-east1-d
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID
export VALID_DURATION=VALID_DURATION

# Additional environment variable needed for provisioning Multislice:
export NUM_SLICES=NUM_SLICES

# Use a custom network for better performance as well as to avoid having the default network becoming overloaded.

export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

コマンドフラグの説明

変数 説明
NODE_ID キューに入れられたリソース リクエストの割り当て時に作成される TPU のユーザー割り当て ID。
PROJECT_ID Google Cloud プロジェクト名。既存のプロジェクトを使用するか、 で新しいプロジェクトを作成します。
ゾーン サポートされているゾーンについては、TPU のリージョンとゾーンのドキュメントをご覧ください。
ACCELERATOR_TYPE アクセラレータ タイプをご覧ください。
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT これは、Google Cloud コンソール -> IAM -> サービス アカウント で確認できるサービス アカウントのメールアドレスです。

例: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

NUM_SLICES 作成するスライスの数(マルチスライスの場合のみ必要)。
QUEUED_RESOURCE_ID キューに格納されたリソース リクエストのユーザー割り当てテキスト ID。
VALID_DURATION キューに入れられたリソース リクエストが有効である期間。
NETWORK_NAME 使用するセカンダリ ネットワークの名前。
NETWORK_FW_NAME 使用するセカンダリ ネットワーク ファイアウォールの名前。

ネットワーク パフォーマンスの最適化

最適なパフォーマンスを得るには、8,896 MTU(最大伝送単位)のネットワークを使用します。

デフォルトでは、Virtual Private Cloud(VPC)は 1,460 バイトの MTU のみを提供します。これにより、ネットワーク パフォーマンスが最適化されません。VPC ネットワークの MTU は、1,300 ~ 8,896 バイトの任意の値に設定できます。一般的なカスタム MTU サイズは 1,500 バイト(標準イーサネット)または 8,896 バイト(可能な最大値)です。詳細については、有効な VPC ネットワークの MTU サイズをご覧ください。

既存またはデフォルトのネットワークの MTU 設定の変更の詳細については、VPC ネットワークの MTU 設定を変更するをご覧ください。

次の例では、8,896 MTU のネットワークを作成します。

export RESOURCE_NAME=RESOURCE_NAME
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
export PROJECT=X
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
 --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME}
 --allow tcp,icmp,udp --project=${PROJECT}

マルチ NIC の使用(マルチスライス向けのオプション)

マルチスライス環境を使用している場合、セカンダリ サブネットには次の環境変数が必要です。

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2

次のコマンドを使用して、ネットワークとサブネットのカスタム IP ルーティングを作成します。

gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
   --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
   --network="${NETWORK_NAME_2}" \
   --range=10.10.0.0/18 --region="${REGION}" \
   --project=$PROJECT

gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
   --network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}

gcloud compute routers create "${ROUTER_NAME}" \
  --project="${PROJECT_ID}" \
  --network="${NETWORK_NAME_2}" \
  --region="${REGION}"

gcloud compute routers nats create "${NAT_CONFIG}" \
  --router="${ROUTER_NAME}" \
  --region="${REGION}" \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project="${PROJECT_ID}" \
  --enable-logging

マルチネットワーク スライスを作成したら、XPK クラスタを設定し、XPK ワークロードの一部として --command ifconfig を実行して、両方の NIC が使用されていることを確認できます。

次の xpk workload コマンドを使用して、ifconfig コマンドの出力を Cloud コンソールのログに表示し、eth0 と eth1 の両方で mtu=8896 であることを確認します。

python3 xpk.py workload create \
   --cluster CLUSTER_NAME \
   (--base-docker-image maxtext_base_image|--docker-image CLOUD_IMAGE_NAME \
   --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone $ZONE \
   --project $PROJECT_ID \
   [--enable-debug-logs] \
   [--use-vertex-tensorboard] \
   --command "ifconfig"

eth0 と eth1 の両方に mtu=8,896 があることを確認します。マルチ NIC が実行されていることを確認するには、XPK ワークロードの一部としてコマンド --command "ifconfig" を実行します。次に、Cloud コンソール ログでその xpk ワークロードの出力を確認して、eth0 と eth1 の両方に mtu=8896 があることを確認します。

TCP 設定の改善

キューに格納されたリソース インターフェースを使用して作成された TPU の場合は、次のコマンドを実行して TCP 受信バッファの上限を増やし、ネットワーク パフォーマンスを改善できます。

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
  --project "$PROJECT" \
  --zone "$ZONE" \
  --node=all \
  --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \
  --worker=all

キューに格納されたリソースを使用したプロビジョニング

割り振られた容量は、キューに入れられたリソースの create コマンドを使用してプロビジョニングできます。

  1. TPU のキューに入れられたリソース リクエストを作成します。

    --reserved フラグは、オンデマンド リソースではなく、予約済みリソースに対してのみ必要です。

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --valid-until-duration ${VALID_DURATION} \
      --service-account ${SERVICE_ACCOUNT} \
      [--reserved]
    
      # The following flags are only needed if you are using Multislice.
      --node-count node-count  # Number of slices in a Multislice \
      --node-prefix node-prefix # An optional user-defined node prefix;
       the default is QUEUED_RESOURCE_ID.

    キューに格納されたリソース リクエストが正常に作成されると、[response] フィールド内の状態は「WAITING_FOR_RESOURCES」または「FAILED」のいずれかになります。キューに追加されたリソース リクエストの状態が「WAITING_FOR_RESOURCES」の場合、リソースはキューに追加され、十分な TPU 容量が割り当てられたときにプロビジョニングされます。キューに格納されたリソース リクエストが「FAILED」状態の場合、失敗の理由が出力に表示されます。指定した時間内に v6e がプロビジョニングされず、状態が「FAILED」になった場合、キューに格納されたリソース リクエストは期限切れになります。詳細については、キューに格納されたリソースの公開ドキュメントをご覧ください。

    キューに追加されたリソース リクエストが「ACTIVE」状態になると、SSH を使用して TPU VM に接続できます。list コマンドまたは describe コマンドを使用して、キューに格納されたリソースのステータスをクエリします。

    gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project ${PROJECT_ID} --zone ${ZONE}
    

    キューに格納されたリソースが「ACTIVE」状態の場合、出力は次のようになります。

      state:
       state: ACTIVE
    
  2. TPU VM を管理します。TPU VM を管理するオプションについては、TPU VM の管理をご覧ください。

  3. SSH を使用して TPU VM に接続する

    TPU スライスの各 TPU VM にバイナリをインストールしてコードを実行できます。スライスに含まれる VM の数を決定するには、VM のタイプのセクションをご覧ください。

    バイナリをインストールするかコードを実行するには、SSH を使用して tpu-vm ssh コマンドを使用して VM に接続します。

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --node=all # add this flag if you are using Multislice
    

    SSH を使用して特定の VM に接続するには、0 ベースのインデックスに続けて --worker フラグを使用します。

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
    

    スライス形状が 8 チップを超える場合、1 つのスライスに複数の VM があります。この場合は、gcloud alpha compute tpus tpu-vm ssh コマンドの --worker=all パラメータと --command パラメータを使用して、すべての VM でコマンドを同時に実行します。次に例を示します。

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
      --zone  ${ZONE} --worker=all \
      --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
      -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. キューに格納されているリソースの削除

    セッションの終了時にキューに格納されているリソースを削除するか、「FAILED」状態のキューに入れられたリソース リクエストを削除します。キューに格納されたリソースを削除するには、次の 2 つの手順でスライスを削除してから、キュー内のリソースを削除します。

    gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \
     --zone=${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
     --project ${PROJECT_ID} --zone ${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
    

GKE または XPK で v6e TPU をプロビジョニングする

v6e で GKE コマンドを使用している場合は、Kubernetes コマンドまたは XPK を使用して TPU をプロビジョニングし、モデルをトレーニングまたは提供できます。GKE クラスタで TPU 構成を計画する方法については、GKE で TPU を計画するをご覧ください。以降のセクションでは、単一 NIC とマルチ NIC をサポートする XPK クラスタを作成するコマンドについて説明します。

単一 NIC をサポートする XPK クラスタを作成するコマンド

export CLUSTER_NAME xpk-cluster-name
export ZONE=us-central2-b
export PROJECT=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
   gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
   gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network ${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
   python3 xpk.py cluster create --cluster $CLUSTER_NAME \
   --cluster-cpu-machine-type=n1-standard-8 \
   --num-slices=$NUM_SLICES \
   --tpu-type=$TPU_TYPE \
   --zone=$ZONE  \
   --project=$PROJECT \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}"  \
   --create-vertex-tensorboard

コマンドフラグの説明

変数 説明
CLUSTER_NAME XPK クラスタにユーザーが割り当てた名前。
PROJECT_ID Google Cloud プロジェクト名。既存のプロジェクトを使用するか、 で新しいプロジェクトを作成します。
ゾーン サポートされているゾーンについては、TPU のリージョンとゾーンのドキュメントをご覧ください。
TPU_TYPE アクセラレータ タイプをご覧ください。
NUM_SLICES 作成するスライスの数
CLUSTER_ARGUMENTS 使用するネットワークとサブネットワーク。

例: "--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"

NUM_SLICES 作成するスライス数。
NETWORK_NAME 使用するセカンダリ ネットワークの名前。
NETWORK_FW_NAME 使用するセカンダリ ネットワーク ファイアウォールの名前。

マルチ NIC をサポートする XPK クラスタを作成するコマンド

export CLUSTER_NAME xpk-cluster-name
export ZONE=us-central2-b
export PROJECT=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export exportSUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
   gcloud compute networks create "${NETWORK_NAME_1}" \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=$PROJECT
   gcloud compute networks subnets create "${SUBNET_NAME_1}" \
   --network="${NETWORK_NAME_1}" \
   --range=10.11.0.0/18 \
   --region="${REGION}" \
   --project=$PROJECT
   gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
   --network "${NETWORK_NAME_1}" \
   --allow tcp,icmp,udp \
   --project="${PROJECT}"
  gcloud compute routers create "${ROUTER_NAME}" \
    --project="${PROJECT}" \
    --network="${NETWORK_NAME_1}" \
    --region="${REGION}"
  gcloud compute routers nats create "${NAT_CONFIG}" \
     --router="${ROUTER_NAME}" \
     --region="${REGION}" \
     --auto-allocate-nat-external-ips \
     --nat-all-subnet-ip-ranges \
     --project="${PROJECT}" \
     --enable-logging
Secondary subnet for multi-nic experience. Need custom ip routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
   gcloud compute networks create "${NETWORK_NAME_2}" \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=$PROJECT
   gcloud compute networks subnets create "${SUBNET_NAME_2}" \
   --network="${NETWORK_NAME_2}" \
   --range=10.10.0.0/18 \
   --region="${REGION}" \
   --project=$PROJECT
   gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
   --network "${NETWORK_NAME_2}" \
   --allow tcp,icmp,udp \
   --project="${PROJECT}"
   gcloud compute routers create "${ROUTER_NAME}" \
     --project="${PROJECT}" \
     --network="${NETWORK_NAME_2}" \
     --region="${REGION}"
   gcloud compute routers nats create "${NAT_CONFIG}" \
     --router="${ROUTER_NAME}" \
     --region="${REGION}" \
     --auto-allocate-nat-external-ips \
     --nat-all-subnet-ip-ranges \
     --project="${PROJECT}" \
     --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking
--network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"

export NODE_POOL_ARGUMENTS="--additional-node-network
network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 ~/xpk/xpk.py cluster create \
--cluster $CLUSTER_NAME \
--num-slices=$NUM_SLICES \
--tpu-type=$TPU_TYPE \
--zone=$ZONE  \
--project=$PROJECT \
--on-demand \
--custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
--custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
--create-vertex-tensorboard

コマンドフラグの説明

変数 説明
CLUSTER_NAME XPK クラスタにユーザーが割り当てた名前。
PROJECT_ID Google Cloud プロジェクト名。既存のプロジェクトを使用するか、 で新しいプロジェクトを作成します。
ゾーン サポートされているゾーンについては、TPU のリージョンとゾーンのドキュメントをご覧ください。
TPU_TYPE アクセラレータ タイプをご覧ください。
NUM_SLICES 作成するスライスの数
CLUSTER_ARGUMENTS 使用するネットワークとサブネットワーク。

例: "--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"

NODE_POOL_ARGUMENTS 使用する追加のノード ネットワーク。

例: --additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES 作成するスライスの数(マルチスライスの場合のみ必要)。
NETWORK_NAME 使用するセカンダリ ネットワークの名前。
NETWORK_FW_NAME 使用するセカンダリ ネットワーク ファイアウォールの名前。

フレームワークの設定

このセクションでは、JAXPyTorchTensorFlow フレームワークを使用した ML モデル トレーニングの一般的な設定プロセスについて説明します。キューに入れられたリソースまたは GKE を使用して TPU をプロビジョニングできます。GKE の設定は、XPK または Kubernetes コマンドを使用して行うことができます。

JAX を設定する

このセクションでは、XPK の有無にかかわらず GKE で JAX ワークロードを実行する例と、キューに入れられたリソースを使用する例を示します。

GKE を使用して JAX を設定する

次の例では、Kubernetes YAML ファイルを使用して 2X2 の単一ホストを設定します。

単一ホストの単一スライス

apiVersion: v1
kind: Pod
metadata:
  name: tpu-pod-jax-v6e-a
spec:
  restartPolicy: Never
  nodeSelector:
    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
    cloud.google.com/gke-tpu-topology: 2x2
  containers:
  - name: tpu-job
    image: python:3.10
    securityContext:
      privileged: true
    command:
    - bash
    - -c
    - |
      pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
    resources:
      requests:
        google.com/tpu: 4
      limits:
        google.com/tpu: 4

正常に完了すると、GKE ログに次のメッセージが表示されます。

Total TPU chips: 4

マルチホスト上の単一スライス

次の例では、Kubernetes YAML ファイルを使用して 4X4 マルチホスト ノードプールを設定します。

apiVersion: v1
kind: Service
metadata:
  name: headless-svc
spec:
  clusterIP: None
  selector:
    job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-available-chips
spec:
  backoffLimit: 0
  completions: 4
  parallelism: 4
  completionMode: Indexed
  template:
    spec:
      subdomain: headless-svc
      restartPolicy: Never
      nodeSelector:
        cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
        cloud.google.com/gke-tpu-topology: 4x4
      containers:
      - name: tpu-job
        image: python:3.10
        ports:
        - containerPort: 8471 # Default port using which TPU VMs communicate
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
          JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 4
          limits:
            google.com/tpu: 4

正常に完了すると、GKE ログに次のメッセージが表示されます。

Total TPU chips: 16

マルチホストのマルチスライス

次の例では、Kubernetes YAML ファイルを使用して 2 つの 4X4 マルチホスト ノードプールを設定します。

前提条件として、v0.2.3 以降の JobSet をインストールする必要があります。

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: multislice-job
  annotations:
    alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
  failurePolicy:
    maxRestarts: 4
  replicatedJobs:
    - name: slice
      replicas: 2
      template:
        spec:
          parallelism: 4
          completions: 4
          backoffLimit: 0
          template:
            spec:
              hostNetwork: true
              dnsPolicy: ClusterFirstWithHostNet
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
              hostNetwork: true
              containers:
              - name: jax-tpu
                image: python:3.10
                ports:
                - containerPort: 8471
                - containerPort: 8080
                - containerPort: 8431
                securityContext:
                  privileged: true
                command:
                - bash
                - -c
                - |
                  pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                  JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
                resources:
                  limits:
                   google.com/tpu: 4
                  requests:
                   google.com/tpu: 4

正常に完了すると、GKE ログに次のメッセージが表示されます。

Total TPU chips: 32

詳細については、GKE ドキュメントのマルチスライス ワークロードを実行するをご覧ください。

パフォーマンスを向上させるには、hostNetwork を有効化します。

マルチ NIC

GKE でマルチ NIC を利用する場合は、Kubernetes Pod マニフェストに追加のアノテーションが必要です。次の例は、TPU 以外のマルチ NIC ワークロードのマニフェストです。

apiVersion: v1
kind: Pod
metadata:
  name: sample-netdevice-pod-1
  annotations:
    networking.gke.io/default-interface: 'eth0'
    networking.gke.io/interfaces: |
      [
        {"interfaceName":"eth0","network":"default"},
        {"interfaceName":"eth1","network":"netdevice-network"}
      ]
spec:
  containers:
  - name: sample-netdevice-pod
    image: busybox
    command: ["sleep", "infinity"]
    ports:
    - containerPort: 80
  restartPolicy: Always
  tolerations:
  - key: "google.com/tpu"
    operator: "Exists"
    effect: "NoSchedule"

Kubernetes Pod に exec すると、次のコードを使用して追加の NIC が表示されます。

$ k exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
    link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
    inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
       valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
    link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
    inet 172.24.0.4/32 scope global eth1
       valid_lft forever preferred_lft forever

XPK で GKE を使用して JAX を設定する

例については、xpk README をご覧ください。

MaxText で XPK を設定して実行するには、MaxText の実行方法をご覧ください。

キューに入れられたリソースを使用して JAX を設定する

gcloud alpha compute tpus tpu-vm ssh を使用して、スライス内のすべての TPU VM に JAX を同時にインストールします。マルチスライスの場合は、--node=all を追加します。


gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'

次の Python コードを実行して、スライドで使用可能な TPU コアの数を確認し、すべてが正しくインストールされていることをテストできます(ここに表示されている出力は、v6e-16 スライスで生成されたものです)。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

出力は次のようになります。

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() は、指定されたスライス内のチップの合計数を示します。jax.local_device_count() は、このスライス内の単一の VM からアクセス可能なチップの数を示します。

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
   && pip install -r requirements.txt  && pip install . '

JAX の設定に関するトラブルシューティング

一般的なヒントとして、GKE ワークロード マニフェストで詳細なロギングを有効にします。次に、ログを GKE サポートに提供します。

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

エラー メッセージ

no endpoints available for service 'jobset-webhook-service'

このエラーは、ジョブセットが正しくインストールされていないことを意味します。jobset-controller-manager Deployment Kubernetes Pod が実行されていることを確認します。詳細については、JobSet のトラブルシューティングに関するドキュメントをご覧ください。

TPU initialization failed: Failed to connect

GKE ノード バージョンが 1.30.4-gke.1348000 以降であることを確認します(GKE 1.31 はサポートされていません)。

PyTorch を設定する

このセクションでは、PyTorch/XLA を使用して v6e で PJRT の使用を開始する方法について説明します。Python 3.10 が推奨される Python バージョンです。

XPK で GKE を使用して PyTorch を設定する

PyTorch の依存関係がすでにインストールされている XPK で、次の Docker コンテナを使用できます。

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

XPK ワークロードを作成するには、次のコマンドを使用します。

python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -${NUM_SLICES} \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES}  \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

--base-docker-image を使用すると、現在の作業ディレクトリが新しい Docker にビルドされた新しい Docker イメージが作成されます。

キューに入れられたリソースを使用して PyTorch を設定する

キューに登録されたリソースを使用して PyTorch をインストールし、v6e で小さなスクリプトを実行する手順は次のとおりです。

SSH を使用して依存関係をインストールし、VM にアクセスする

マルチスライスの場合は、--node=all を追加します。

   gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt install -y libopenblas-base pip3 \
    install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
    --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

サイズが大きく、頻繁な割り当てがあるモデルのパフォーマンスを改善する

サイズ設定に頻繁な割り当てがあるモデルの場合、tcmalloc を使用すると、デフォルトの malloc 実装よりもパフォーマンスが大幅に向上するため、TPU VM で使用されるデフォルトの malloctcmalloc です。しかし、ワークロードによっては(たとえば、埋め込みテーブルへの割り当てが非常に大きい DLRM の場合)、tcmalloc では速度が低下する可能性があります。その場合は、次の変数の設定を解除して、代わりにデフォルトの malloc を使用してください。

unset LD_PRELOAD

Python スクリプトを使用して v6e VM で計算を行います。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

これにより、次のような出力が生成されます。

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

TensorFlow の設定

v6e 公開プレビュー版では、tf-nightly ランタイム バージョンのみがサポートされています。

次のコマンドを実行して、v6e 互換の TensorFlow バージョンで tpu-runtime をリセットできます。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
    --zone  ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
    --zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

SSH を使用して worker-0 にアクセスします。

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
     --zone ${ZONE}

worker-0 に TensorFlow をインストールします。

sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

TPU_NAME 環境変数をエクスポートします。

export TPU_NAME=v6e-16

次の Python スクリプトを実行して、スライドで使用可能な TPU コアの数を確認し、すべてが正しくインストールされていることをテストできます(ここに示す出力は v6e-16 スライスで生成されています)。

import TensorFlow as tf
print("TensorFlow version " + tf.__version__)

@tf.function
  def add_fn(x,y):
  z = x + y
  return z

  cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  tf.config.experimental_connect_to_cluster(cluster_resolver)
  tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
  strategy = tf.distribute.TPUStrategy(cluster_resolver)

  x = tf.constant(1.)
  y = tf.constant(1.)
  z = strategy.run(add_fn, args=(x,y))
  print(z)

出力は次のようになります。

PerReplica:{
  0: tf.Tensor(2.0, shape=(), dtype=float32),
  1: tf.Tensor(2.0, shape=(), dtype=float32),
  2: tf.Tensor(2.0, shape=(), dtype=float32),
  3: tf.Tensor(2.0, shape=(), dtype=float32),
  4: tf.Tensor(2.0, shape=(), dtype=float32),
  5: tf.Tensor(2.0, shape=(), dtype=float32),
  6: tf.Tensor(2.0, shape=(), dtype=float32),
  7: tf.Tensor(2.0, shape=(), dtype=float32)
}

v6e(SkyPilot 対応)

SkyPilot では TPU v6e を使用できます。次の手順で、v6e 関連の場所/料金情報を SkyPilot に追加します。

  1. ~/.sky/catalogs/v5/gcp/vms.csv の末尾に次のように追加します。

    ,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
    
  2. YAML ファイルで次のリソースを指定します。

    # tpu_v6.yaml
    resources:
      accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use
      accelerator_args:
        runtime_version: v2-alpha-tpuv6e # Official suggested runtime
    
  3. TPU v6e を使用してクラスタを起動します。

       sky launch tpu_v6.yaml -c tpu_v6
    
  4. SSH を使用して TPU v6e に接続します。ssh tpu_v6

推論チュートリアル

次のチュートリアルでは、TPU v6e で推論を実行する方法について説明します。

トレーニング チュートリアル

以降のセクションでは、TPU v6e で MaxText、MaxDiffusion、PyTorch モデルをトレーニングするチュートリアルについて説明します。

v6e Cloud TPU VM での MaxText と MaxDiffusion のトレーニング

以降のセクションでは、MaxText モデルと MaxDiffusion モデルのトレーニング ライフサイクルについて説明します。

一般的な手順は次のとおりです。

  1. ワークロードのベースイメージをビルドします。
  2. XPK を使用してワークロードを実行します。
    1. ワークロードのトレーニング コマンドをビルドします。
    2. ワークロードをデプロイします。
  3. ワークロードを追跡して指標を表示する。
  4. XPK ワークロードが不要な場合は削除します。
  5. 不要になった XPK クラスタを削除します。

ベースイメージをビルドする

MaxText または MaxDiffusion をインストールして Docker イメージをビルドします。

  1. 使用するリポジトリのクローンを作成し、リポジトリのディレクトリに移動します。

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. Google Cloud CLI を使用するように Docker を構成します。

    gcloud auth configure-docker
    
  3. 次のコマンドまたは JAX Stable Stack を使用して Docker イメージをビルドします。JAX Stable Stack の詳細については、JAX Stable Stack を使用して Docker イメージを作成するをご覧ください。

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.37
    
  4. ローカルにビルドされたイメージがないマシンからワークロードを起動する場合は、イメージをアップロードします。

    bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
    
JAX Stable Stack を使用して Docker イメージを作成する

MaxText と MaxDiffusion の Docker イメージは、JAX Stable Stack ベースイメージを使用してビルドできます。

JAX Stable Stack は、JAX を orbaxflaxoptax などのコア パッケージと、TPU プログラム ユーティリティやその他の重要なツールを駆動する適格な libtpu.so とともにバンドルすることで、MaxText と MaxDiffusion の一貫した環境を提供します。これらのライブラリは、互換性を確保し、MaxText と MaxDiffusion のビルドと実行のための安定した基盤を提供するためにテストされています。これにより、互換性のないパッケージ バージョンによる競合の発生を防ぐことができます。

JAX Stable Stack には、完全にリリースされ、適格性のある libtpu.so が含まれています。これは、TPU プログラムのコンパイル、実行、ICI ネットワーク構成を駆動するコア ライブラリです。libtpu リリースは、JAX で以前に使用されていたナイトリー ビルドに代わるもので、HLO/StableHLO IR で PJRT レベルの適格性テストを行い、TPU での XLA 計算の一貫した機能を保証します。

JAX Stable Stack を使用して MaxText と MaxDiffusion の Docker イメージをビルドするには、docker_build_dependency_image.sh スクリプトを実行するときに、MODE 変数を stable_stack に設定し、BASEIMAGE 変数を使用するベースイメージに設定します。

次の例では、us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1 をベースイメージとして指定します。

bash docker_build_dependency_image.sh MODE=stable_stack
BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1

使用可能な JAX Stable Stack ベースイメージの一覧については、Artifact Registry の JAX Stable Stack イメージをご覧ください。

XPK を使用してワークロードを実行する

  1. MaxText または MaxDiffusion で設定されたデフォルト値を使用しない場合は、次の環境変数を設定します。

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. モデル スクリプトを作成します。このスクリプトは、後続のステップでトレーニング コマンドとしてコピーされます。

    モデル スクリプトはまだ実行しないでください。

    MaxText

    MaxText は、ピュア Python と JAX で記述された、高パフォーマンスでスケーラビリティに優れたオープンソースの LLM です。トレーニングと推論に TPU と GPU をターゲットとしています。 Google Cloud

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \
            base_output_directory=${BASE_OUTPUT_DIR} \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma は、Gemini の研究とテクノロジーに基づいて Google DeepMind が開発したオープン重み LLM のファミリーです。

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral は、Mistral AI によって開発された最先端の AI モデルであり、スパースな Mixture of Experts(MoE)アーキテクチャを利用しています。

    python3 MaxText/train.py MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama は、Meta が開発したオープン ウェイト LLM のファミリーです。

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=llama3-8b \
        base_output_directory=${BASE_OUTPUT_DIR} \
        dataset_type=synthetic \
        tokenizer_path=assets/tokenizer_llama3.tiktoken \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \
        gcs_metrics=true \
        profiler=xplane \
        skip_first_n_steps_for_profiler=5 \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        attention=flash"
    

    MaxDiffusion

    MaxDiffusion は、Cloud TPU や GPU などの XLA デバイスで実行される、純粋な Python と JAX で記述されたさまざまな潜在拡散モデルのリファレンス実装のコレクションです。Stable Diffusion は、テキスト入力からフォトリアリスティックな画像を生成する、潜在的 text-to-image モデルです。

    MaxDiffusion を実行するには、次の git checkout コマンドに示すように、特定の Git ブランチをインストールする必要があります。

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    トレーニング スクリプト:

        cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \
        python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR}  \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash run_name=sdxl-ddp-v6e
    
        
  3. 前の手順で作成したスクリプトを使用してモデルを実行します。MaxText ベースイメージを使用するには、--base-docker-image フラグを指定するか、--docker-image フラグと使用するイメージを指定する必要があります。

    省略可: --enable-debug-logs フラグを含めると、デバッグ ロギングを有効にできます。詳細については、MaxText で JAX をデバッグするをご覧ください。

    省略可: --use-vertex-tensorboard フラグを指定して Vertex AI Experiments を作成し、Vertex AI TensorBoard にデータをアップロードできます。詳細については、Vertex AI を使用して MaxText で JAX をモニタリングするをご覧ください。

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \
        --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \
        --tpu-type=$ACCELERATOR_TYPE \
        --num-slices=$NUM_SLICES  \
        --on-demand \
        --zone $ZONE \
        --project $PROJECT_ID \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command $YOUR-MODEL-SCRIPT

    次の変数をエクスポートします。

    export ClUSTER_NAME=CLUSTER_NAME: XPK クラスタの名前。 export ACCELERATOR_TYPEACCELERATOR_TYPE: TPU のバージョンとサイズ。例: v6e-256export NUM_SLICES=NUM_SLICES: TPU スライスの数。 export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT: トレーニング コマンドとして実行するモデル スクリプト。

    出力には、次のようなワークロードのフォローリンクが含まれます。

    [XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
    

    リンクを開き、[ログ] タブをクリックして、ワークロードをリアルタイムで追跡します。

MaxText で JAX をデバッグする

補足 XPK コマンドを使用して、クラスタまたはワークロードが実行されていない理由を診断します。

Vertex AI を使用して MaxText の JAX をモニタリングする

Vertex AI のマネージド TensorBoard を使用して、スカラーデータとプロファイル データを表示します。

  1. 使用しているゾーンのリソース管理(CRUD)リクエストを 600 から 5,000 に増やします。16 個未満の VM を使用する小規模なワークロードでは、問題にならない場合があります。
  2. Vertex AI の cloud-accelerator-diagnostics などの依存関係をインストールします。

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Vertex AI TensorBoard を作成するで説明されているように、--create-vertex-tensorboard フラグを使用して XPK クラスタを作成します。このコマンドは既存のクラスタでも実行できます。

  4. --use-vertex-tensorboard フラグとオプションの --experiment-name フラグを使用して XPK ワークロードを実行するときに、Vertex AI テストを作成します。手順の一覧については、Vertex AI Experiments を作成して Vertex AI TensorBoard にデータをアップロードするをご覧ください。

ログには、次のような Vertex AI TensorBoard へのリンクが含まれています。

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Vertex AI TensorBoard のリンクは Google Cloud コンソールで確認することもできます。Google Cloud コンソールで [Vertex AI Experiments] に移動します。プルダウンから適切なリージョンを選択します。

TensorBoard ディレクトリも、${BASE_OUTPUT_DIR} で指定した Cloud Storage バケットに書き込まれます。

XPK ワークロードを削除する

xpk workload delete コマンドを使用して、ジョブ接頭辞またはジョブのステータスに基づいて 1 つ以上のワークロードを削除します。このコマンドは、実行する必要がない XPK ワークロードを送信した場合や、キューにスタックしているジョブがある場合に便利です。

XPK クラスタを削除する

クラスタを削除するには、xpk cluster delete コマンドを使用します。

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
--zone $ZONE --project $PROJECT_ID

v6e Cloud TPU VM での Llama と PyTorch/XLA トレーニング

このチュートリアルでは、WikiText データセットを使用して、TPU v6e で PyTorch/XLA を使用して Llama モデルをトレーニングする方法について説明します。

Hugging Face と Llama 3 モデルにアクセスする

このチュートリアルを実行するには、Hugging Face ユーザー アクセス トークンが必要です。ユーザー アクセス トークンの作成と使用については、ユーザー アクセス トークンの Hugging Face のドキュメントをご覧ください。

また、Hugging Face の Llama 3 8B モデルにアクセスする権限も必要です。アクセス権を取得するには、Hugging Face の Meta-Llama-3-8B モデルにアクセスしてアクセス権をリクエストします。

TPU VM を作成する

チュートリアルを実行するために、8 個のチップを持つ TPU v6e を作成します。

  1. 環境変数を設定します。

    export ACCELERATOR_TYPE=v6e-8
    export VERSION=v2-alpha-tpuv6e
    export TPU_NAME=$USER-$ACCELERATOR_TYPE
    export PROJECT=YOUR_PROJECT
    export ZONE=YOUR_ZONE
  2. TPU VM を作成します。

    gcloud alpha compute tpus tpu-vm create $TPU_NAME --version=$VERSION \
        --accelerator-type=$ACCELERATOR_TYPE --zone=$ZONE --project=$PROJECT

インストール

Hugging Face Transformers の pytorch-tpu/transformers フォークと依存関係をインストールします。このチュートリアルは、この例で使用されている次の依存関係のバージョンでテストされています。

  • torch: 2.5.0 と互換性あり
  • torch_xla[tpu]: 2.5.0 と互換性あり
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT --zone $ZONE \
    --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
    cd transformers
    sudo pip3 install -e .
    pip3 install datasets
    pip3 install evaluate
    pip3 install scikit-learn
    pip3 install accelerate
    pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
    pip install jax==0.4.33 jaxlib==0.4.33 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

モデル構成を設定する

次のセクションのトレーニング コマンド(モデルを実行する)では、2 つの JSON 構成ファイルを使用して、モデル パラメータと FSDP(完全にシャーディングされたデータ パラレル)構成を定義します。FSDP シャーディングは、トレーニング中にモデルの重みが大きなバッチサイズに適合するために使用されます。小規模なモデルでトレーニングする場合は、データ並列処理を使用して各デバイスに重みを複製するだけで十分な場合があります。PyTorch/XLA でデバイス間でテンソルをシャーディングする方法については、PyTorch/XLA SPMD ユーザーガイドをご覧ください。

  1. モデル パラメータ構成ファイルを作成します。Llama3-8B のモデル パラメータ構成は次のとおりです。他のモデルについては、Hugging Face で構成を確認してください。たとえば、Llama2-7B 構成をご覧ください。

    cat > llama-config.json <
    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
    EOF
  2. FSDP 構成ファイルを作成します。

    cat > fsdp-config.json <
    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }
    EOF

    FSDP の詳細については、FSDPv2 をご覧ください。

  3. 次のコマンドを使用して、構成ファイルを TPU VM にアップロードします。

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json $TPU_NAME:. \
        --worker=all \
        --project=$PROJECT \
        --zone $ZONE

モデルを実行する

前のセクションで作成した構成ファイルを使用して run_clm.py スクリプトを実行し、WikiText データセットで Llama 3 8B モデルをトレーニングします。トレーニング スクリプトを TPU v6e-8 で実行すると、約 10 分かかります。

  1. 次のコマンドを使用して、TPU で Hugging Face にログインします。

    gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \
        --zone $ZONE \
        --worker=all \
        --command='
        pip3 install "huggingface_hub[cli]"
        huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. モデルのトレーニングを実行します。

    gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \
        --zone $ZONE \
        --worker=all \
        --command='
        export PJRT_DEVICE=TPU
        export XLA_USE_SPMD=1
        export ENABLE_PJRT_COMPATIBILITY=true
            # Optional variables for debugging:
        export XLA_IR_DEBUG=1
        export XLA_HLO_DEBUG=1
        export PROFILE_EPOCH=0
        export PROFILE_STEP=3
        export PROFILE_DURATION_MS=100000
            # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
        export PROFILE_LOGDIR=PROFILE_PATH
        python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 16 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/llama-config.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp-config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20'

PyTorch/XLA のトラブルシューティング

前のセクションでデバッグ用のオプション変数を設定した場合、モデルのプロファイルは変数 PROFILE_LOGDIR で指定された場所に保存されます。この場所に保存されている xplane.pb ファイルを抽出し、tensorboard を使用して TensorBoard の手順に沿ってブラウザでプロファイルを表示できます。PyTorch/XLA が期待どおりに動作しない場合は、トラブルシューティング ガイドをご覧ください。モデルのデバッグ、プロファイリング、最適化に関する推奨事項が記載されています。

v6e での DLRM DCN v2 トレーニング

このチュートリアルでは、TPU v6e で DLRM DCN v2 モデルをトレーニングする方法について説明します。64、128、256 個のチップで TPU v6e をプロビジョニングする必要があります。

マルチホストで実行している場合は、次のコマンドを実行して、適切な TensorFlow バージョンで tpu-runtime をリセットします。単一ホストで実行している場合は、次の 2 つのコマンドを実行する必要はありません。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID}
--zone  ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
 --zone  ${ZONE}   \
 --worker=all \
 --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

worker-0 に SSH 接続する

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}

TPU 名を設定します。

export TPU_NAME=${TPU_NAME}

DLRM v2 を実行する

pip install --user setuptools==65.5.0

pip install cloud-tpu-client

pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git

export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'

TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py  --mode=train     --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
  distribution_strategy: tpu
  mixed_precision_dtype: 'mixed_bfloat16'
task:
  use_synthetic_data: false
  use_tf_record_reader: true
  train_data:
    input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  validation_data:
    input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  model:
    num_dense_features: 13
    bottom_mlp: [512, 256, 128]
    embedding_dim: 128
    interaction: 'multi_layer_dcn'
    dcn_num_layers: 3
    dcn_low_rank_dim: 512
    size_threshold: 8000
    top_mlp: [1024, 1024, 512, 256, 1]
    use_multi_hot: true
    concat_dense: false
    dcn_use_bias: true
    vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
    multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
    max_ids_per_chip_per_sample: 128
    max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
    max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
    use_partial_tpu_embedding: false
    size_threshold: 0
    initialize_tables_on_host: true
trainer:
  train_steps: 10000
  validation_interval: 1000
  validation_steps: 660
  summary_interval: 1000
  steps_per_loop: 1000
  checkpoint_interval: 0
  optimizer_config:
    embedding_optimizer: 'Adagrad'
    dense_optimizer: 'Adagrad'
    lr_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.025
      warmup_steps: 0
    dense_sgd_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.00025
      warmup_steps: 8000
  train_tf_function: true
  train_tf_while_loop: true
  eval_tf_while_loop: true
  use_orbit: true
  pipeline_sparse_and_dense_execution: true"

script.sh を実行します。

chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

推奨ワークロード(DLRM DCN)を実行するには、次のフラグが必要です。

ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"

ベンチマーク結果

次のセクションでは、v6e での DLRM DCN v2 と MaxDiffusion のベンチマーク結果について説明します。

DLRM DCN v2

DLRM DCN v2 トレーニング スクリプトは、さまざまなスケールで実行されました。スループットは次の表を参照してください。

v6e-64 v6e-128 v6e-256
トレーニングのステップ 7000 7000 7000
グローバル バッチサイズ 131072 262144 524288
スループット(例/秒) 2975334 5111808 10066329

MaxDiffusion

MaxDiffusion のトレーニング スクリプトを v6e-4、v6e-16、2xv6e-16 で実行しました。スループットは次の表を参照してください。

v6e-4 v6e-16 2 個の v6e-16
トレーニングのステップ 0.069 0.073 0.13
グローバル バッチサイズ 8 32 64
スループット(例/秒) 115.9 438.4 492.3

収集のスケジュール設定

Trillium(v6e)には、「コレクション スケジューリング」という新しい機能が含まれています。この機能を使用すると、GKE と Cloud TPU API の両方で単一ホストの 推論ワークロードを実行する複数の TPU スライスを管理できます。これらのスライスをコレクションにグループ化すると、需要に合わせてレプリカの数を簡単に調整できます。ソフトウェアの更新は慎重に制御され、コレクション内のスライスの一部が常に使用可能で、受信トラフィックを処理できるようにします。

GKE で収集のスケジュール設定を使用する方法の詳細については、GKE のドキュメントをご覧ください。

収集のスケジュール設定機能は v6e にのみ適用されます。

Cloud TPU API からの収集スケジューリングを使用する

Cloud TPU API の単一ホスト コレクションは、ワークロードのサービス提供に使用されることを基盤となるインフラストラクチャに示すために、特別なフラグ(--workload-type = availability-optimized)が設定されたキューに入れられたリソースです。

次のコマンドは、Cloud TPU API を使用して単一ホスト コレクションをプロビジョニングします。

gcloud alpha compute tpus queued-resources create my-collection \
   --project=$PROJECT_ID \
   --zone=${ZONE} \
   --accelerator-type $ACCELERATOR_TYPE \
   --node-count ${NODE_COUNT} \
   --workload-type=availability-optimized

モニタリングとプロファイル

Cloud TPU v6e は、以前の世代の Cloud TPU と同じ方法でモニタリングとプロファイリングをサポートしています。モニタリングの詳細については、TPU VM をモニタリングするをご覧ください。