TPU v6e を使用してモデルをトレーニングする

このドキュメントでは、Cloud TPU v6e(Trillium とも呼ばれます)でモデルをトレーニングする方法について説明します。環境設定、パフォーマンスの最適化、JAX と PyTorch/XLA を使用した実践的なトレーニングの例について説明します。

TPU v6e(Trillium とも呼ばれます)は、Google の第 6 世代の TPU です。API やログなどのすべての技術的な側面から、Trillium はこのドキュメント全体で v6e と呼んでいます。Pod あたり 256 個のチップを備えた TPU v6e のアーキテクチャは、v5e と多くの類似点があります。TPU v6e は、トランスフォーマー、テキスト画像変換、畳み込みニューラル ネットワーク(CNN)のトレーニング、ファインチューニング、サービングに最適化されています。TPU v6e システム アーキテクチャと構成の詳細については、TPU v6e をご覧ください。

Cloud TPU v6e で推論を実行する方法については、次のチュートリアルをご覧ください。

始める前に

始める前に、次のことを行う必要があります。

  • 課金を有効にした Google Cloud アカウントおよびプロジェクトを作成する
  • Google Cloud CLI アルファ版コンポーネントをインストールする
  • Cloud TPU API を有効にする
  • Cloud TPU サービス エージェントを作成する
  • Cloud TPU サービス アカウントを作成して権限を付与する

詳細については、Cloud TPU 環境を設定するをご覧ください。

割り当てと権限を確認する

プロジェクトに次の割り当てがあることを確認します。

GKE と XPK を使用している場合は、 Google Cloud コンソールで追加の権限が必要です。詳細については、Google Cloud コンソールに必要な権限 をご覧ください。

TPU をプロビジョニングする

TPU v6e は、次の方法でプロビジョニングして管理できます。

  • GKE: GKE を使用して、コンテナ化された ML ワークロードのアクセラレータ プールとして TPU をプロビジョニングして管理できます。詳細については、GKE の TPU についてをご覧ください。
  • GKE と XPK: XPK は、GKE でのクラスタの作成とワークロードの実行を簡素化するコマンドライン ツールです。これは、ML の実務担当者が Kubernetes の深い専門知識を必要とせずに TPU をプロビジョニングしてトレーニング ジョブを実行できるように設計されています。詳細については、XPK GitHub リポジトリをご覧ください。
  • Cloud TPU キューに格納されたリソース: キューに格納されたリソースを使用すると、利用可能になったときにプロビジョニングされる TPU 容量をリクエストできます。キューで待機できるバッチジョブやフォールト トレラント ワークロードに最適です。リクエストの時間枠を指定できます。詳細については、キューに格納されたリソースを管理するをご覧ください。

GKE と XPK を使用して v6e Cloud TPU をプロビジョニングする

v6e で GKE コマンドを使用している場合は、Kubernetes コマンドまたは XPK を使用して Cloud TPU をプロビジョニングし、モデルをトレーニングまたはサービングできます。GKE クラスタで Cloud TPU 構成を計画する方法については、GKE で Cloud TPU を計画するをご覧ください。以降のセクションでは、単一 NIC とマルチ NIC をサポートする XPK クラスタを作成するコマンドについて説明します。

単一 NIC をサポートする XPK クラスタを作成する

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --create-vertex-tensorboard

コマンドフラグの説明

変数 説明
CLUSTER_NAME ユーザーが割り当てた XPK クラスタの名前。
PROJECT_ID Google Cloud プロジェクト名。既存のプロジェクトを使用するか、新しいプロジェクトを作成します。 詳細については、 Google Cloud プロジェクトを設定するをご覧ください。
ZONE サポートされているゾーンについては、Cloud TPU のリージョンとゾーンのドキュメントをご覧ください。
TPU_TYPE アクセラレータ タイプをご覧ください。
NUM_SLICES 作成するスライスの数。
CLUSTER_ARGUMENTS 使用するネットワークとサブネットワーク。

例: --network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES 作成するスライスの数。
NETWORK_NAME 使用するセカンダリ ネットワークの名前。
NETWORK_FW_NAME 使用するセカンダリ ネットワーク ファイアウォールの名前。

マルチ NIC をサポートする XPK クラスタを作成する

