Cloud TPU Pod での PyTorch モデルのトレーニング


このチュートリアルでは、TPU ノード構成を使用して、単一の Cloud TPU(v2-8 または v3-8)から Cloud TPU Pod へのモデルのトレーニングをスケールアップする方法を示します。TPU Pod 内の Cloud TPU アクセラレータは超高帯域幅の相互接続によって接続されており、トレーニング ジョブのスケールアップを効率的に実施できます。

Cloud TPU Pods プロダクトの詳細については、 Cloud TPU プロダクト ページまたはこの Cloud TPU プレゼンテーションをご覧ください。

次の図に、分散クラスタの設定の概要を示します。VM のインスタンス グループは、TPU Pod に接続されます。8 つの TPU コアのグループそれぞれに 1 つずつ VM が必要です。VM は TPU コアにデータをフィードし、すべてのトレーニングは TPU Pod で実行されます。

画像

目標

  • PyTorch または XLA を使用したトレーニング用に Compute Engine インスタンス グループと Cloud TPU ポッドを設定する
  • Cloud TPU ポッドで PyTorch トレーニングまたは XLA トレーニングを実行する

始める前に

Cloud TPU Pod の分散トレーニングを開始する前に、モデルが単一の v2-8 または v3-8 Cloud TPU デバイスでトレーニングを受けていることを確認します。単一のデバイス上でモデルに重大なパフォーマンスの問題が生じた場合は、ベスト プラクティストラブルシューティングのガイドをご覧ください。

単一の TPU デバイスのトレーニングが完了したら、次の手順を実行して Cloud TPU Pod を設定し、トレーニングします。

  1. gcloud コマンドを設定します。

  2. [省略可] VM ディスク イメージを VM イメージにキャプチャします。

  3. VM イメージからインスタンス テンプレートを作成します。

  4. インスタンス テンプレートからインスタンス グループを作成します。

  5. Compute Engine VM に SSH 接続します。

  6. VM 間通信を許可するファイアウォール ルールを確認します。

  7. Cloud TPU Pod を作成します。

  8. 分散トレーニングをポッド上で実行します。

  9. クリーンアップします。

gcloud コマンドを設定する

gcloud を使用して Google Cloud プロジェクトを構成します。

プロジェクト ID の変数を作成します。

export PROJECT_ID=project-id

プロジェクト ID を、gcloud のデフォルト プロジェクトとして設定します。

gcloud config set project ${PROJECT_ID}

このコマンドを新しい Cloud Shell VM で初めて実行すると、Authorize Cloud Shell ページが表示されます。ページの下部にある [Authorize] をクリックして、gcloud に認証情報を使用した API 呼び出しを許可します。

gcloud を使用して、デフォルトのゾーンを構成します。

gcloud config set compute/zone europe-west4-a

(省略可)VM ディスク イメージをキャプチャする

(すでにデータセット、パッケージがインストールされているなど)単一の TPU のトレーニングに使用した VM のディスク イメージを使用して、イメージを作成する前に、gcloud コマンドを使用して VM を停止します。

gcloud compute instances stop vm-name

次に、gcloud コマンドを使用して VM イメージを作成します。

gcloud compute images create image-name  \
    --source-disk instance-name \
    --source-disk-zone europe-west4-a \
    --family=torch-xla \
    --storage-location europe-west4

VM イメージからインスタンス テンプレートを作成する

デフォルトのインスタンス テンプレートを作成します。インスタンス テンプレートを作成する場合は、上記のステップで作成した VM イメージを使用できます。または、Google が一般に提供している PyTorch または XLA のイメージを使用できます。インスタンス テンプレートを作成するには、gcloud コマンドを使用します。

gcloud compute instance-templates create instance-template-name \
    --machine-type n1-standard-16 \
    --image-project=${PROJECT_ID} \
    --image=image-name \
    --scopes=https://www.googleapis.com/auth/cloud-platform

インスタンス テンプレートからインスタンス グループを作成する

gcloud compute instance-groups managed create instance-group-name \
    --size 4 \
    --template template-name \
    --zone europe-west4-a

Compute Engine VM に SSH 接続する

インスタンス グループを作成したら、インスタンス グループ内のいずれかのインスタンス(VM)に SSH 接続します。次のコマンドを使用して、gcloud コマンドをグループ化してインスタンス内のすべてのインスタンスを一覧表示します。

gcloud compute instance-groups list-instances instance-group-name

list-instances コマンドを実行して、表示されているインスタンスのいずれかに SSH 接続します。

gcloud compute ssh instance-name --zone=europe-west4-a

インスタンス グループ内の VM が相互に通信できることを確認する

nmap コマンドを使用して、インスタンス グループ内の VM が相互に通信できることを確認します。接続先の VM から nmap コマンドを実行します。instance-name は、インスタンス グループ内の別の VM のインスタンス名に置き換えます。

