PyTorch コードを TPU Pod スライスで実行する

PyTorch / XLA では、すべての TPU VM がモデルのコードとデータにアクセスできる必要があります。起動スクリプトを使用して、モデルデータをすべての TPU VM に分散させるのに必要なソフトウェアをダウンロードできます。

TPU VM を Virtual Private Cloud(VPC)に接続する場合は、ポート 8470~8479 の上り(内向き)を許可するファイアウォール ルールをプロジェクトに追加する必要があります。ファイアウォール ルールの追加方法については、ファイアウォール ルールの使用をご覧ください。

環境の設定

  1. Cloud Shell で次のコマンドを実行して、gcloud の最新バージョンを実行していることを確認します。

    $ gcloud components update

    gcloud をインストールする必要がある場合は、次のコマンドを使用します。

    $ sudo apt install -y google-cloud-sdk
  2. いくつかの環境変数を作成します。

    $ export PROJECT_ID=project-id
    $ export TPU_NAME=tpu-name
    $ export ZONE=us-central2-b
    $ export RUNTIME_VERSION=tpu-ubuntu2204-base
    $ export ACCELERATOR_TYPE=v4-32

TPU VM を作成します。

$ gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--project=${PROJECT_ID} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version ${RUNTIME_VERSION}

トレーニング スクリプトを構成して実行する

  1. プロジェクトに SSH 証明書を追加します。

    ssh-add ~/.ssh/google_compute_engine
  2. すべての TPU VM ワーカーに PyTorch/XLA をインストールする

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="
      pip install torch~=2.4.0 torch_xla[tpu]~=2.4.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
  3. すべての TPU VM ワーカーで XLA のクローンを作成する

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="git clone -b r2.4 https://github.com/pytorch/xla.git"
  4. すべてのワーカーでトレーニング スクリプトを実行する

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all \
      --command="PJRT_DEVICE=TPU python3 ~/xla/test/test_train_mp_imagenet.py  \
      --fake_data \
      --model=resnet50  \
      --num_epochs=1 2>&1 | tee ~/logs.txt"
      

    トレーニングには約 5 分間を要します。完了すると、次のようなメッセージが表示されます。

    Epoch 1 test end 23:49:15, Accuracy=100.00
    10.164.0.11 [0] Max Accuracy: 100.00%
    

クリーンアップ

TPU VM の使用を終了したら、次の手順に沿ってリソースをクリーンアップします。

  1. Compute Engine から切断します。

    (vm)$ exit
  2. 次のコマンドを実行して、リソースが削除されたことを確認します。TPU がリストに表示されないことを確認します。削除には数分かかることがあります。

    $ gcloud compute tpus tpu-vm list \
      --zone europe-west4-a