Cloud TPU v5p トレーニング

Cloud TPU v5p は、Google Cloud の第 5 世代 Cloud TPU であり、v4 TPU の後継です。v5p は大規模なトレーニング用に最適化されており、基盤となる LLM、拡散モデル、生成 AI を開発するための主要なプラットフォームです。大まかに言うと、v5p は v4 の最大 2 倍の性能を備えながら、Pod に 2 倍の TPU を詰め込み(最大スライスは v4 の 3k に対して 6k)、Pod レベルで最大 4 倍の性能を実現します。また、高いクロック周波数(1.05 Ghz に対して 1.75 Ghz)で動作し、大規模な埋め込み用の SparseCore が追加され、高帯域幅メモリ(HBM)容量を 3 倍に増やしています。

Cloud TPU v5p のコンセプト

Cloud TPU を初めて使用する場合は、TPU のドキュメント ホームをご覧ください。

Cloud TPU のコンセプト(スライス、ホスト、TensorCore など)と、すべての Cloud TPU バージョンに対する Cloud TPU システム アーキテクチャについては、Cloud TPU システム アーキテクチャ ページをご覧ください。

各 Cloud TPU のバージョンでは、トレーニングまたは推論のために特定のアクセラレータ タイプが必要です。これらのアクセラレータ タイプについては、v5p 構成をご覧ください。

TPU リソースを管理する

TPU VM の管理に使用できるすべてのコマンドについては、TPU の管理またはキューに入れられたリソースの管理に関するキューに入れられたリソース ユーザーガイドをご覧ください。

フレームワークの設定

このセクションでは、TPU v5p で JAX または PyTorch を使用したモデル トレーニングの一般的な設定プロセスについて説明します。

JAX を設定する

スライス形状が 4 チップを超える場合、1 つのスライスに複数の VM があります。この場合、--worker=all フラグを使用して、1 つのコマンドですべての TPU VM にインストールする必要があります。

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='pip install "jax[tpu]==0.4.20" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

次のコマンドを実行して、デバイスの数を確認できます(ここに示した出力は、v5p-32 スライスで生成されています)。このコードは、JAX から Cloud TPU TensorCore が検出され、基本的なオペレーションを実行できることを確認することによって、すべてが正しくインストールされていることをテストします。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

出力は次のようになります。

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4

jax.device_count() は、指定されたスライス内のチップの合計数を示します。jax.local_device_count() は、このスライス内の 1 つの VM によってアクセス可能なチップの数を示します。

# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'

出力は次のようになります。

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]

--node=all を使用して、すべてのマルチスライス ワーカーでコマンドを実行します。

gcloud compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE} --node=all --worker=all \
--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

このドキュメントの JAX チュートリアルを試して、JAX を使用した v5p トレーニングを開始してください。

PyTorch を設定する

PJRT ランタイムは v5p でサポートされている唯一のランタイムで、PyTorch 2.1+ ではすべての TPU バージョンのデフォルト ランタイムとして PJRT が使用されます。このセクションでは、すべてのワーカーに PyTorch/XLA 2.2.0 を使用して v5p Pod で PJRT を使用する方法について説明します。

依存関係をインストールする

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='
sudo apt-get update
sudo apt-get install libopenblas-dev -y
pip3 install numpy
pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html
'

PJRT で Python スクリプトを使用して、インストールを検証し、使用可能な TPU デバイスを表示します(ここに示す出力は v5p-32 スライスで生成されています)。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} --zone ${ZONE} --worker=all \
--command='
PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))"
'
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
['xla:0', 'xla:1', 'xla:2', 'xla:3']
['xla:0', 'xla:1', 'xla:2', 'xla:3']
['xla:0', 'xla:1', 'xla:2', 'xla:3']
['xla:0', 'xla:1', 'xla:2', 'xla:3']

--node=all を使用して、すべてのマルチスライス ワーカーでコマンドを実行します。

gcloud compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE} --node=all --worker=all \
--command='
PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))"
'

