Cloud TPU Pod での PyTorch モデルのトレーニング

このチュートリアルでは、モデルを単一の Cloud TPU(v2-8 または v3-8)から Cloud TPU Pod にスケールアップする方法を示します。TPU ポッド内の Cloud TPU アクセラレータは超高帯域幅の相互接続によって接続されており、トレーニング ジョブのスケールアップが効率的になります。

Cloud TPU Pods プロダクトの詳細については、 Cloud TPU プロダクト ページまたはこの Cloud TPU プレゼンテーションをご覧ください。

次の図に、分散クラスタの設定の概要を示します。VM のインスタンス グループが TPU Pod に接続されます。8 つの TPU コアのグループごとに 1 つの VM が必要です。VM は TPU コアにデータをフィードし、すべてのトレーニングが TPU Pod で実行されます。

画像

目標

  • PyTorch または XLA を使用したトレーニング用に Compute Engine インスタンス グループと Cloud TPU ポッドを設定する
  • Cloud TPU ポッドで PyTorch トレーニングまたは XLA トレーニングを実行する

始める前に

Cloud TPU Pod の分散トレーニングを開始する前に、モデルが単一の v2-8 または v3-8 Cloud TPU デバイスでトレーニングを受けていることを確認します。単一のデバイス上でモデルに重大なパフォーマンスの問題が生じた場合は、ベスト プラクティストラブルシューティングのガイドをご覧ください。

単一の TPU デバイスのトレーニングが完了したら、次の手順を実行して Cloud TPU Pod を設定します。

  1. gcloud コマンドを構成します。

  2. [省略可] VM ディスク イメージを VM イメージにキャプチャします。

  3. VM イメージからインスタンス テンプレートを作成します。

  4. インスタンス テンプレートからインスタンス グループを作成します。

  5. Compute Engine VM に SSH 接続します。

  6. VM 間通信を許可するファイアウォール ルールを確認します。

  7. Cloud TPU Pod を作成します。

  8. 分散トレーニングをポッド上で実行します。

  9. クリーンアップします。

gcloud コマンドを構成する

gcloud を使用して GCP プロジェクトを構成します。

プロジェクト ID の変数を作成します。

export PROJECT_ID=project-id

プロジェクト ID を gcloud のデフォルト プロジェクトとして設定する

gcloud config set project ${PROJECT_ID}

gcloud を使用してデフォルトのゾーンを構成します。

gcloud config set compute/zone us-central1-a

(省略可)VM ディスク イメージをキャプチャする

(すでにデータセット、パッケージがインストールされているなど)単一の TPU のトレーニングに使用した VM のディスク イメージを使用できます。イメージを作成する前に、gcloud コマンドを使用して VM を停止します。

gcloud compute instances stop vm-name

次に、gcloud コマンドを使用して VM イメージを作成します。

gcloud compute images create image-name  \
    --source-disk instance-name \
    --source-disk-zone us-central1-a \
    --family=torch-xla \
    --storage-location us-central1

VM イメージからインスタンス テンプレートを作成する

デフォルトのインスタンス テンプレートを作成します。インスタンス テンプレートを作成する場合は、上記の手順で作成した VM イメージを使用できます。あるいは、Google が一般に提供している PyTorch または XLA のイメージを使用できます。インスタンス テンプレートを作成するには、gcloud コマンドを使用します。

gcloud compute instance-templates create instance-template-name \
    --machine-type n1-standard-16 \
    --image-project=${PROJECT_ID} \
    --image=image-name \
    --scopes=https://www.googleapis.com/auth/cloud-platform

インスタンス テンプレートからインスタンス グループを作成する

gcloud compute instance-groups managed create instance-group-name \
    --size 4 \
    --template template-name \
    --zone us-central1-a

Compute Engine VM に SSH 接続する

インスタンス グループを作成したら、インスタンス グループ内のいずれかのインスタンス(VM)に SSH 接続します。次のコマンドを使用して、gcloud コマンドをグループ化したインスタンスの一覧を取得します。

gcloud compute instance-groups list-instances instance-group-name

list-instances コマンドを実行して、いずれかのインスタンスに SSH 接続します。

gcloud compute ssh instance-name --zone=us-central1-a

インスタンス グループ内の VM が相互に通信できることを確認する

nmap コマンドを使用して、インスタンス グループ内の VM が相互に通信できることを確認します。接続先の VM から nmap コマンドを実行します。instance-name はインスタンス グループ内の別の VM のインスタンス名に置き換えます。