export CLUSTER_NAME=xpk-cluster-name
export REGION=your-region
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_1} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_2} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \
   --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
   --create-vertex-tensorboard

コマンドフラグの説明

変数 説明
CLUSTER_NAME ユーザーが割り当てた XPK クラスタの名前。
PROJECT_ID Google Cloud プロジェクト名。既存のプロジェクトを使用するか、新しいプロジェクトを作成します。 詳細については、 Google Cloud プロジェクトを設定するをご覧ください。
ZONE サポートされているゾーンについては、Cloud TPU のリージョンとゾーンのドキュメントをご覧ください。
TPU_TYPE アクセラレータ タイプをご覧ください。
NUM_SLICES 作成するスライスの数。
CLUSTER_ARGUMENTS 使用するネットワークとサブネットワーク。

例: --enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS 使用する追加のノード ネットワーク。

例: --additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES 作成するスライスの数(マルチスライスのみで必要)。
NETWORK_NAME 使用するセカンダリ ネットワークの名前。
NETWORK_FW_NAME 使用するセカンダリ ネットワーク ファイアウォールの名前。

JAX または PyTorch を設定する

使用するプロビジョニングと管理の方法に応じて、次のリソースで Cloud TPU に JAX または PyTorch を設定する方法を確認してください。

MaxText で XPK を設定して実行するには、XPK を使用して MaxText を大規模に実行する をご覧ください。

ネットワーク パフォーマンスを最適化する

このセクションでは、最大伝送単位(MTU)の構成、Multislice 環境でのマルチ NIC の使用、TCP 設定の改善によってネットワーク パフォーマンスを最適化する方法について説明します。

MTU を構成する

ネットワーク パフォーマンスを最大限に高めるには、8,896 MTU(最大伝送単位)のネットワークを使用します。

デフォルトでは、Virtual Private Cloud(VPC)では 1,460 バイトの MTU のみが提供され、ネットワーク パフォーマンスが最適化されません。VPC ネットワークの MTU は、1,300~8,896 バイト(両端を含む)の任意の値に設定できます。一般的なカスタム MTU サイズは 1,500 バイト(標準イーサネット)または 8,896 バイト(可能な最大値)です。詳細については、有効な VPC ネットワークの MTU サイズをご覧ください。

既存またはデフォルトのネットワークの MTU 設定の変更について詳しくは、VPC ネットワークの MTU 設定を変更するをご覧ください。

次の例では、8,896 MTU のネットワークと、ネットワーク内の TCP、ICMP、UDP トラフィックを許可する対応するファイアウォール ルールを作成します。

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
    --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
    --allow tcp,icmp,udp --project=${PROJECT_ID}

your-resource-name は、ネットワークとファイアウォールのベース名に置き換えます。

マルチスライスのマルチ NIC オプションを使用する

マルチスライス環境を使用している場合は、セカンダリ サブネットに必要な次の環境変数を設定します。

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

次のコマンドを使用して、ネットワークとサブネットのカスタム IP ルーティングを作成します。

  1. セカンダリ ネットワークを作成します。

    gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
    --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
    
  2. セカンダリ ネットワークのサブネットワークを作成します。

    gcloud compute networks subnets create ${SUBNET_NAME_2} \
    --network=${NETWORK_NAME_2} \
    --range=10.10.0.0/18 --region=${REGION} \
    --project=${PROJECT_ID}
    
  3. 新しいサブネットワーク内のトラフィックを許可するファイアウォール ルールを作成します。

    gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
    --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
    --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
    
  4. セカンダリ ネットワークの Cloud Router を作成します。

    gcloud compute routers create ${ROUTER_NAME} \
    --project=${PROJECT_ID} \
    --network=${NETWORK_NAME_2} \
    --region=${REGION}
    
  5. Cloud Router の NAT 構成を作成します。

    gcloud compute routers nats create ${NAT_CONFIG} \
    --router=${ROUTER_NAME} \
    --region=${REGION} \
    --auto-allocate-nat-external-ips \
    --nat-all-subnet-ip-ranges \
    --project=${PROJECT_ID} \
    --enable-logging
    