このドキュメントの PyTorch チュートリアルを試して、PyTorch を使用した v5p トレーニングを開始します。

モニタリングとプロファイル

Cloud TPU v5p は、前世代の Cloud TPU と同じ方法でモニタリングとプロファイリングをサポートします。プロファイリングの詳細については、Cloud TPU ツールでモデルをプロファイリングするをご覧ください。モニタリングの詳細については、Cloud TPU VM をモニタリングするをご覧ください。

トレーニング チュートリアル

このセクションでは、単一スライスのトレーニング チュートリアルに焦点を当てます。これらのチュートリアルをマルチスライス トレーニングに適応させるには、SSH コマンドに --node=all フラグを追加します。詳細とベスト プラクティスについては、マルチスライスの概要をご覧ください。

JAX チュートリアル

Diffusion 2.1 をトレーニングする

このチュートリアルでは、Cloud TPU v5p で Pokémon データセットを使用して、HuggingFace から Stable Diffusion モデルをトレーニングする方法について説明します。

Stable Diffusion モデルは、テキスト入力からフォトリアリスティックな画像を生成する、潜在的 text-to-image モデルです。詳しくは、次のリソースをご覧ください。

設定

  1. 環境変数を作成します。

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5p-32
    export ZONE=us-east5-a
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_UNTIL_DURATION=1d
    

    コマンドフラグの説明

    変数 説明
    PROJECT_ID Google Cloud プロジェクト名
    ACCELERATOR_TYPE TPU のバージョンについては、TPU のバージョンをご覧ください。
    ZONE サポートされているゾーンについては、TPU のリージョンとゾーンのドキュメントをご覧ください。
    RUNTIME_VERSION v5p の場合、RUNTIME_VERSION に v2-alpha-tpuv5 を使用します。
    SERVICE_ACCOUNT これは、Google Cloud コンソール -> IAM] -> サービス アカウント で確認できるサービス アカウントのアドレスです。 例: tpu-service-account@myprojectID.iam.gserviceaccount.com
    TPU_NAME キューに格納されたリソース リクエストの割り当て時に作成される TPU のユーザー割り当て ID。
    QUEUED_RESOURCE_ID キューに格納されたリソース リクエストのユーザー割り当て ID。キューに格納されたリソースについては、キューに格納されたリソースのドキュメントをご覧ください。
    QUOTA_TYPE reservedbest-effort のいずれかを設定できます。どちらも指定されていない場合、QUOTA_TYPE はデフォルトで on-demand になります。Cloud TPU でサポートされている割り当てのさまざまなタイプについては、割り当てをご覧ください。
    VALID_UNTIL_DURATION リクエストが有効である期間。さまざまな有効期間の詳細については、キューに入れられたリソースをご覧ください。
  2. TPU リソースを作成します。

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
    --node-id ${TPU_NAME} \
    --project ${PROJECT_ID} \
    --zone ${ZONE} \
    --accelerator-type ${ACCELERATOR_TYPE} \
    --runtime-version ${RUNTIME_VERSION} \
    --valid-until-duration ${VALID_UNTIL_DURATION} \
    --service-account ${SERVICE_ACCOUNT} \
    --${QUOTA_TYPE}
    

    キューに格納されたリソースが ACTIVE 状態になると、TPU VM に SSH 接続できるようになります。 次のコマンドを実行して、キューに入れられたリソースの状態を確認します。

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
    --project ${PROJECT_ID} --zone ${ZONE}
    

    キューに入れられたリソースが ACTIVE 状態の場合、出力は次のようになります。

    state: ACTIVE
    
  3. JAX とその依存関係をインストールします。

    # compatible with v5p: only jax version 0.4.19 and later \
    # jax 0.4.19 requires py 3.10 \
    
    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} --zone=${ZONE} --worker=all \
    --command='pip install "jax[tpu]==0.4.20" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. HuggingFace のリポジトリをダウンロードし、要件をインストールします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='git clone https://github.com/huggingface/diffusers.git && cd diffusers && pip install . && pip install tensorflow clu && pip install -U -r examples/text_to_image/requirements_flax.txt'
    
  5. モデルのトレーニング

    事前にマッピングされたバッファ(4GB)を使用してモデルをトレーニングします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='export PATH=$PATH:$HOME/.local/bin && cd diffusers/examples/text_to_image && JAX_PLATFORMS=tpu,cpu python3 train_text_to_image_flax.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-1 --dataset_name=lambdalabs/pokemon-blip-captions --resolution=256 --center_crop --random_flip --train_batch_size=1 --mixed_precision=bf16 --max_train_steps=150 --learning_rate=1e-05 --max_grad_norm=1 --output_dir=sd-pokemon-model --from_pt'
    

