TPU v6e を使用してモデルをトレーニングする
このドキュメントでは、Cloud TPU v6e(Trillium とも呼ばれます)でモデルをトレーニングする方法について説明します。環境設定、パフォーマンスの最適化、JAX と PyTorch/XLA を使用した実践的なトレーニングの例について説明します。
TPU v6e(Trillium とも呼ばれます)は、Google の第 6 世代の TPU です。API やログなどのすべての技術的な側面から、Trillium はこのドキュメント全体で v6e と呼んでいます。Pod あたり 256 個のチップを備えた TPU v6e のアーキテクチャは、v5e と多くの類似点があります。TPU v6e は、トランスフォーマー、テキスト画像変換、畳み込みニューラル ネットワーク(CNN)のトレーニング、ファインチューニング、サービングに最適化されています。TPU v6e システム アーキテクチャと構成の詳細については、TPU v6e をご覧ください。
Cloud TPU v6e で推論を実行する方法については、次のチュートリアルをご覧ください。
- v6e での JetStream MaxText 推論
- v6e での JetStream PyTorch 推論
- v6e での MaxDiffusion 推論
- v6e での vLLM 推論
- Pathways を使用してマルチホスト推論を実行する
始める前に
始める前に、次のことを行う必要があります。
- 課金を有効にした Google Cloud アカウントおよびプロジェクトを作成する
- Google Cloud CLI アルファ版コンポーネントをインストールする
- Cloud TPU API を有効にする
- Cloud TPU サービス エージェントを作成する
- Cloud TPU サービス アカウントを作成して権限を付与する
詳細については、Cloud TPU 環境を設定するをご覧ください。
割り当てと権限を確認する
プロジェクトに次の割り当てがあることを確認します。
GKE と XPK を使用している場合は、 Google Cloud コンソールで追加の権限が必要です。詳細については、Google Cloud コンソールに必要な権限 をご覧ください。
TPU をプロビジョニングする
TPU v6e は、次の方法でプロビジョニングして管理できます。
- GKE: GKE を使用して、コンテナ化された ML ワークロードのアクセラレータ プールとして TPU をプロビジョニングして管理できます。詳細については、GKE の TPU についてをご覧ください。
- GKE と XPK: XPK は、GKE でのクラスタの作成とワークロードの実行を簡素化するコマンドライン ツールです。これは、ML の実務担当者が Kubernetes の深い専門知識を必要とせずに TPU をプロビジョニングしてトレーニング ジョブを実行できるように設計されています。詳細については、XPK GitHub リポジトリをご覧ください。
- Cloud TPU キューに格納されたリソース: キューに格納されたリソースを使用すると、利用可能になったときにプロビジョニングされる TPU 容量をリクエストできます。キューで待機できるバッチジョブやフォールト トレラント ワークロードに最適です。リクエストの時間枠を指定できます。詳細については、キューに格納されたリソースを管理するをご覧ください。
GKE と XPK を使用して v6e Cloud TPU をプロビジョニングする
v6e で GKE コマンドを使用している場合は、Kubernetes コマンドまたは XPK を使用して Cloud TPU をプロビジョニングし、モデルをトレーニングまたはサービングできます。GKE クラスタで Cloud TPU 構成を計画する方法については、GKE で Cloud TPU を計画するをご覧ください。以降のセクションでは、単一 NIC とマルチ NIC をサポートする XPK クラスタを作成するコマンドについて説明します。
単一 NIC をサポートする XPK クラスタを作成する
export CLUSTER_NAME=xpk-cluster-name export ZONE=us-east1-d export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME=${CLUSTER_NAME}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT_ID} \ --subnet-mode=auto \ --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \ --network=${NETWORK_NAME} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=e2-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${TPU_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --create-vertex-tensorboard
コマンドフラグの説明
変数 | 説明 |
CLUSTER_NAME | ユーザーが割り当てた XPK クラスタの名前。 |
PROJECT_ID | Google Cloud プロジェクト名。既存のプロジェクトを使用するか、新しいプロジェクトを作成します。 詳細については、 Google Cloud プロジェクトを設定するをご覧ください。 |
ZONE | サポートされているゾーンについては、Cloud TPU のリージョンとゾーンのドキュメントをご覧ください。 |
TPU_TYPE | アクセラレータ タイプをご覧ください。 |
NUM_SLICES | 作成するスライスの数。 |
CLUSTER_ARGUMENTS | 使用するネットワークとサブネットワーク。
例: |
NUM_SLICES | 作成するスライスの数。 |
NETWORK_NAME | 使用するセカンダリ ネットワークの名前。 |
NETWORK_FW_NAME | 使用するセカンダリ ネットワーク ファイアウォールの名前。 |
マルチ NIC をサポートする XPK クラスタを作成する
export CLUSTER_NAME=xpk-cluster-name export REGION=your-region export ZONE=us-east1-d export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE} export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE} export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE} export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE} export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE} export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \ --network=${NETWORK_NAME_1} \ --range=10.11.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_1} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_1} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.
export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \ --network=${NETWORK_NAME_2} \ --range=10.10.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_2} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_2} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \ --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=e2-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${TPU_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \ --create-vertex-tensorboard
コマンドフラグの説明
変数 | 説明 |
CLUSTER_NAME | ユーザーが割り当てた XPK クラスタの名前。 |
PROJECT_ID | Google Cloud プロジェクト名。既存のプロジェクトを使用するか、新しいプロジェクトを作成します。 詳細については、 Google Cloud プロジェクトを設定するをご覧ください。 |
ZONE | サポートされているゾーンについては、Cloud TPU のリージョンとゾーンのドキュメントをご覧ください。 |
TPU_TYPE | アクセラレータ タイプをご覧ください。 |
NUM_SLICES | 作成するスライスの数。 |
CLUSTER_ARGUMENTS | 使用するネットワークとサブネットワーク。
例: |
NODE_POOL_ARGUMENTS | 使用する追加のノード ネットワーク。
例: |
NUM_SLICES | 作成するスライスの数(マルチスライスのみで必要)。 |
NETWORK_NAME | 使用するセカンダリ ネットワークの名前。 |
NETWORK_FW_NAME | 使用するセカンダリ ネットワーク ファイアウォールの名前。 |
JAX または PyTorch を設定する
使用するプロビジョニングと管理の方法に応じて、次のリソースで Cloud TPU に JAX または PyTorch を設定する方法を確認してください。
- GKE Autopilot: TPU アプリケーションを準備する
- GKE Standard: ワークロードを準備する
- GKE と XPK: XPK README
- JAX を使用したシングルホスト Cloud TPU: JAX を使用して Cloud TPU VM で計算を実行する
- JAX を使用したマルチホスト Cloud TPU: TPU スライスで JAX コードを実行する
- PyTorch を使用したシングルホスト Cloud TPU: PyTorch を使用して Cloud TPU VM で計算を実行する
- PyTorch を使用したマルチホスト Cloud TPU: TPU スライスで PyTorch コードを実行する
MaxText で XPK を設定して実行するには、XPK を使用して MaxText を大規模に実行する をご覧ください。
ネットワーク パフォーマンスを最適化する
このセクションでは、最大伝送単位(MTU)の構成、Multislice 環境でのマルチ NIC の使用、TCP 設定の改善によってネットワーク パフォーマンスを最適化する方法について説明します。
MTU を構成する
ネットワーク パフォーマンスを最大限に高めるには、8,896 MTU(最大伝送単位)のネットワークを使用します。
デフォルトでは、Virtual Private Cloud(VPC)では 1,460 バイトの MTU のみが提供され、ネットワーク パフォーマンスが最適化されません。VPC ネットワークの MTU は、1,300~8,896 バイト(両端を含む)の任意の値に設定できます。一般的なカスタム MTU サイズは 1,500 バイト(標準イーサネット)または 8,896 バイト(可能な最大値)です。詳細については、有効な VPC ネットワークの MTU サイズをご覧ください。
既存またはデフォルトのネットワークの MTU 設定の変更について詳しくは、VPC ネットワークの MTU 設定を変更するをご覧ください。
次の例では、8,896 MTU のネットワークと、ネットワーク内の TCP、ICMP、UDP トラフィックを許可する対応するファイアウォール ルールを作成します。
export RESOURCE_NAME=your-resource-name export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \ --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \ --allow tcp,icmp,udp --project=${PROJECT_ID}
your-resource-name は、ネットワークとファイアウォールのベース名に置き換えます。
マルチスライスのマルチ NIC オプションを使用する
マルチスライス環境を使用している場合は、セカンダリ サブネットに必要な次の環境変数を設定します。
export NETWORK_NAME_2=${RESOURCE_NAME} export SUBNET_NAME_2=${RESOURCE_NAME} export FIREWALL_RULE_NAME=${RESOURCE_NAME} export ROUTER_NAME=${RESOURCE_NAME}-network-2 export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2 export REGION=your-region
次のコマンドを使用して、ネットワークとサブネットのカスタム IP ルーティングを作成します。
セカンダリ ネットワークを作成します。
gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \ --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
セカンダリ ネットワークのサブネットワークを作成します。
gcloud compute networks subnets create ${SUBNET_NAME_2} \ --network=${NETWORK_NAME_2} \ --range=10.10.0.0/18 --region=${REGION} \ --project=${PROJECT_ID}
新しいサブネットワーク内のトラフィックを許可するファイアウォール ルールを作成します。
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \ --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
セカンダリ ネットワークの Cloud Router を作成します。
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_2} \ --region=${REGION}
Cloud Router の NAT 構成を作成します。
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
マルチネットワーク スライスを作成したら、XPK クラスタを設定し、XPK ワークロード作成コマンドに --command ifconfig
フラグを追加して、両方のネットワーク インターフェース カード(NIC)が使用されていることを確認できます。
次の
workload create
コマンドを使用して、 Google Cloud コンソールログにifconfig
コマンドの出力を表示し、eth0 と eth1 の両方の MTU が 8,896 に設定されていることを確認します。python3 xpk.py workload create \ --cluster CLUSTER_NAME \ {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --command "ifconfig"
デバッグログを有効にするか Vertex AI TensorBoard を使用する場合は、次のオプション引数をコマンドに追加します。
--enable-debug-logs \ --use-vertex-tensorboard
Google Cloud コンソールログで XPK ワークロードの出力を確認して、eth0 と eth1 の両方の MTU が 8,896 に設定されていることを確認します。
TCP 設定を改善する
キューに格納されたリソースを使用して Cloud TPU をプロビジョニングした場合は、次のコマンドを実行して TCP 受信バッファの上限を増やし、ネットワーク パフォーマンスを改善できます。
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "${PROJECT_ID}" \ --zone "${ZONE}" \ --node=all \ --worker=all \ --command=' sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'
メモリ割り当てのパフォーマンスを最適化する
tcmalloc
ライブラリは、Cloud TPU VM でデフォルトで使用され、メモリ割り当てが頻繁に発生するモデルのパフォーマンスを向上させます。これは、LD_PRELOAD
環境変数を使用して構成されます。
ただし、一部のワークロード(埋め込みテーブルの割り当てが非常に大きい DLRM など)では、tcmalloc
によって速度が低下する可能性があります。このような場合は、トレーニング スクリプトを実行する前に、シェル セッションで LD_PRELOAD
変数を設定解除することで、標準の malloc
関数に戻すことができます。
unset LD_PRELOAD
SkyPilot を使用する
Cloud TPU v6e は SkyPilot で使用できます。SkyPilot は、AI ワークロードの実行、管理、スケーリングのプロセスを簡素化するオープンソース フレームワークです。v6e に関連するロケーションと料金の情報を SkyPilot に追加できます。詳細については、SkyPilot TPU v6e の例をご覧ください。
トレーニング サンプル
以降のセクションでは、Cloud TPU v6e で MaxText、MaxDiffusion、PyTorch の各モデルをトレーニングする例について説明します。
これらの例は、次のソフトウェア バージョンでテストされています。
- Python
3.10
以降 - ナイトリー ソフトウェア バージョン:
- ナイトリー JAX
0.4.32.dev20240912
- ナイトリー LibTPU
0.1.dev20240912+nightly
- ナイトリー JAX
- 安定版ソフトウェア バージョン:
- JAX + JAX Lib バージョン 0.4.37
Cloud TPU v6e で MaxText と MaxDiffusion をトレーニングする
以降のセクションでは、MaxText モデルと MaxDiffusion モデルのトレーニング ライフサイクルについて説明します。
一般的な手順は次のとおりです。
- ワークロードのベースイメージをビルドします。
- XPK を使用してワークロードを実行します。
- ワークロードのトレーニング コマンドをビルドします。
- ワークロードをデプロイします。
- ワークロードを追跡して指標を表示します。
- 不要な XPK ワークロードを削除します。
- 不要になった XPK クラスタを削除します。
ベースイメージをビルドする
MaxText または MaxDiffusion をインストールして Docker イメージをビルドします。
使用するリポジトリのクローンを作成し、リポジトリのディレクトリに移動します。
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
Google Cloud CLI を使用するように Docker を構成します。
gcloud auth configure-docker
次のコマンドまたは JAX AI イメージを使用して Docker イメージをビルドします。JAX AI イメージの詳細については、JAX AI イメージをご覧ください。
MaxText:
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
MaxDiffusion:
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
アクティブな gcloud CLI 構成でプロジェクト ID を設定します。
gcloud config set project ${PROJECT_ID}
ローカルにビルドされたイメージがないマシンからワークロードを起動する場合は、イメージをアップロードします。
CLOUD_IMAGE_NAME
環境変数を設定します。export CLOUD_IMAGE_NAME=${USER}_runner
画像をアップロードします。
bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
XPK を使用してワークロードを実行する
MaxText または MaxDiffusion によって設定されたデフォルト値を使用していない場合は、次の環境変数を設定します。
export BASE_OUTPUT_DIR=gs://YOUR_BUCKET export PER_DEVICE_BATCH_SIZE=2 export NUM_STEPS=30 export MAX_TARGET_LENGTH=8192
モデル スクリプトをビルドします。このスクリプトは、後でトレーニング コマンドとしてコピーされます。
モデル スクリプトはまだ実行しないでください。
MaxText
MaxText は、高パフォーマンスでスケーラビリティに優れたオープンソースの LLM です。ピュア Python と JAX で記述され、トレーニングと推論で Google Cloud TPU および GPU をターゲットとします。
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} # attention='dot_product'"
Gemma2
Gemma は、Gemini の研究とテクノロジーに基づいて Google DeepMind が開発したオープン ウェイト LLM のファミリーです。
python3 -m MaxText.train MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
Mixtral は、Mistral AI が開発した最先端の AI モデルであり、スパースな Mixture of Experts(MoE)アーキテクチャを利用します。
python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
Llama は、Meta が開発したオープン ウェイト LLM のファミリーです。
PyTorch で Llama3 を実行する方法の例については、torchprime リポジトリの torch_xla モデルをご覧ください。
MaxDiffusion
MaxDiffusion は、さまざまな潜在拡散モデルのリファレンス実装のコレクションです。ピュア Python と JAX で記述され、Cloud TPU や GPU などの XLA デバイスで実行されます。Stable Diffusion は、テキスト入力からフォトリアリスティックな画像を生成する、潜在テキスト画像変換モデルです。
MaxDiffusion を実行するには、次のトレーニング スクリプトに示すように、特定の Git ブランチをインストールする必要があります。
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103 && pip install -r requirements.txt && pip install . && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR} && python src/maxdiffusion/train_sdxl.py \ src/maxdiffusion/configs/base_xl.yml \ revision=refs/pr/95 \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ resolution=1024 \ per_device_batch_size=1 \ output_dir=${OUT_DIR} \ jax_cache_dir=${OUT_DIR}/cache_dir/ \ max_train_steps=200 \ attention=flash \ run_name=sdxl-ddp-v6e
次の変数をエクスポートします。
export CLUSTER_NAME=CLUSTER_NAME export ACCELERATOR_TYPE=ACCELERATOR_TYPE export NUM_SLICES=NUM_SLICES export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT
環境変数の説明
変数 説明 CLUSTER_NAME
XPK クラスタの名前。 ACCELERATOR_TYPE
アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。 NUM_SLICES
TPU スライスの数。 YOUR_MODEL_SCRIPT
トレーニング コマンドとして実行するモデル スクリプト。 前の手順で作成したスクリプトを使用してモデルを実行します。MaxText ベースイメージを使用するために
--base-docker-image
フラグを指定するか、--docker-image
フラグと使用するイメージを指定する必要があります。次のオプションのフラグを追加できます。
--enable-debug-logs
フラグを含めると、デバッグ ロギングを有効にできます。詳細については、MaxText で JAX をデバッグするをご覧ください。--use-vertex-tensorboard
フラグを含めると、Vertex AI Experiments を作成し、Vertex AI TensorBoard にデータをアップロードできます。詳細については、Vertex AI を使用して MaxText で JAX をモニタリングするをご覧ください。
python3 xpk.py workload create \ --cluster ${CLUSTER_NAME} \ {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --command="${YOUR_MODEL_SCRIPT}"
出力には、ワークロードを追跡するためのリンクが含まれます。リンクを開き、[ログ] タブをクリックしてワークロードをリアルタイムで追跡します。
MaxText で JAX をデバッグする
補足 XPK コマンドを使用して、クラスタまたはワークロードが実行されていない理由を診断します。
- XPK ワークロード リスト
- XPK インスペクタ
- XPK ワークロードを作成するときに、
--enable-debug-logs
フラグを使用してワークロード ログの詳細ログを有効にします。
Vertex AI を使用して MaxText の JAX をモニタリングする
TensorBoard を使用するには、 Google Cloud ユーザー アカウントに aiplatform.user
ロールが必要です。次のコマンドを実行して、このロールを付与します。
gcloud projects add-iam-policy-binding your-project-id \ --member='user:your-email' \ --role='roles/aiplatform.user'
Vertex AI マネージド TensorBoard を使用して、スカラーデータとプロファイル データを表示します。
使用しているゾーンのリソース管理(CRUD)リクエストを 600 から 5,000 に増やします。これは、16 台未満の VM を使用する小規模なワークロードでは問題にならない可能性があります。
Vertex AI の
cloud-accelerator-diagnostics
などの依存関係をインストールします。# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
Vertex AI TensorBoard を作成するで説明されているように、
--create-vertex-tensorboard
フラグを使用して XPK クラスタを作成します。このコマンドは既存のクラスタでも実行できます。--use-vertex-tensorboard
フラグとオプションの--experiment-name
フラグを使用して、XPK ワークロードを実行するときに、Vertex AI テストを作成します。手順の一覧については、Vertex AI テストを作成して Vertex AI TensorBoard にデータをアップロードするをご覧ください。
ログには、次のような Vertex AI TensorBoard へのリンクが含まれています。
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
Vertex AI TensorBoard のリンクは、 Google Cloud コンソールでも確認できます。 Google Cloud コンソールで Vertex AI Experiments に移動します。プルダウンから適切なリージョンを選択します。
TensorBoard ディレクトリも、${BASE_OUTPUT_DIR}
で指定した Cloud Storage バケットに書き込まれます。
XPK ワークロードを削除する
ジョブの接頭辞またはジョブ ステータスに基づいて 1 つ以上のワークロードを削除するには、xpk workload delete
コマンドを使用します。このコマンドは、実行する必要がなくなった XPK ワークロードを送信した場合や、キュー内に停止しているジョブがある場合に便利です。
XPK クラスタを削除する
クラスタを削除するには、xpk cluster delete
コマンドを使用します。
python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \ --zone=${ZONE} --project=${PROJECT_ID}
MaxDiffusion のベンチマーク結果
MaxDiffusion のトレーニング スクリプトを v6e-4、v6e-16、2 つの v6e-16 で実行しました。次の表に、測定されたスループットを示します。
v6e-4 | v6e-16 | 2 つの v6e-16 | |
---|---|---|---|
トレーニング ステップ数 | 0.069 | 0.073 | 0.13 |
グローバル バッチサイズ | 8 | 32 | 64 |
スループット(例/秒) | 115.9 | 438.4 | 492.3 |
Cloud TPU v6e で PyTorch/XLA を使用して Llama モデルをトレーニングする
このセクションでは、WikiText データセットを使用して、Cloud TPU v6e で PyTorch/XLA を使用して Llama モデルをトレーニングする方法について説明します。
Hugging Face および Llama 3 モデルにアクセスする
この例では、Hugging Face ユーザー アクセス トークンが必要です。ユーザー アクセス トークンの作成については、Hugging Face のユーザー アクセス トークンに関するドキュメントをご覧ください。
また、Hugging Face の Llama-3-8B モデルにアクセスする権限も必要です。アクセス権を取得するには、HuggingFace の Meta-Llama-3-8B モデルにアクセスしてアクセス権をリクエストしてください。
Cloud TPU VM を作成する
この例では、8 チップを搭載した Cloud TPU v6e を作成します。
環境変数を設定します。
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-east1-d export ACCELERATOR_TYPE=v6e-8 export RUNTIME_VERSION=v2-alpha-tpuv6e
環境変数の説明
変数 説明 PROJECT_ID
実際の Google Cloud のプロジェクト ID。既存のプロジェクトを使用するか、新しいプロジェクトを作成します。 TPU_NAME
TPU の名前。 ZONE
TPU VM を作成するゾーン。サポートされているゾーンの詳細については、TPU のリージョンとゾーンをご覧ください。 ACCELERATOR_TYPE
アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。 RUNTIME_VERSION
Cloud TPU ソフトウェアのバージョン。 Cloud TPU VM を作成します。
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${RUNTIME_VERSION} \ --accelerator-type=${ACCELERATOR_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID}
インストール
Hugging Face Transformers と依存関係の pytorch-tpu/transformers
フォークをインストールします。この例は、次の依存関係のバージョンでテストされています。
torch
: 2.5.0 との互換性ありtorch_xla[tpu]
: 2.5.0 との互換性ありjax
: 0.4.33jaxlib
: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers sudo pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'
モデル構成ファイルを設定する
次のセクション(モデルを実行する)のトレーニング コマンドでは、2 つの JSON 構成ファイルを使用して、モデル パラメータと Fully Sharded Data Parallel(FSDP)構成を定義します。FSDP シャーディングを使用すると、モデルの重みを複数の TPU にシャーディングすることで、トレーニング中に大きなバッチサイズを使用できます。小さなモデルでトレーニングする場合は、データ並列処理を使用して各デバイスに重みを複製するだけで十分な可能性があります。PyTorch/XLA でデバイス間でテンソルをシャーディングする方法の詳細については、PyTorch/XLA SPMD ユーザーガイドをご覧ください。
モデル パラメータ構成ファイルを作成します。Llama-3-8B のモデル パラメータ構成は次のとおりです。他のモデルについては、Hugging Face で構成ファイルを確認してください。たとえば、Llama-2-7B 構成をご覧ください。
cat > llama-config.json << EOF { "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF
FSDP 構成ファイルを作成します。
cat > fsdp-config.json << EOF { "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF
FSDP の詳細については、SPMD を使用した完全なシャード データ並列処理 をご覧ください。
次のコマンドを使用して、構成ファイルを Cloud TPU VM にアップロードします。
gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \ --worker=all \ --project=${PROJECT_ID} \ --zone=${ZONE}
モデルを実行する
前のセクションで作成した構成ファイルを使用して run_clm.py
スクリプトを実行し、WikiText データセットで Llama-3-8B モデルをトレーニングします。Cloud TPU v6e-8 でトレーニング スクリプトを実行すると、約 10 分かかります。
次のコマンドを使用して、Cloud TPU で Hugging Face にログインします。
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' pip3 install "huggingface_hub[cli]" huggingface-cli login --token HUGGING_FACE_TOKEN'
モデル トレーニングを実行します。
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=PROFILE_PATH python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
PyTorch/XLA のトラブルシューティングを行う
前のセクションでデバッグ用のオプション変数を設定した場合、モデルのプロファイルは PROFILE_LOGDIR
変数で指定された場所に保存されます。この場所に保存されている xplane.pb
ファイルを抽出し、tensorboard
を使用して、TensorBoard の手順に沿ってブラウザでプロファイルを表示できます。
PyTorch/XLA が想定どおりに動作しない場合は、トラブルシューティング ガイドをご覧ください。モデルのデバッグ、プロファイリング、最適化に関する推奨事項が記載されています。