(vm)$ nmap -Pn -p 8477 instance-name
Starting Nmap 7.40 ( https://nmap.org ) at 2019-10-02 21:35 UTC
Nmap scan report for pytorch-20190923-n4tx.c.jysohntpu.internal (10.164.0.3)
Host is up (0.00034s latency).
PORT     STATE  SERVICE
8477/tcp closed unknown

STATE フィールドが filtered と表示されていない限り、ファイアウォール ルールは正しく設定されています。

Cloud TPU Pod を作成する

gcloud compute tpus create tpu-name \
    --zone=us-central1-a \
    --network=default \
    --accelerator-type=v2-32 \
    --version=1.6

分散トレーニングをポッドで実行する

  1. VM セッション ウィンドウで、Cloud TPU 名をエクスポートし、conda 環境を有効にします。

    (vm)$ export TPU_NAME=tpu-name
    (vm)$ conda activate torch-xla-1.6
    
  2. トレーニング スクリプトを実行します。

    (torch-xla-1.6)$ python -m torch_xla.distributed.xla_dist \
          --tpu=$TPU_NAME \
          --conda-env=torch-xla-1.6 \
          --env XLA_USE_BF16=1 \
          --env ANY_OTHER=ENV_VAR \
          -- python /usr/share/torch-xla-1.6/pytorch/xla/test/test_train_mp_imagenet.py \
          --fake_data
    

上記のコマンドを実行すると、次のような出力が表示されます(--fake_data を使用していることに注意)。このトレーニングは v3-32 TPU Pod 上で約 1 時間かかります。

2020-08-06 02:38:29  [] Command to distribute: "python" "/usr/share/torch-xla-nightly/pytorch/xla/test/test_train_mp_imagenet.py" "--fake_data"
2020-08-06 02:38:29  [] Cluster configuration: {client_workers: [{10.164.0.43, n1-standard-96, europe-west4-a, my-instance-group-hm88}, {10.164.0.109, n1-standard-96, europe-west4-a, my-instance-group-n3q2}, {10.164.0.46, n1-standard-96, europe-west4-a, my-instance-group-s0xl}, {10.164.0.49, n1-standard-96, europe-west4-a, my-instance-group-zp14}], service_workers: [{10.131.144.61, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.59, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.58, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.60, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}]}
2020-08-06 02:38:31 10.164.0.43 [0]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:31 10.164.0.43 [0]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2757      0 --:--:-- --:--:-- --:--:--  3166
2020-08-06 02:38:34 10.164.0.43 [0]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:34 10.164.0.43 [0]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2623      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:37 10.164.0.46 [2]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:37 10.164.0.46 [2]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2583      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:37 10.164.0.49 [3]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:37 10.164.0.49 [3]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2530      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:37 10.164.0.109 [1]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:37 10.164.0.109 [1]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2317      0 --:--:-- --:--:-- --:--:--  2375
2020-08-06 02:38:40 10.164.0.46 [2]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:40 10.164.0.49 [3]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:40 10.164.0.46 [2]                                  Dload  Upload   Total   Spent    Left  Speed
2020-08-06 02:38:40 10.164.0.49 [3]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2748      0 --:--:-- --:--:-- --:--:--  3166
100    19  100    19    0     0   2584      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:40 10.164.0.109 [1]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:40 10.164.0.109 [1]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2495      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:43 10.164.0.49 [3]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.49 [3]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2654      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:43 10.164.0.43 [0]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.43 [0]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2784      0 --:--:-- --:--:-- --:--:--  3166
2020-08-06 02:38:43 10.164.0.46 [2]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.46 [2]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2691      0 --:--:-- --:--:-- --:--:--  3166
2020-08-06 02:38:43 10.164.0.109 [1]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.109 [1]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2589      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/14 Epoch=1 Step=0 Loss=6.87500 Rate=258.47 GlobalRate=258.47 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/15 Epoch=1 Step=0 Loss=6.87500 Rate=149.45 GlobalRate=149.45 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] Epoch 1 train begin 02:38:52
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:1/0 Epoch=1 Step=0 Loss=6.87500 Rate=25.72 GlobalRate=25.72 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/6 Epoch=1 Step=0 Loss=6.87500 Rate=89.01 GlobalRate=89.01 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/1 Epoch=1 Step=0 Loss=6.87500 Rate=64.15 GlobalRate=64.15 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/2 Epoch=1 Step=0 Loss=6.87500 Rate=93.19 GlobalRate=93.19 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/7 Epoch=1 Step=0 Loss=6.87500 Rate=58.78 GlobalRate=58.78 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] Epoch 1 train begin 02:38:56
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:1/8 Epoch=1 Step=0 Loss=6.87500 Rate=100.43 GlobalRate=100.43 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/13 Epoch=1 Step=0 Loss=6.87500 Rate=66.83 GlobalRate=66.83 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/11 Epoch=1 Step=0 Loss=6.87500 Rate=64.28 GlobalRate=64.28 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/10 Epoch=1 Step=0 Loss=6.87500 Rate=73.17 GlobalRate=73.17 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/9 Epoch=1 Step=0 Loss=6.87500 Rate=27.29 GlobalRate=27.29 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/12 Epoch=1 Step=0 Loss=6.87500 Rate=110.29 GlobalRate=110.29 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/20 Epoch=1 Step=0 Loss=6.87500 Rate=100.85 GlobalRate=100.85 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/22 Epoch=1 Step=0 Loss=6.87500 Rate=93.52 GlobalRate=93.52 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/23 Epoch=1 Step=0 Loss=6.87500 Rate=165.86 GlobalRate=165.86 Time=02:38:57

クリーンアップ

このチュートリアルで使用したリソースについて、Google Cloud Platform アカウントに課金されないようにする手順は次のとおりです。

  1. Compute Engine VM との接続を解除します。

    (vm)$ exit
    
  2. インスタンス グループを削除します。

    gcloud compute instance-groups managed delete instance-group-name
    
  3. TPU Pod を削除します。

    gcloud compute tpus delete ${TPU_NAME} --zone=us-central1-a
    

次のステップ

次のように PyTorch colabs を試す