クリーンアップ

セッションの終了時に TPU とキューに入れられたリソース リクエストを削除するか、「FAILED」状態のキューに入れられたリソース リクエストを削除します。キューに入れられたリソースを削除するには、2 ステップで、スライスを削除した後、キューに入れられたリソース リクエストを削除します。

   gcloud compute tpus tpu-vm delete ${TPU_NAME} --project=${PROJECT_ID}
   --zone=${ZONE} --quiet
   gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID}
   --project ${PROJECT_ID} --zone ${ZONE} --quiet

または、--force を使用して、1 ステップでスライスとキューに入れられたリソース リクエストを削除します。

# With --force
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID}
--project ${PROJECT_ID} --zone ${ZONE} --quiet --force

ベンチマークの結果

Stable Diffusion のトレーニング スクリプトは、v5p-8、v5p-32、v5p-128 で動作しました。次の表では、スループットを示します。

v5p-8

v5p-32

v5p-128

トレーニング ステップ

150

150

150

グローバル バッチサイズ

32

64

64

スループット(例/秒)

12.10

18.08

19.10

MaxText

このチュートリアルでは、Cloud TPU で合成データセットを使用して MaxText モデルをトレーニングする方法について説明します。

MaxText は、Cloud TPU をターゲットとする、純粋な Python/JAX で記述された、高性能で、任意にスケーラブルなオープンソースの LLM です。MaxText は、自然言語処理(NLP)の研究開発のフロンティアを前進させるためのアクセスしやすく適応可能なツールによって、研究者やデベロッパーを支援します。