マルチネットワーク スライスを作成したら、XPK クラスタを設定し、XPK ワークロード作成コマンド--command ifconfig フラグを追加して、両方のネットワーク インターフェース カード(NIC)が使用されていることを確認できます。

  1. 次の workload create コマンドを使用して、 Google Cloud コンソールログに ifconfig コマンドの出力を表示し、eth0 と eth1 の両方の MTU が 8,896 に設定されていることを確認します。

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
        --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --num-slices=${NUM_SLICES}  \
        --on-demand \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --command "ifconfig"

    デバッグログを有効にする Vertex AI TensorBoard を使用する場合は、次のオプション引数をコマンドに追加します。

    --enable-debug-logs \
    --use-vertex-tensorboard
  2. Google Cloud コンソールログで XPK ワークロードの出力を確認して、eth0 と eth1 の両方の MTU が 8,896 に設定されていることを確認します。

TCP 設定を改善する

キューに格納されたリソースを使用して Cloud TPU をプロビジョニングした場合は、次のコマンドを実行して TCP 受信バッファの上限を増やし、ネットワーク パフォーマンスを改善できます。

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
    --project "${PROJECT_ID}" \
    --zone "${ZONE}" \
    --node=all \
    --worker=all \
    --command='
    sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'

メモリ割り当てのパフォーマンスを最適化する

tcmalloc ライブラリは、Cloud TPU VM でデフォルトで使用され、メモリ割り当てが頻繁に発生するモデルのパフォーマンスを向上させます。これは、LD_PRELOAD 環境変数を使用して構成されます。

ただし、一部のワークロード(埋め込みテーブルの割り当てが非常に大きい DLRM など)では、tcmalloc によって速度が低下する可能性があります。このような場合は、トレーニング スクリプトを実行する前に、シェル セッションで LD_PRELOAD 変数を設定解除することで、標準の malloc 関数に戻すことができます。

unset LD_PRELOAD

SkyPilot を使用する

Cloud TPU v6e は SkyPilot で使用できます。SkyPilot は、AI ワークロードの実行、管理、スケーリングのプロセスを簡素化するオープンソース フレームワークです。v6e に関連するロケーションと料金の情報を SkyPilot に追加できます。詳細については、SkyPilot TPU v6e の例をご覧ください。

トレーニング サンプル

以降のセクションでは、Cloud TPU v6e で MaxText、MaxDiffusion、PyTorch の各モデルをトレーニングする例について説明します。

これらの例は、次のソフトウェア バージョンでテストされています。

  • Python 3.10 以降
  • ナイトリー ソフトウェア バージョン:
    • ナイトリー JAX 0.4.32.dev20240912
    • ナイトリー LibTPU 0.1.dev20240912+nightly
  • 安定版ソフトウェア バージョン:
    • JAX + JAX Lib バージョン 0.4.37

Cloud TPU v6e で MaxText と MaxDiffusion をトレーニングする

以降のセクションでは、MaxText モデルと MaxDiffusion モデルのトレーニング ライフサイクルについて説明します。

一般的な手順は次のとおりです。

  1. ワークロードのベースイメージをビルドします。
  2. XPK を使用してワークロードを実行します。
    1. ワークロードのトレーニング コマンドをビルドします。
    2. ワークロードをデプロイします。
  3. ワークロードを追跡して指標を表示します。
  4. 不要な XPK ワークロードを削除します。
  5. 不要になった XPK クラスタを削除します。

ベースイメージをビルドする

MaxText または MaxDiffusion をインストールして Docker イメージをビルドします。

  1. 使用するリポジトリのクローンを作成し、リポジトリのディレクトリに移動します。

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    
  2. Google Cloud CLI を使用するように Docker を構成します。

    gcloud auth configure-docker
    
  3. 次のコマンドまたは JAX AI イメージを使用して Docker イメージをビルドします。JAX AI イメージの詳細については、JAX AI イメージをご覧ください。

    MaxText:

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    

    MaxDiffusion:

    bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
    
  4. アクティブな gcloud CLI 構成でプロジェクト ID を設定します。

    gcloud config set project ${PROJECT_ID}
    
  5. ローカルにビルドされたイメージがないマシンからワークロードを起動する場合は、イメージをアップロードします。

    1. CLOUD_IMAGE_NAME 環境変数を設定します。

      export CLOUD_IMAGE_NAME=${USER}_runner
      
    2. 画像をアップロードします。

      bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
      

