Cloud TPU v5p トレーニング
Cloud TPU v5p は、Google Cloud の第 5 世代 Cloud TPU であり、v4 TPU の後継です。v5p は大規模なトレーニング用に最適化されており、基盤となる LLM、拡散モデル、生成 AI を開発するための主要なプラットフォームです。大まかに言うと、v5p は v4 の最大 2 倍の性能を備えながら、Pod に 2 倍の TPU を詰め込み(最大スライスは v4 の 3k に対して 6k)、Pod レベルで最大 4 倍の性能を実現します。また、高いクロック周波数(1.05 Ghz に対して 1.75 Ghz)で動作し、大規模な埋め込み用の SparseCore が追加され、高帯域幅メモリ(HBM)容量を 3 倍に増やしています。
Cloud TPU v5p のコンセプト
Cloud TPU を初めて使用する場合は、TPU のドキュメント ホームをご覧ください。
Cloud TPU のコンセプト(スライス、ホスト、TensorCore など)と、すべての Cloud TPU バージョンに対する Cloud TPU システム アーキテクチャについては、Cloud TPU システム アーキテクチャ ページをご覧ください。
各 Cloud TPU のバージョンでは、トレーニングまたは推論のために特定のアクセラレータ タイプが必要です。これらのアクセラレータ タイプについては、v5p 構成をご覧ください。
TPU リソースを管理する
TPU VM の管理に使用できるすべてのコマンドについては、TPU の管理またはキューに入れられたリソースの管理に関するキューに入れられたリソース ユーザーガイドをご覧ください。
フレームワークの設定
このセクションでは、TPU v5p で JAX または PyTorch を使用したモデル トレーニングの一般的な設定プロセスについて説明します。
JAX を設定する
スライス形状が 4 チップを超える場合、1 つのスライスに複数の VM があります。この場合、--worker=all
フラグを使用して、1 つのコマンドですべての TPU VM にインストールを実行する必要があります。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command='pip install "jax[tpu]==0.4.20" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
次のコマンドを実行して、デバイスの数を確認できます(ここに示す出力は v5p-32 スライスで生成されています)。このコードは、JAX で Cloud TPU TensorCore が表示され、基本オペレーションを実行できることを確認することによって、すべてが正しくインストールされていることをテストします。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'
出力は次のようになります。
SSH: Attempting to connect to worker 0... SSH: Attempting to connect to worker 1... SSH: Attempting to connect to worker 2... SSH: Attempting to connect to worker 3... 16 4 16 4 16 4 16 4
jax.device_count()
は、指定されたスライス内のチップの合計数を示します。jax.local_device_count()
は、このスライス内の単一の VM からアクセス可能なチップの数を示します。
# Check the number of chips in the given slice by summing the count of chips # from all VMs through the # jax.local_device_count() API call. gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'
出力は次のようになります。
SSH: Attempting to connect to worker 0... SSH: Attempting to connect to worker 1... SSH: Attempting to connect to worker 2... SSH: Attempting to connect to worker 3... [16. 16. 16. 16.] [16. 16. 16. 16.] [16. 16. 16. 16.] [16. 16. 16. 16.]
--node=all
を使用して、すべてのマルチスライス ワーカーでコマンドを実行します。
gcloud compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --node=all --worker=all \ --command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'
このドキュメントの JAX チュートリアルを試して、JAX を使用して v5p トレーニングを始めましょう。
PyTorch を設定する
PJRT ランタイムは v5p でサポートされている唯一のランタイムで、PyTorch 2.1+ ではすべての TPU バージョンのデフォルト ランタイムとして PJRT が使用されます。このセクションでは、すべてのワーカーに PyTorch/XLA 2.2.0 を使用して v5p Pod で PJRT を使用する方法について説明します。
依存関係をインストールする
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' sudo apt-get update sudo apt-get install libopenblas-dev -y pip3 install numpy pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html '
PJRT で Python スクリプトを使用してインストールの検証を行い、使用可能な TPU デバイスを表示します(ここに示す出力は v5p-32 スライスで生成されています)。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} --zone ${ZONE} --worker=all \ --command=' PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))" '
SSH: Attempting to connect to worker 0... SSH: Attempting to connect to worker 1... SSH: Attempting to connect to worker 2... SSH: Attempting to connect to worker 3... ['xla:0', 'xla:1', 'xla:2', 'xla:3'] ['xla:0', 'xla:1', 'xla:2', 'xla:3'] ['xla:0', 'xla:1', 'xla:2', 'xla:3'] ['xla:0', 'xla:1', 'xla:2', 'xla:3']
--node=all
を使用して、すべてのマルチスライス ワーカーでコマンドを実行します。
gcloud compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --node=all --worker=all \ --command=' PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))" '
このドキュメントの PyTorch チュートリアルを試して、PyTorch を使用した v5p トレーニングを開始します。
モニタリングとプロファイル
Cloud TPU v5p では、前の世代の Cloud TPU と同じ方法を使用したモニタリングとプロファイリングがサポートされています。プロファイリングの詳細については、Cloud TPU ツールでモデルをプロファイリングするをご覧ください。モニタリングの詳細については、Cloud TPU VM のモニタリングをご覧ください。
トレーニング チュートリアル
このセクションでは、シングルスライスのトレーニング チュートリアルに重点を置いて説明します。これらのチュートリアルをマルチスライス トレーニングに適応させるには、SSH コマンドに --node=all
フラグを追加します。詳細とベスト プラクティスについては、マルチスライスの概要をご覧ください。
JAX チュートリアル
Diffusion 2.1 をトレーニングする
このチュートリアルでは、Cloud TPU v5p で Pokémon データセットを使用して、HuggingFace から Stable Diffusion モデルをトレーニングする方法について説明します。
Stable Diffusion モデルは、テキスト入力からフォトリアリスティックな画像を生成する、潜在的 text-to-image モデルです。詳しくは、次のリソースをご覧ください。
設定
環境変数を作成します。
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5p-32 export ZONE=us-east5-a export RUNTIME_VERSION=v2-alpha-tpuv5 export SERVICE_ACCOUNT=your_service_account export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=queued_resource_id export QUOTA_TYPE=quota_type export VALID_UNTIL_DURATION=1d
コマンドフラグの説明
変数 説明 PROJECT_ID Google Cloud プロジェクト名 ACCELERATOR_TYPE TPU のバージョンについては、TPU のバージョンをご覧ください。 ZONE サポートされているゾーンについては、TPU のリージョンとゾーンのドキュメントをご覧ください。 RUNTIME_VERSION v5p の場合、RUNTIME_VERSION に v2-alpha-tpuv5 を使用します。 SERVICE_ACCOUNT これは、Google Cloud コンソール -> IAM] -> サービス アカウント で確認できるサービス アカウントのアドレスです。 例: tpu-service-account@myprojectID。iam.gserviceaccount.com TPU_NAME キューに格納されたリソース リクエストの割り当て時に作成される TPU のユーザー割り当て ID。 QUEUED_RESOURCE_ID キューに格納されたリソース リクエストのユーザー割り当て ID。キューに格納されたリソースについては、キューに格納されたリソースのドキュメントをご覧ください。 QUOTA_TYPE reserved
、spot
のいずれかを設定できます。どちらも指定されていない場合、QUOTA_TYPE はデフォルトでon-demand
になります。Cloud TPU でサポートされている割り当てのさまざまなタイプについては、割り当てをご覧ください。VALID_UNTIL_DURATION リクエストが有効である期間。さまざまな有効期間の詳細については、キューに入れられたリソースをご覧ください。 -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --valid-until-duration ${VALID_UNTIL_DURATION} \ --service-account ${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
キューに格納されたリソースが
ACTIVE
状態になると、TPU VM に SSH 接続できるようになります。 次のコマンドを実行して、キューに入れられたリソースの状態を確認します。gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE}
キューに入れられたリソースが
ACTIVE
状態の場合、出力は次のようになります。state: ACTIVE
JAX とその依存関係をインストールします。
# compatible with v5p: only jax version 0.4.19 and later \ # jax 0.4.19 requires py 3.10 \ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} --zone=${ZONE} --worker=all \ --command='pip install "jax[tpu]==0.4.20" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
HuggingFace のリポジトリをダウンロードし、要件をインストールします。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/huggingface/diffusers.git && cd diffusers && pip install . && pip install tensorflow clu && pip install -U -r examples/text_to_image/requirements_flax.txt'
モデルのトレーニング
事前にマッピングされたバッファ(4GB)を使用してモデルをトレーニングします。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='export PATH=$PATH:$HOME/.local/bin && cd diffusers/examples/text_to_image && JAX_PLATFORMS=tpu,cpu python3 train_text_to_image_flax.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-1 --dataset_name=lambdalabs/pokemon-blip-captions --resolution=256 --center_crop --random_flip --train_batch_size=1 --mixed_precision=bf16 --max_train_steps=150 --learning_rate=1e-05 --max_grad_norm=1 --output_dir=sd-pokemon-model --from_pt'
クリーンアップ
セッションの終了時に TPU とキューに入れられたリソース リクエストを削除するか、「FAILED」状態のキューに入れられたリソース リクエストを削除します。キューに入れられたリソースを削除するには、2 ステップで、スライスを削除した後、キューに入れられたリソース リクエストを削除します。
gcloud compute tpus tpu-vm delete ${TPU_NAME} --project=${PROJECT_ID} --zone=${ZONE} --quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} --project ${PROJECT_ID} --zone ${ZONE} --quiet
または、--force
を使用して、1 ステップでスライスとキューに入れられたリソース リクエストを削除します。
# With --force gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
ベンチマークの結果
Stable Diffusion のトレーニング スクリプトは、v5p-8、v5p-32、v5p-128 で実行されました。次の表では、スループットを示します。
v5p-8 |
v5p-32 |
v5p-128 |
|
---|---|---|---|
トレーニング ステップ |
150 |
150 |
150 |
グローバル バッチサイズ |
32 |
64 |
64 |
スループット(例/秒) |
12.10 |
18.08 |
19.10 |
MaxText
このチュートリアルでは、Cloud TPU で合成データセットを使用して MaxText モデルをトレーニングする方法について説明します。
MaxText は、Cloud TPU をターゲットとする純粋な Python/JAX で記述された、高性能で任意に拡張可能なオープンソース LLM です。MaxText は、自然言語処理(NLP)の研究開発の新境地を開拓するための、アクセスしやすく適応性の高いツールを使用して研究者と開発者を支援します。
このチュートリアルを実行する前に、Cloud TPU 環境を設定する必要があります。
環境変数を設定する
export PROJECT_ID=your_project_ID export TPU_NAME=your_tpu_name # user defined TPU name export ACCELERATOR_TYPE=v5p-256 export ZONE=us-east5-a export RUNTIME_VERSION=v2-alpha-tpuv5 export RUN_NAME=your_experiment_run_name # user defined name for this run export GCS_BUCKET_NAME=your_bucket_name # Output cloud folder. Should start with gs:// export MAXTEXT_OUTPUT_PATH=${GCS_BUCKET_NAME}/your_experiment_output_path export NUM_SLICES=1 # Update the value to a number >1 for Multislice.
コマンドフラグの説明
変数 説明 PROJECT_ID Google Cloud プロジェクト名 TPU_NAME TPU のユーザー定義の名前。 ACCELERATOR_TYPE TPU のバージョンについては、TPU のバージョンをご覧ください。 ZONE サポートされているゾーンについては、TPU のリージョンとゾーンのドキュメントをご覧ください。 RUNTIME_VERSION v5p の場合は、ランタイム バージョンに v2-alpha-tpuv5 を使用します。 RUN_NAME ユーザーが指定したテスト実行名。 マルチスライスに推奨されるオプションの設定:
export NETWORK_NAME=your_network_name export FIREWALL_RULE_NAME=your_firewall_rule_name
マルチスライス ワークロードを実行していて、最適なネットワーク パフォーマンスが必要な場合は、最大伝送単位(MTU)が 8,896 バイトの専用ネットワークを作成して、適切なファイアウォール ルールを構成することを検討してください。このステップは省略できますが、特にデータセンター ネットワーク(DCN)でスライス数をスケールアップする場合に、パフォーマンスを大幅に改善できます。なお、ネットワークを作成するには、プロジェクトに
compute.networks.create
権限が必要です。次の例では、専用ネットワークとファイアウォール ルールを作成する方法を示します。専用ネットワークを作成します。
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT_ID} \ --subnet-mode=auto \ --bgp-routing-mode=regional
ファイアウォール ルールを作成します。
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network ${NETWORK_NAME} --allow tcp,icmp,udp --project=${PROJECT_ID}
MaxText リポジトリのクローンを作成します
git clone https://github.com/google/maxtext.git
モデルのトレーニング
以降のセクションでは、MaxText をトレーニングするための 2 つのオプションについて説明します。
オプション 1
Cloud TPU のプロビジョニングや依存関係のインストールから、モデルの実行とリソースの破棄まで、ワークフロー全体を管理するスクリプトが必要な場合は、
multihost_job.py
を使用できます。cd maxtext && python3 multihost_job.py --PROJECT=${PROJECT_ID} --ZONE=${ZONE} \ --NUM_SLICES=${NUM_SLICES} --TPU_TYPE=${ACCELERATOR_TYPE} \ --VERSION=${RUNTIME_VERSION} --RUN_NAME=${RUN_NAME} #user defined run name \ --BUCKET_NAME=${GCS_BUCKET_NAME} \ #used to store logs and configs --COMMAND="bash setup.sh && bash MaxText/configs/experimental/64b.sh RUN_NAME=${RUN_NAME} OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} PLATFORM=gce"
スクリプトを開始すると、ログに次のようなメッセージが表示されます。ログの場所は出力メッセージで参照されます。TPU のプロビジョニングが完了したら、最初のリンクをクリックしてすべてのワーカーのログにアクセスします。
------------------------------------ multihost_job finished running, TPUs are starting up to run your job remotely. Logs for your job are displayed here: https://console.cloud.google.com/logs/query;query=resource.type%3D%22gce_instance%22%20AND%0Alog_id%2528%22
_log%22%2529;?project=PROJECT_ID To see the output of a single host, you may edit the slice and worker number in the `log_file_path` property here: https://console.cloud.google.com/logs/query;query=resource.type%3D%22gce_instance%22%20AND%0Alog_id%2528%22RUN_NAME_log%22%2529%20AND%0Alabels.%22agent.googleapis.com%2Flog_file_path%22%3D%20%22%2FRUN_NAME%2Fmain_command_log_slice_0_worker_0%22;?project=PROJECT_ID When your job is finished, the main command log is in your Cloud Storage bucket: https://console.cloud.google.com/storage/browser/YOUR_BUCKET_NAME/RUN_NAME?project=PROJECT_ID View the status of the created TPUs using: gcloud compute tpus queued-resources list --filter=RUN_NAME --zone=ZONE --project=PROJECT_ID
オプション 2
プロビジョニングされた Cloud TPU でトレーニング スクリプトを複数回実行するには、multihost_runner.py
スクリプトを使ってリソースを使用します。
変数を設定して TPU を作成します。
export SERVICE_ACCOUNT=your_service_account export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id export VALID_DURATION=1d export QUOTA_TYPE=quota_type
--node-count ${NODE_COUNT} \ --node-prefix ${NODE_PREFIX} # optional, the default is QUEUED_RESOURCE_ID
TPU リソースを作成します。
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --valid-until-duration ${VALID_DURATION} \ --service-account ${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
QueuedResource
がACTIVE
状態になると、SSH を使用して TPU VM に接続できるようになります。describe
コマンドを使用して、キューに入れられたリソースのステータスを確認します。gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} --project ${PROJECT_ID} --zone ${ZONE}
キューに格納されたリソースが ACTIVE 状態の場合、出力は次のようになります。
state: ACTIVE
SSH を使用して TPU に接続します
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE}
依存関係をインストールする
export TPU_NAME=your_tpu_name export MAXTEXT_OUTPUT_PATH=output-path
cd maxtext && python3 multihost_runner.py --TPU_PREFIX=${TPU_NAME} \ --COMMAND='bash setup.sh'
32b.sh、64b.sh などのさまざまな構成スクリプトを使用してモデルを実行します。TPU VM からスクリプトを実行する場合は、
--INTERNAL_IP=true
フラグを追加する必要があります。python3 multihost_runner.py --TPU_PREFIX=${TPU_NAME} \ --COMMAND="bash MaxText/configs/experimental/64b.sh RUN_NAME=${RUN_NAME} OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} PLATFORM=gce"
クリーンアップ
ベンチマークの結果
MaxText トレーニング スクリプトは、32 ~ 1,160 バイトを bf16 適合率で実行しました。これらの実行結果を次の表に示します。
パラメータの数 |
アクセラレータ タイプ |
TFLOP/チップ/秒 |
モデルの FLOPS 使用率 (MFU) |
---|---|---|---|
32B |
v5p-128 |
3.28E+02 |
71.47% |
64B |
v5p-128 |
3.23E+02 |
70.31% |
128B |
v5p-256 |
3.15E+02 |
68.68% |
128B |
v5p-512 |
3.15E+02 |
68.53% |
256B |
v5p-1024 |
3.16E+02 |
68.82% |
512B |
v5p-1024 |
2.94E+02 |
63.99% |
1024B |
v5p-2048 |
2.49E+02 |
64.05% |
1024B |
v5p-4096 |
2.97E+02 |
64.80% |
1160B |
v5p-7680 |
2.95E+02 |
64.27% |
1160B |
v5p-12288 |
3.04E+02 |
66.23% |
256B パラメータ モデルは、bf16 と int8 の両方の重み付けを使用して、v5p-512 と v5p-1024 でテストされています。次の表では、これらのテスト結果を示します。
v5p-512 |
v5p-512 |
v5p-1024 |
v5p-1024 |
|
---|---|---|---|---|
グローバル バッチサイズ (トークン) |
5.24E+05 |
5.24E+05 |
1.05E+06 |
1.05E+06 |
適合率 |
bf16 |
int8 |
bf16 |
int8 |
TFLOP/チップ/秒 |
307 |
408 |
308 |
414 |
モデルの FLOPS 使用率 (MFU) |
66.98% |
88.85% |
67.09% |
90.23% |
TensorFlow チュートリアル
単一ホスト v5p で ResNet をトレーニングする
このチュートリアルでは、架空のデータセットを使用して v5p-8
TPU で ImageNet をトレーニングする方法について説明します。別のデータセットを使用する場合は、データセットの準備をご覧ください。
設定
環境変数を作成します。
export PROJECT_ID=your-project-ID export ACCELERATOR_TYPE=v5p-8 export ZONE=us-east1-c export RUNTIME_VERSION=tpu-vm-tf-2.17.0-pjrt export TPU_NAME=your-tpu-name export QUEUED_RESOURCE_ID=your-queued-resource-id export QUOTA_TYPE=quota-type
このチュートリアルでは、
ACCELERATOR_TYPE
としてv5p-8
を使用します。-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --${QUOTA_TYPE}
キューに入れられたリソースが
ACTIVE
状態になると、SSH を使用して TPU VM に接続できるようになります。キューに格納されたリソースの状態を確認するには、次のコマンドを使用します。gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} \ --zone ${ZONE}
SSH を使用して TPU に接続します
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE}
いくつかの環境変数を設定します
export MODELS_REPO=/usr/share/tpu/models export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}" export MODEL_DIR=gcp-directory-to-store-model export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet export NEXT_PLUGGABLE_DEVICE_USE_C_API=true export TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so
モデル リポジトリのディレクトリに移動し、要件をインストールします。
cd ${MODELS_REPO} && git checkout r2.15.0 pip install -r official/requirements.txt
モデルのトレーニング
トレーニング スクリプトを実行します。
python3 official/vision/train.py \ --tpu=local \ --experiment=resnet_imagenet \ --mode=train_and_eval \ --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \ --model_dir=${MODEL_DIR} \ --params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*,task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"
クリーンアップ
マルチホスト v5p で ResNet をトレーニングする
このチュートリアルでは、架空のデータセットを使用して v5p-16
以上での ImageNet をトレーニングする方法について説明します。別のデータセットを使用する場合は、データセットの準備をご覧ください。
環境変数を作成します。
export PROJECT_ID=your_project_ID export TPU_NAME=your_tpu_name export ZONE=us-east1-c export ACCELERATOR_TYPE=v5p-16 export RUNTIME_VERSION=tpu-vm-tf-2.17.0-pod-pjrt export QUEUED_RESOURCE_ID=your-queued-resource-id export QUOTA_TYPE=quota-type
ACCELERATOR_TYPE
はv5p-16
かそれ以上にすることができます。-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --${QUOTA_TYPE}
キューに入れられたリソースが
ACTIVE
状態になると、SSH を使用して TPU VM に接続できるようになります。describe
コマンドを使用して、キューに入れられたリソースのステータスをクエリします。gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} \ --zone ${ZONE}
SSH を使用して TPU(ワーカーゼロ)に接続します
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE}
いくつかの環境変数を設定します
export TPU_NAME=your_tpu_name export MODELS_REPO=/usr/share/tpu/models export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}" export MODEL_DIR=gcp-directory-to-store-model export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet export TPU_LOAD_LIBRARY=0
モデル リポジトリのディレクトリに移動し、要件をインストールします。
cd $MODELS_REPO && git checkout r2.15.0 pip install -r official/requirements.txt
モデルのトレーニング
トレーニング スクリプトを実行します。
python3 official/vision/train.py \ --tpu=${TPU_NAME} \ --experiment=resnet_imagenet \ --mode=train_and_eval \ --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \ --model_dir=${MODEL_DIR} \ --params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*,task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"
クリーンアップ
PyTorch/XLA
Llama 2
このチュートリアルでは、ML 計算グラフ(GSPMD)の一般および拡張可能な並列化により、PyTorch/XLA で HuggingFace のリポジトリのフォークを使用して、v5p で Llama 2 7B モデルをトレーニングする方法について説明します。
設定
プロジェクト ID、アクセラレータ タイプ、ゾーン、ランタイム バージョン、TPU 名の変数を作成します。
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5p-8 export ZONE=us-east5-a export RUNTIME_VERSION=v2-alpha-tpuv5 export SERVICE_ACCOUNT=your_service_account export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id export QUOTA_TYPE=quota_type export VALID_DURATION=1d
TPU リソースの作成
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --valid-until-duration ${VALID_DURATION} \ --service-account ${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
QueuedResource
がACTIVE
状態になると、SSH を使用して TPU VM に接続できるようになります。describe
コマンドを使用して、キューに入れられたリソースのステータスを確認します。gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} \ --zone ${ZONE}
キューに格納されたリソースが ACTIVE 状態の場合、出力は次のようになります。
state: ACTIVE
Pytorch/XLA と必要な依存関係をインストールします。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' sudo apt-get update sudo apt-get install libopenblas-dev -y pip3 install numpy pip3 install typing-extensions pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html '
HuggingFace のリポジトリをダウンロードし、要件をインストールします。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' git clone -b llama2-google-next-training https://github.com/pytorch-tpu/transformers.git cd transformers pip3 install git+file://$PWD pip3 install datasets accelerate evaluate scikit-learn'
7B モデル構成をダウンロードします。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="curl https://huggingface.co/TheBloke/Llama-2-7B-fp16/raw/main/config.json --output ~/config.json"
モデルのトレーニング
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_BF16=1 export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export LIBTPU_INIT_ARGS="--xla_enable_async_collective_permute=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_jf_spmd_threshold_for_windowed_einsum_mib=0" export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=20000 export PROFILE_LOGDIR=/tmp/home/ cd transformers python examples/pytorch/language-modeling/run_clm.py \ --tokenizer_name hf-internal-testing/llama-tokenizer \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 96 \ --per_device_eval_batch_size 8 \ --num_train_epochs 1 \ --do_train \ --output_dir /tmp/output \ --overwrite_output_dir \ --config_name ~/config.json \ --save_strategy no \ --logging_strategy no \ --remove_unused_columns no \ --optim adafactor \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --block_size 2048 \ --spmd_2d_sharding 1 \ --spmd_grad_chkpt '
マルチスライス環境で実行している場合は、フラグ --spmd_dcn_parallelism
をスライス数に設定する必要があります。
SPMD_USER_GUIDE には、HF スクリプトのさまざまな環境変数とトグルを説明する詳細なユーザーガイドが用意されています。LIBTPU_INIT_ARGS は PyTorch/XLA に組み込まれ、今後のリリースではデフォルトで有効になる予定です。
クリーンアップ
ベンチマークの結果
次の表では、3 つの Llama 2 モデルのサイズすべてのスループットを示します。
v5p-8 |
v5p-128 |
v5p-128 |
|
---|---|---|---|
モデルのサイズ |
70 億人 |
13B |
70B |
グローバル バッチサイズ |
96 |
1024 |
128 |
シャーディング メッシュ形状 |
(4, 1) |
(64, 1) |
(16, 4) |
モデルの FLOPS 使用率 (MFU) |
56.67% |
55.80% |
51.85% |
サポートとフィードバック
フィードバックをぜひお寄せください。フィードバックを共有したり、サポートをリクエストしたりするには、Cloud TPU サポートまたはフィードバック フォームにご記入ください。