このチュートリアルを実行する前に、Cloud TPU 環境を設定する必要があります。

  1. 環境変数を設定する

    export PROJECT_ID=your_project_ID
    export TPU_NAME=your_tpu_name # user defined TPU name
    export ACCELERATOR_TYPE=v5p-256
    export ZONE=us-east5-a
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export RUN_NAME=your_experiment_run_name # user defined name for this run
    export GCS_BUCKET_NAME=your_bucket_name # Output cloud folder. Should start with gs://
    export MAXTEXT_OUTPUT_PATH=${GCS_BUCKET_NAME}/your_experiment_output_path
    export NUM_SLICES=1 # Update the value to a number >1 for Multislice.
    

    コマンドフラグの説明

    変数 説明
    PROJECT_ID Google Cloud プロジェクト名
    TPU_NAME TPU のユーザー定義の名前。
    ACCELERATOR_TYPE TPU のバージョンについては、TPU のバージョンをご覧ください。
    ZONE サポートされているゾーンについては、TPU のリージョンとゾーンのドキュメントをご覧ください。
    RUNTIME_VERSION v5p の場合は、ランタイム バージョンに v2-alpha-tpuv5 を使用します。
    RUN_NAME ユーザーが指定したテスト実行名。

    マルチスライスに推奨されるオプションの設定:

    export NETWORK_NAME=your_network_name
    export FIREWALL_RULE_NAME=your_firewall_rule_name
    

    マルチスライス ワークロードを実行していて、最適なネットワーク パフォーマンスが必要な場合は、最大伝送単位(MTU)が 8,896 バイトの専用ネットワークを作成して、適切なファイアウォール ルールを構成することを検討してください。このステップは省略できますが、特にデータセンター ネットワーク(DCN)でスライス数をスケールアップする場合に、パフォーマンスを大幅に改善できます。なお、ネットワークを作成するには、プロジェクトに compute.networks.create 権限が必要です。次の例では、専用ネットワークとファイアウォール ルールを作成する方法を示します。

    専用ネットワークを作成します。

    gcloud compute networks create ${NETWORK_NAME} \
    --mtu=8896 \
    --project=${PROJECT_ID} \
    --subnet-mode=auto \
    --bgp-routing-mode=regional
    

    ファイアウォール ルールを作成します。

    gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
    --network ${NETWORK_NAME} --allow tcp,icmp,udp --project=${PROJECT_ID}
    
  2. MaxText リポジトリのクローンを作成します

    git clone https://github.com/google/maxtext.git
    
  3. モデルのトレーニング

    以降のセクションでは、MaxText をトレーニングするための 2 つのオプションについて説明します。

    オプション 1

    Cloud TPU のプロビジョニングと依存関係のインストールから、モデルの実行とリソースの削除まで、ワークフロー全体をスクリプトで管理する場合は、multihost_job.py を使用できます。

    cd maxtext && python3 multihost_job.py --PROJECT=${PROJECT_ID} --ZONE=${ZONE} \
    --NUM_SLICES=${NUM_SLICES} --TPU_TYPE=${ACCELERATOR_TYPE} \
    --VERSION=${RUNTIME_VERSION} --RUN_NAME=${RUN_NAME} #user defined run name \
    --BUCKET_NAME=${GCS_BUCKET_NAME} \ #used to store logs and configs
    --COMMAND="bash setup.sh && bash MaxText/configs/experimental/64b.sh RUN_NAME=${RUN_NAME} OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} PLATFORM=gce"
    

    スクリプトを開始すると、ログに次のようなメッセージが表示されます。ログの場所は出力メッセージで参照されます。TPU のプロビジョニングが完了したら、最初のリンクをクリックしてすべてのワーカーのログにアクセスします。

    ------------------------------------
    
    multihost_job finished running, TPUs are starting up to run your job remotely.
    
    Logs for your job are displayed here:
    https://console.cloud.google.com/logs/query;query=resource.type%3D%22gce_instance%22%20AND%0Alog_id%2528%22_log%22%2529;?project=PROJECT_ID
    
    To see the output of a single host, you may edit the slice and worker
    number in the `log_file_path` property here:
    
    https://console.cloud.google.com/logs/query;query=resource.type%3D%22gce_instance%22%20AND%0Alog_id%2528%22RUN_NAME_log%22%2529%20AND%0Alabels.%22agent.googleapis.com%2Flog_file_path%22%3D%20%22%2FRUN_NAME%2Fmain_command_log_slice_0_worker_0%22;?project=PROJECT_ID
    
    When your job is finished, the main command log is in your Cloud Storage
    bucket:
    https://console.cloud.google.com/storage/browser/YOUR_BUCKET_NAME/RUN_NAME?project=PROJECT_ID
    
    View the status of the created TPUs using:
    gcloud compute tpus queued-resources list --filter=RUN_NAME
    --zone=ZONE --project=PROJECT_ID
    
オプション 2