(vm)$ nmap -Pn -p 8477 instance-name
Starting Nmap 7.40 ( https://nmap.org ) at 2019-10-02 21:35 UTC
Nmap scan report for pytorch-20190923-n4tx.c.jysohntpu.internal (10.164.0.3)
Host is up (0.00034s latency).
PORT     STATE  SERVICE
8477/tcp closed unknown

STATE フィールドが filtered と表示されていない限り、ファイアウォール ルールは正しく設定されています。

Cloud TPU Pod を作成する

gcloud compute tpus create tpu-name \
    --zone=europe-west4-a \
    --network=default \
    --accelerator-type=v2-32 \
    --version=pytorch-1.13

分散トレーニングをポッドで実行する

  1. VM セッション ウィンドウで、Cloud TPU 名をエクスポートし、conda 環境を有効にします。

    (vm)$ export TPU_NAME=tpu-name
    (vm)$ conda activate torch-xla-1.13
    
  2. トレーニング スクリプトを実行します。

    (torch-xla-1.13)$ python -m torch_xla.distributed.xla_dist \
          --tpu=$TPU_NAME \
          --conda-env=torch-xla-1.13 \
          --env XLA_USE_BF16=1 \
          --env ANY_OTHER=ENV_VAR \
          -- python /usr/share/torch-xla-1.13/pytorch/xla/test/test_train_mp_imagenet.py \
          --fake_data
    

上記のコマンドを実行すると、次のような出力が表示されます(--fake_data を使用していることに注意)。このトレーニングは v3-32 TPU Pod 上で約 1 時間かかります。

2020-08-06 02:38:29  [] Command to distribute: "python" "/usr/share/torch-xla-nightly/pytorch/xla/test/test_train_mp_imagenet.py" "--fake_data"
2020-08-06 02:38:29  [] Cluster configuration: {client_workers: [{10.164.0.43, n1-standard-96, europe-west4-a, my-instance-group-hm88}, {10.164.0.109, n1-standard-96, europe-west4-a, my-instance-group-n3q2}, {10.164.0.46, n1-standard-96, europe-west4-a, my-instance-group-s0xl}, {10.164.0.49, n1-standard-96, europe-west4-a, my-instance-group-zp14}], service_workers: [{10.131.144.61, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.59, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.58, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.60, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}]}
2020-08-06 02:38:31 10.164.0.43 [0]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:31 10.164.0.43 [0]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2757      0 --:--:-- --:--:-- --:--:--  3166
2020-08-06 02:38:34 10.164.0.43 [0]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:34 10.164.0.43 [0]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2623      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:37 10.164.0.46 [2]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:37 10.164.0.46 [2]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2583      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:37 10.164.0.49 [3]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:37 10.164.0.49 [3]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2530      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:37 10.164.0.109 [1]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:37 10.164.0.109 [1]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2317      0 --:--:-- --:--:-- --:--:--  2375
2020-08-06 02:38:40 10.164.0.46 [2]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:40 10.164.0.49 [3]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:40 10.164.0.46 [2]                                  Dload  Upload   Total   Spent    Left  Speed
2020-08-06 02:38:40 10.164.0.49 [3]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2748      0 --:--:-- --:--:-- --:--:--  3166
100    19  100    19    0     0   2584      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:40 10.164.0.109 [1]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:40 10.164.0.109 [1]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2495      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:43 10.164.0.49 [3]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.49 [3]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2654      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:43 10.164.0.43 [0]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.43 [0]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2784      0 --:--:-- --:--:-- --:--:--  3166
2020-08-06 02:38:43 10.164.0.46 [2]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.46 [2]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2691      0 --:--:-- --:--:-- --:--:--  3166
2020-08-06 02:38:43 10.164.0.109 [1]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.109 [1]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2589      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/14 Epoch=1 Step=0 Loss=6.87500 Rate=258.47 GlobalRate=258.47 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/15 Epoch=1 Step=0 Loss=6.87500 Rate=149.45 GlobalRate=149.45 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] Epoch 1 train begin 02:38:52
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:1/0 Epoch=1 Step=0 Loss=6.87500 Rate=25.72 GlobalRate=25.72 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/6 Epoch=1 Step=0 Loss=6.87500 Rate=89.01 GlobalRate=89.01 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/1 Epoch=1 Step=0 Loss=6.87500 Rate=64.15 GlobalRate=64.15 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/2 Epoch=1 Step=0 Loss=6.87500 Rate=93.19 GlobalRate=93.19 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/7 Epoch=1 Step=0 Loss=6.87500 Rate=58.78 GlobalRate=58.78 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] Epoch 1 train begin 02:38:56
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:1/8 Epoch=1 Step=0 Loss=6.87500 Rate=100.43 GlobalRate=100.43 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/13 Epoch=1 Step=0 Loss=6.87500 Rate=66.83 GlobalRate=66.83 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/11 Epoch=1 Step=0 Loss=6.87500 Rate=64.28 GlobalRate=64.28 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/10 Epoch=1 Step=0 Loss=6.87500 Rate=73.17 GlobalRate=73.17 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/9 Epoch=1 Step=0 Loss=6.87500 Rate=27.29 GlobalRate=27.29 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/12 Epoch=1 Step=0 Loss=6.87500 Rate=110.29 GlobalRate=110.29 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/20 Epoch=1 Step=0 Loss=6.87500 Rate=100.85 GlobalRate=100.85 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/22 Epoch=1 Step=0 Loss=6.87500 Rate=93.52 GlobalRate=93.52 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/23 Epoch=1 Step=0 Loss=6.87500 Rate=165.86 GlobalRate=165.86 Time=02:38:57

クリーンアップ

このチュートリアルで使用したリソースについて、Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。

  1. Compute Engine VM との接続を解除します。

    (vm)$ exit
    
  2. インスタンス グループを削除します。

    gcloud compute instance-groups managed delete instance-group-name
    
  3. TPU Pod を削除します。

    gcloud compute tpus delete ${TPU_NAME} --zone=europe-west4-a
    
  4. インスタンス グループ テンプレートを削除します。

    gcloud compute instance-templates delete instance-template-name
    
  5. [省略可] VM ディスク イメージを削除します。

    gcloud compute images delete image-name
    

次のステップ

次のように PyTorch colabs を試す