Cloud TPU PyTorch/XLA ユーザーガイド

PyTorch/XLA を使用して ML ワークロードを実行する

このガイドでは、PyTorch を使用して v4 TPU で簡単な計算を行う方法について説明します。

基本設定

  1. Pytorch 2.0 の TPU VM ランタイムを実行している v4 TPU を使用して TPU VM を作成します。

      gcloud compute tpus tpu-vm create your-tpu-name \
      --zone=us-central2-b \
      --accelerator-type=v4-8 \
      --version=tpu-vm-v4-pt-2.0
  2. SSH を使用して TPU VM に接続します。

      gcloud compute tpus tpu-vm ssh your-tpu-name \
      --zone=us-central2-b \
      --accelerator-type=v4-8
  3. PJRT または XRT TPU デバイス構成を設定します。

    PJRT

        (vm)$ export PJRT_DEVICE=TPU
     

    XRT

        (vm)$ export XRT_TPU_CONFIG="localservice;0;localhost:51011"
     

  4. Cloud TPU v4 でトレーニングする場合は、次の環境変数も設定します。

      (vm)$ export TPU_NUM_DEVICES=4

簡単な計算を実行する

  1. TPU VM で Python インタープリタを起動します。

    (vm)$ python3
  2. 次の PyTorch パッケージをインポートします。

    import torch
    import torch_xla.core.xla_model as xm
  3. 次のスクリプトを入力します。

    dev = xm.xla_device()
    t1 = torch.randn(3,3,device=dev)
    t2 = torch.randn(3,3,device=dev)
    print(t1 + t2)

    次の出力が表示されます。

    tensor([[-0.2121,  1.5589, -0.6951],
           [-0.7886, -0.2022,  0.9242],
           [ 0.8555, -1.8698,  1.4333]], device='xla:1')
    

単一デバイス TPU での Resnet の実行

この時点で、任意の PyTorch / XLA コードを実行できます。たとえば、架空のデータを使用して ResNet モデルを実行できます。

(vm)$ git clone --recursive https://github.com/pytorch/xla.git
(vm)$ python3 xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1

ResNet サンプルでは 1 エポックのトレーニングを行い、所要時間は約 7 分です。この場合、次のような出力が返されます。

Epoch 1 test end 20:57:52, Accuracy=100.00 Max Accuracy: 100.00%

ResNet トレーニングが終了したら、TPU VM を削除します。

(vm)$ exit
$ gcloud compute tpus tpu-vm delete tpu-name \
--zone=zone

削除には数分かかることがあります。gcloud compute tpus list --zone=${ZONE} を実行して、リソースが削除されたことを確認します。

高度なトピック

サイズが大きく、頻繁な割り当てがあるモデルの場合、tcmalloc は C/C++ ランタイム関数 malloc と比較してパフォーマンスを向上させます。TPU VM で使用されるデフォルトの malloctcmalloc です。LD_PRELOAD 環境変数の設定を解除すると、TPU VM ソフトウェアで標準の malloc を使用できます。

   (vm)$ unset LD_PRELOAD

上記の例(簡単な計算と ResNet50)では、PyTorch/XLA プログラムが Python インタープリタと同じプロセスでローカル XRT サーバーを起動します。XRT ローカル サービスを別のプロセスで開始することもできます。

(vm)$ python3 -m torch_xla.core.xrt_run_server --port 51011 --restart

このアプローチの利点は、トレーニングの実行中、コンパイル キャッシュが保持されることです。XLA サーバーを別のプロセスで実行する場合、サーバー側のロギング情報は /tmp/xrt_server_log に書き込まれます。

(vm)$ ls /tmp/xrt_server_log/
server_20210401-031010.log

TPU VM パフォーマンスのプロファイリング

TPU VM でのモデルのプロファイリングの詳細については、PyTorch XLA パフォーマンス プロファイリングをご覧ください。

PyTorch/XLA TPU Pod の例

TPU VM Pod 上で PyTorch/XLA を実行するための設定情報と例については、PyTorch TPU VM Pod をご覧ください。

TPU VM 上の Docker

このセクションでは、PyTorch/XLA がプリインストールされた TPU VM 上で Docker を実行する方法について説明します。

利用可能な Docker イメージ

使用可能なすべての TPU VM Docker イメージを確認するには、GitHub の README をご覧ください。

TPU VM で Docker イメージを実行する

(tpuvm): sudo docker pull gcr.io/tpu-pytorch/xla:nightly_3.8_tpuvm
(tpuvm): sudo docker run --privileged  --shm-size 16G --name tpuvm_docker -it -d  gcr.io/tpu-pytorch/xla:nightly_3.8_tpuvm
(tpuvm): sudo docker exec --privileged -it tpuvm_docker /bin/bash
(pytorch) root:/#

libtpu を確認する

libtpu がインストールされていることを確認するには、次のコマンドを実行します。

(pytorch) root:/# ls /root/anaconda3/envs/pytorch/lib/python3.8/site-packages/ | grep libtpu
その結果、次のような出力が生成されます。
libtpu
libtpu_nightly-0.1.dev20220518.dist-info

結果が表示されない場合は、次のコマンドを使用して、対応する libtpu を手動でインストールできます。

(pytorch) root:/# pip install torch_xla[tpuvm]

tcmalloc の所有権を証明してください

tcmalloc は、TPU VM で使用するデフォルトの Malloc です。詳細については、このセクションをご覧ください。このライブラリは、新しい TPU VM Docker イメージにプリインストールする必要がありますが、手動で確認することをおすすめします。次のコマンドを実行して、ライブラリがインストールされていることを確認できます。

(pytorch) root:/# echo $LD_PRELOAD
その結果、次のような出力が生成されます。
/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4

LD_PRELOAD が設定されていない場合は、手動で実行できます。

(pytorch) root:/# sudo apt-get install -y google-perftools
(pytorch) root:/# export LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4"

デバイスを確認する

次のコマンドを実行して、TPU VM デバイスが使用可能であることを確認します。

(pytorch) root:/# ls /dev | grep accel
次の結果が生成されます。
accel0
accel1
accel2
accel3

結果が表示されない場合は、コンテナを --privileged フラグで開始していない可能性があります。

モデルを実行する

次のコマンドを実行して、TPU VM デバイスが使用可能かどうかを確認できます。

(pytorch) root:/# export XRT_TPU_CONFIG="localservice;0;localhost:51011"
(pytorch) root:/# python3 pytorch/xla/test/test_train_mp_imagenet.py --fake_data --num_epochs 1