プロビジョニングされた Cloud TPU でトレーニング スクリプトを複数回実行するには、multihost_runner.py スクリプトを使ってリソースを使用します。

  1. 変数を設定して TPU を作成します。

    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=your_queued_resource_id
    export VALID_DURATION=1d
    export QUOTA_TYPE=quota_type
    
    --node-count ${NODE_COUNT} \
    --node-prefix ${NODE_PREFIX} # optional, the default is QUEUED_RESOURCE_ID
    
  2. TPU リソースを作成します。

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
    --node-id ${TPU_NAME} \
    --project ${PROJECT_ID} \
    --zone ${ZONE} \
    --accelerator-type ${ACCELERATOR_TYPE} \
    --runtime-version ${RUNTIME_VERSION} \
    --valid-until-duration ${VALID_DURATION} \
    --service-account ${SERVICE_ACCOUNT} \
    --${QUOTA_TYPE}
    

    QueuedResourceACTIVE 状態になると、SSH を使用して TPU VM に接続できるようになります。

    describe コマンドを使用して、キューに入れられたリソースのステータスを確認します。

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}
    --project ${PROJECT_ID} --zone ${ZONE}
    

    キューに格納されたリソースが ACTIVE 状態の場合、出力は次のようになります。

     state: ACTIVE
    
  3. SSH を使用して TPU に接続します

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
    
  4. 依存関係をインストールする

    export TPU_NAME=your_tpu_name
    export MAXTEXT_OUTPUT_PATH=output-path
    
    cd maxtext && python3 multihost_runner.py --TPU_PREFIX=${TPU_NAME} \
    --COMMAND='bash setup.sh'
    
  5. さまざまな構成スクリプト(32b.sh、64b.sh など)を使用してモデルを実行します。TPU VM からスクリプトを実行する場合は、--INTERNAL_IP=true フラグを追加する必要があります。

    python3 multihost_runner.py --TPU_PREFIX=${TPU_NAME} \
    --COMMAND="bash MaxText/configs/experimental/64b.sh RUN_NAME=${RUN_NAME}
    OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} PLATFORM=gce"
    

クリーンアップ

TPU とキューに格納されたリソースを削除します

ベンチマークの結果

MaxText トレーニング スクリプトは、32 ~ 1,160 バイトを bf16 適合率で実行しました。これらの実行の結果を次の表に示します。

パラメータの数

アクセラレータ タイプ

TFLOP/チップ/秒

モデルの FLOPS 使用率

(MFU)

32B

v5p-128

3.28E+02

71.47%

64B

v5p-128

3.23E+02

70.31%

128B

v5p-256

3.15E+02

68.68%

128B

v5p-512

3.15E+02

68.53%

256B

v5p-1024

3.16E+02

68.82%

512B

v5p-1024

2.94E+02

63.99%

1024B

v5p-2048

2.49E+02

64.05%

1024B

v5p-4096

2.97E+02

64.80%

1160B

v5p-7680

2.95E+02

64.27%

1160B

v5p-12288

3.04E+02

66.23%

256B パラメータ モデルは、bf16 と int8 の両方の重み付けを使用して、v5p-512 と v5p-1024 でテストされています。次の表では、これらのテスト結果を示します。

v5p-512

v5p-512

v5p-1024

v5p-1024

グローバル バッチサイズ

(トークン)

5.24E+05

5.24E+05

1.05E+06

1.05E+06

適合率

bf16

int8

bf16

int8

TFLOP/チップ/秒

307

408

308

414

モデルの FLOPS 使用率

(MFU)

66.98%

88.85%

67.09%

90.23%

TensorFlow チュートリアル

単一ホスト v5p で ResNet をトレーニングする

このチュートリアルでは、架空のデータセットを使用して v5p-8 TPU で ImageNet をトレーニングする方法について説明します。別のデータセットを使用する場合は、データセットの準備をご覧ください。