XPK を使用してワークロードを実行する

  1. MaxText または MaxDiffusion によって設定されたデフォルト値を使用していない場合は、次の環境変数を設定します。

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. モデル スクリプトをビルドします。このスクリプトは、後でトレーニング コマンドとしてコピーされます。

    モデル スクリプトはまだ実行しないでください。

    MaxText

    MaxText は、高パフォーマンスでスケーラビリティに優れたオープンソースの LLM です。ピュア Python と JAX で記述され、トレーニングと推論で Google Cloud TPU および GPU をターゲットとします。

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python3 -m MaxText.train MaxText/configs/base.yml \
         base_output_directory=${BASE_OUTPUT_DIR} \
         dataset_type=synthetic \
         per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
         enable_checkpointing=false \
         gcs_metrics=true \
         profiler=xplane \
         skip_first_n_steps_for_profiler=5 \
         steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma は、Gemini の研究とテクノロジーに基づいて Google DeepMind が開発したオープン ウェイト LLM のファミリーです。

    python3 -m MaxText.train MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral は、Mistral AI が開発した最先端の AI モデルであり、スパースな Mixture of Experts(MoE)アーキテクチャを利用します。

    python3 -m MaxText.train MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama は、Meta が開発したオープン ウェイト LLM のファミリーです。

    PyTorch で Llama3 を実行する方法の例については、torchprime リポジトリの torch_xla モデルをご覧ください。

    MaxDiffusion

    MaxDiffusion は、さまざまな潜在拡散モデルのリファレンス実装のコレクションです。ピュア Python と JAX で記述され、Cloud TPU や GPU などの XLA デバイスで実行されます。Stable Diffusion は、テキスト入力からフォトリアリスティックな画像を生成する、潜在テキスト画像変換モデルです。

    MaxDiffusion を実行するには、次のトレーニング スクリプトに示すように、特定の Git ブランチをインストールする必要があります。

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    && pip install -r requirements.txt && pip install .
    && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR}
    && python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR} \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash \
        run_name=sdxl-ddp-v6e
    
  3. 次の変数をエクスポートします。

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    環境変数の説明

    変数 説明
    CLUSTER_NAME XPK クラスタの名前。
    ACCELERATOR_TYPE アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
    NUM_SLICES TPU スライスの数。
    YOUR_MODEL_SCRIPT トレーニング コマンドとして実行するモデル スクリプト。
  4. 前の手順で作成したスクリプトを使用してモデルを実行します。MaxText ベースイメージを使用するために --base-docker-image フラグを指定するか、--docker-image フラグと使用するイメージを指定する必要があります。

    次のオプションのフラグを追加できます。

    python3 xpk.py workload create \
      --cluster ${CLUSTER_NAME} \
      {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \
      --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
      --tpu-type=${ACCELERATOR_TYPE} \
      --num-slices=${NUM_SLICES}  \
      --on-demand \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --command="${YOUR_MODEL_SCRIPT}"

    出力には、ワークロードを追跡するためのリンクが含まれます。リンクを開き、[ログ] タブをクリックしてワークロードをリアルタイムで追跡します。

MaxText で JAX をデバッグする

補足 XPK コマンドを使用して、クラスタまたはワークロードが実行されていない理由を診断します。

Vertex AI を使用して MaxText の JAX をモニタリングする

TensorBoard を使用するには、 Google Cloud ユーザー アカウントに aiplatform.user ロールが必要です。次のコマンドを実行して、このロールを付与します。

gcloud projects add-iam-policy-binding your-project-id \
   --member='user:your-email' \
   --role='roles/aiplatform.user'

Vertex AI マネージド TensorBoard を使用して、スカラーデータとプロファイル データを表示します。

  1. 使用しているゾーンのリソース管理(CRUD)リクエストを 600 から 5,000 に増やします。これは、16 台未満の VM を使用する小規模なワークロードでは問題にならない可能性があります。

  2. Vertex AI の cloud-accelerator-diagnostics などの依存関係をインストールします。

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Vertex AI TensorBoard を作成するで説明されているように、--create-vertex-tensorboard フラグを使用して XPK クラスタを作成します。このコマンドは既存のクラスタでも実行できます。

  4. --use-vertex-tensorboard フラグとオプションの --experiment-name フラグを使用して、XPK ワークロードを実行するときに、Vertex AI テストを作成します。手順の一覧については、Vertex AI テストを作成して Vertex AI TensorBoard にデータをアップロードするをご覧ください。

ログには、次のような Vertex AI TensorBoard へのリンクが含まれています。

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Vertex AI TensorBoard のリンクは、 Google Cloud コンソールでも確認できます。 Google Cloud コンソールで Vertex AI Experiments に移動します。プルダウンから適切なリージョンを選択します。

TensorBoard ディレクトリも、${BASE_OUTPUT_DIR} で指定した Cloud Storage バケットに書き込まれます。

XPK ワークロードを削除する

ジョブの接頭辞またはジョブ ステータスに基づいて 1 つ以上のワークロードを削除するには、xpk workload delete コマンドを使用します。このコマンドは、実行する必要がなくなった XPK ワークロードを送信した場合や、キュー内に停止しているジョブがある場合に便利です。

XPK クラスタを削除する

クラスタを削除するには、xpk cluster delete コマンドを使用します。

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
    --zone=${ZONE} --project=${PROJECT_ID}

MaxDiffusion のベンチマーク結果

MaxDiffusion のトレーニング スクリプトを v6e-4、v6e-16、2 つの v6e-16 で実行しました。次の表に、測定されたスループットを示します。

v6e-4 v6e-16 2 つの v6e-16
トレーニング ステップ数 0.069 0.073 0.13
グローバル バッチサイズ 8 32 64
スループット(例/秒) 115.9 438.4 492.3

Cloud TPU v6e で PyTorch/XLA を使用して Llama モデルをトレーニングする

このセクションでは、WikiText データセットを使用して、Cloud TPU v6e で PyTorch/XLA を使用して Llama モデルをトレーニングする方法について説明します。

Hugging Face および Llama 3 モデルにアクセスする

この例では、Hugging Face ユーザー アクセス トークンが必要です。ユーザー アクセス トークンの作成については、Hugging Face のユーザー アクセス トークンに関するドキュメントをご覧ください。

また、Hugging Face の Llama-3-8B モデルにアクセスする権限も必要です。アクセス権を取得するには、HuggingFace の Meta-Llama-3-8B モデルにアクセスしてアクセス権をリクエストしてください。

Cloud TPU VM を作成する

この例では、8 チップを搭載した Cloud TPU v6e を作成します。

  1. 環境変数を設定します。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east1-d
    export ACCELERATOR_TYPE=v6e-8
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    環境変数の説明

    変数 説明
    PROJECT_ID 実際の Google Cloud のプロジェクト ID。既存のプロジェクトを使用するか、新しいプロジェクトを作成します
    TPU_NAME TPU の名前。
    ZONE TPU VM を作成するゾーン。サポートされているゾーンの詳細については、TPU のリージョンとゾーンをご覧ください。
    ACCELERATOR_TYPE アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
    RUNTIME_VERSION Cloud TPU ソフトウェアのバージョン

  2. Cloud TPU VM を作成します。

    gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${RUNTIME_VERSION} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --zone=${ZONE} \
       --project=${PROJECT_ID}

インストール

Hugging Face Transformers と依存関係の pytorch-tpu/transformers フォークをインストールします。この例は、次の依存関係のバージョンでテストされています。

  • torch: 2.5.0 との互換性あり
  • torch_xla[tpu]: 2.5.0 との互換性あり
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone ${ZONE} \
   --worker=all \
   --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
   cd transformers
   sudo pip3 install -e .
   pip3 install datasets
   pip3 install evaluate
   pip3 install scikit-learn
   pip3 install accelerate
   pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
   pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'

モデル構成ファイルを設定する

次のセクション(モデルを実行する)のトレーニング コマンドでは、2 つの JSON 構成ファイルを使用して、モデル パラメータと Fully Sharded Data Parallel(FSDP)構成を定義します。FSDP シャーディングを使用すると、モデルの重みを複数の TPU にシャーディングすることで、トレーニング中に大きなバッチサイズを使用できます。小さなモデルでトレーニングする場合は、データ並列処理を使用して各デバイスに重みを複製するだけで十分な可能性があります。PyTorch/XLA でデバイス間でテンソルをシャーディングする方法の詳細については、PyTorch/XLA SPMD ユーザーガイドをご覧ください。

  1. モデル パラメータ構成ファイルを作成します。Llama-3-8B のモデル パラメータ構成は次のとおりです。他のモデルについては、Hugging Face で構成ファイルを確認してください。たとえば、Llama-2-7B 構成をご覧ください。

    cat > llama-config.json << EOF
    {
      "architectures": [
        "LlamaForCausalLM"
      ],
      "attention_bias": false,
      "attention_dropout": 0.0,
      "bos_token_id": 128000,
      "eos_token_id": 128001,
      "hidden_act": "silu",
      "hidden_size": 4096,
      "initializer_range": 0.02,
      "intermediate_size": 14336,
      "max_position_embeddings": 8192,
      "model_type": "llama",
      "num_attention_heads": 32,
      "num_hidden_layers": 32,
      "num_key_value_heads": 8,
      "pretraining_tp": 1,
      "rms_norm_eps": 1e-05,
      "rope_scaling": null,
      "rope_theta": 500000.0,
      "tie_word_embeddings": false,
      "torch_dtype": "bfloat16",
      "transformers_version": "4.40.0.dev0",
      "use_cache": false,
      "vocab_size": 128256
    }
    EOF
    
  2. FSDP 構成ファイルを作成します。

    cat > fsdp-config.json << EOF
    {
      "fsdp_transformer_layer_cls_to_wrap": [
        "LlamaDecoderLayer"
      ],
      "xla": true,
      "xla_fsdp_v2": true,
      "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    FSDP の詳細については、SPMD を使用した完全なシャード データ並列処理 をご覧ください。

  3. 次のコマンドを使用して、構成ファイルを Cloud TPU VM にアップロードします。

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \
       --worker=all \
       --project=${PROJECT_ID} \
       --zone=${ZONE}

モデルを実行する

前のセクションで作成した構成ファイルを使用して run_clm.py スクリプトを実行し、WikiText データセットで Llama-3-8B モデルをトレーニングします。Cloud TPU v6e-8 でトレーニング スクリプトを実行すると、約 10 分かかります。

  1. 次のコマンドを使用して、Cloud TPU で Hugging Face にログインします。

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       pip3 install "huggingface_hub[cli]"
       huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. モデル トレーニングを実行します。

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       export PJRT_DEVICE=TPU
       export XLA_USE_SPMD=1
       export ENABLE_PJRT_COMPATIBILITY=true
       # Optional variables for debugging:
       export XLA_IR_DEBUG=1
       export XLA_HLO_DEBUG=1
       export PROFILE_EPOCH=0
       export PROFILE_STEP=3
       export PROFILE_DURATION_MS=100000
       # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
       export PROFILE_LOGDIR=PROFILE_PATH
       python3 transformers/examples/pytorch/language-modeling/run_clm.py \
         --dataset_name wikitext \
         --dataset_config_name wikitext-2-raw-v1 \
         --per_device_train_batch_size 16 \
         --do_train \
         --output_dir /home/$USER/tmp/test-clm \
         --overwrite_output_dir \
         --config_name /home/$USER/llama-config.json \
         --cache_dir /home/$USER/cache \
         --tokenizer_name meta-llama/Meta-Llama-3-8B \
         --block_size 8192 \
         --optim adafactor \
         --save_strategy no \
         --logging_strategy no \
         --fsdp "full_shard" \
         --fsdp_config /home/$USER/fsdp-config.json \
         --torch_dtype bfloat16 \
         --dataloader_drop_last yes \
         --flash_attention \
         --max_steps 20'

PyTorch/XLA のトラブルシューティングを行う

前のセクションでデバッグ用のオプション変数を設定した場合、モデルのプロファイルは PROFILE_LOGDIR 変数で指定された場所に保存されます。この場所に保存されている xplane.pb ファイルを抽出し、tensorboard を使用して、TensorBoard の手順に沿ってブラウザでプロファイルを表示できます。

PyTorch/XLA が想定どおりに動作しない場合は、トラブルシューティング ガイドをご覧ください。モデルのデバッグ、プロファイリング、最適化に関する推奨事項が記載されています。