設定

  1. 環境変数を作成します。

    export PROJECT_ID=your-project-ID
    export ACCELERATOR_TYPE=v5p-8
    export ZONE=us-east1-c
    export RUNTIME_VERSION=tpu-vm-tf-2.16.1-pjrt
    export TPU_NAME=your-tpu-name
    export QUEUED_RESOURCE_ID=your-queued-resource-id
    export QUOTA_TYPE=quota-type
    

    このチュートリアルでは、ACCELERATOR_TYPE として v5p-8 を使用します。

  2. TPU リソースを作成します。

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --${QUOTA_TYPE}
    

    キューに入れられたリソースが ACTIVE 状態になると、SSH を使用して TPU VM に接続できるようになります。キューに格納されたリソースの状態を確認するには、次のコマンドを使用します。

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
    
  3. SSH を使用して TPU に接続します

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
    
  4. いくつかの環境変数を設定します

    export MODELS_REPO=/usr/share/tpu/models
    export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}"
    export MODEL_DIR=gcp-directory-to-store-model
    export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet
    export NEXT_PLUGGABLE_DEVICE_USE_C_API=true
    export TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so
    
  5. モデル リポジトリのディレクトリに移動し、要件をインストールします。

    cd ${MODELS_REPO} && git checkout r2.15.0
    pip install -r official/requirements.txt
    

モデルのトレーニング

  1. トレーニング スクリプトを実行します。

    python3 official/vision/train.py \
      --tpu=local \
      --experiment=resnet_imagenet \
      --mode=train_and_eval \
      --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \
      --model_dir=${MODEL_DIR} \
      --params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*,task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"
    

クリーンアップ

TPU とキューに格納されたリソースを削除します

マルチホスト v5p で ResNet をトレーニングする

このチュートリアルでは、架空のデータセットを使用して v5p-16 以上での ImageNet をトレーニングする方法について説明します。別のデータセットを使用する場合は、データセットの準備をご覧ください。

  1. 環境変数を作成します。

    export PROJECT_ID=your_project_ID
    export TPU_NAME=your_tpu_name
    export ZONE=us-east1-c
    export ACCELERATOR_TYPE=v5p-16
    export RUNTIME_VERSION=tpu-vm-tf-2.16.1-pod-pjrt
    export QUEUED_RESOURCE_ID=your-queued-resource-id
    export QUOTA_TYPE=quota-type
    

    ACCELERATOR_TYPEv5p-16 以上に設定できます。

  2. TPU リソースを作成します。

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --${QUOTA_TYPE}
    

    キューに入れられたリソースが ACTIVE 状態になると、SSH を使用して TPU VM に接続できるようになります。

    describe コマンドを使用して、キューに入れられたリソースのステータスをクエリします。

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
    
  3. SSH を使用して TPU(ワーカーゼロ)に接続します

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
    
  4. いくつかの環境変数を設定します

    export TPU_NAME=your_tpu_name
    export MODELS_REPO=/usr/share/tpu/models
    export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}"
    export MODEL_DIR=gcp-directory-to-store-model
    export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet
    export TPU_LOAD_LIBRARY=0
    
  5. モデル リポジトリのディレクトリに移動し、要件をインストールします。

    cd $MODELS_REPO && git checkout r2.15.0
    pip install -r official/requirements.txt
    

モデルのトレーニング

  1. トレーニング スクリプトを実行します。

    python3 official/vision/train.py \
      --tpu=${TPU_NAME} \
      --experiment=resnet_imagenet \
      --mode=train_and_eval \
      --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \
      --model_dir=${MODEL_DIR} \
      --params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*,task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"
    

クリーンアップ

TPU とキューに格納されたリソースを削除します

PyTorch/XLA

Llama 2

このチュートリアルでは、PyTorch/XLA の HuggingFace リポジトリの fork を使用して Llama 2 7B モデルを v5p でトレーニングする方法を説明します。ML 計算グラフ(GSPMD)用の一般かつスケーラブルな並列化も可能です。

設定

  1. プロジェクト ID、アクセラレータ タイプ、ゾーン、ランタイム バージョン、TPU 名の変数を作成します。

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5p-8
    export ZONE=us-east5-a
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=your_queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_DURATION=1d
    
  2. TPU リソースの作成

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
    --node-id ${TPU_NAME} \
    --project ${PROJECT_ID} \
    --zone ${ZONE} \
    --accelerator-type ${ACCELERATOR_TYPE} \
    --runtime-version ${RUNTIME_VERSION} \
    --valid-until-duration ${VALID_DURATION} \
    --service-account ${SERVICE_ACCOUNT} \
    --${QUOTA_TYPE}
    

    QueuedResourceACTIVE 状態になると、SSH を使用して TPU VM に接続できるようになります。

    describe コマンドを使用して、キューに入れられたリソースのステータスを確認します。

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
    --project ${PROJECT_ID} \
    --zone ${ZONE}
    

    キューに格納されたリソースが ACTIVE 状態の場合、出力は次のようになります。

     state: ACTIVE
    

  3. Pytorch/XLA と必要な依存関係をインストールします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
    --project ${PROJECT_ID} \
    --zone  ${ZONE} \
    --worker=all \
    --command='
    sudo apt-get update
    sudo apt-get install libopenblas-dev -y
    pip3 install numpy
    pip3 install typing-extensions
    pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html
    '
    
  4. HuggingFace リポジトリをダウンロードして、要件をインストールします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
    git clone -b llama2-google-next-training https://github.com/pytorch-tpu/transformers.git
    cd transformers
    pip3 install git+file://$PWD
    pip3 install datasets accelerate evaluate scikit-learn'
    
  5. 7B モデル構成をダウンロードします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command="curl https://huggingface.co/TheBloke/Llama-2-7B-fp16/raw/main/config.json --output ~/config.json"
    
  6. モデルのトレーニング

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
    export PJRT_DEVICE=TPU
    export XLA_USE_BF16=1
    export XLA_IR_DEBUG=1
    export XLA_HLO_DEBUG=1
    
    export LIBTPU_INIT_ARGS="--xla_enable_async_collective_permute=true
    --xla_tpu_enable_async_collective_fusion_multiple_steps=true
    --xla_tpu_enable_async_collective_fusion=true
    --xla_tpu_overlap_compute_collective_tc=true
    --xla_enable_async_all_gather=true
    --xla_jf_spmd_threshold_for_windowed_einsum_mib=0"
    
    export PROFILE_EPOCH=0
    export PROFILE_STEP=3
    export PROFILE_DURATION_MS=20000
    export PROFILE_LOGDIR=/tmp/home/
    
    cd transformers
    python examples/pytorch/language-modeling/run_clm.py \
     --tokenizer_name hf-internal-testing/llama-tokenizer \
     --dataset_name wikitext \
     --dataset_config_name wikitext-2-raw-v1 \
     --per_device_train_batch_size 96 \
     --per_device_eval_batch_size 8 \
     --num_train_epochs 1 \
     --do_train \
     --output_dir /tmp/output \
     --overwrite_output_dir \
     --config_name ~/config.json \
     --save_strategy no \
     --logging_strategy no \
     --remove_unused_columns no \
     --optim adafactor \
     --torch_dtype bfloat16 \
     --dataloader_drop_last yes \
     --block_size 2048 \
     --spmd_2d_sharding 1 \
     --spmd_grad_chkpt
    '
    

マルチスライス環境で実行している場合は、フラグ --spmd_dcn_parallelism をスライスの数に設定する必要があります。

SPMD_USER_GUIDE には、HF スクリプトのさまざまな環境変数とトグルを説明する詳細なユーザーガイドが用意されています。なお、LIBTPU_INIT_ARGS は PyTorch/XLA に組み込まれ、今後のリリースではデフォルトでオンになります。

クリーンアップ

TPU とキューに格納されたリソースを削除します

ベンチマークの結果

次の表では、3 つの Llama 2 モデルのサイズすべてのスループットを示します。

v5p-8

v5p-128

v5p-128

モデルのサイズ

70 億人

13B

70B

グローバル バッチサイズ

96

1024

128

メッシュ形状のシャーディング

(4, 1)

(64, 1)

(16, 4)

モデルの FLOPS 使用率

(MFU)

56.67%

55.80%

51.85%

サポートとフィードバック

フィードバックをぜひお寄せください。フィードバックを共有したり、サポートをリクエストしたりするには、Cloud TPU サポートまたはフィードバック フォームにご記入ください。