PyTorch を使用した Cloud TPU での DLRM のトレーニング

このチュートリアルでは、Cloud TPU で Facebook Research DLRM をトレーニングする方法について説明します。

目標

  • PyTorch 環境を作成および構成する
  • 架空データでトレーニング ジョブを実行する
  • (省略可)Criteo Kaggle データセットでトレーニングする

料金

このチュートリアルでは、Google Cloud の課金対象となる以下のコンポーネントを使用します。

  • Compute Engine
  • Cloud TPU

料金計算ツールを使うと、予想使用量に基づいて費用の見積もりを出すことができます。新しい Google Cloud ユーザーは無料トライアルをご利用いただける場合があります。

始める前に

このチュートリアルを開始する前に、Google Cloud プロジェクトが正しく設定されていることを確認します。

  1. Google Cloud アカウントにログインします。Google Cloud を初めて使用する場合は、アカウントを作成して、実際のシナリオでの Google プロダクトのパフォーマンスを評価してください。新規のお客様には、ワークロードの実行、テスト、デプロイができる無料クレジット $300 分を差し上げます。
  2. Google Cloud Console の [プロジェクト セレクタ] ページで、Google Cloud プロジェクトを選択または作成します。

    プロジェクト セレクタに移動

  3. Cloud プロジェクトに対して課金が有効になっていることを確認します。プロジェクトに対して課金が有効になっていることを確認する方法を学習する

  4. Google Cloud Console の [プロジェクト セレクタ] ページで、Google Cloud プロジェクトを選択または作成します。

    プロジェクト セレクタに移動

  5. Cloud プロジェクトに対して課金が有効になっていることを確認します。プロジェクトに対して課金が有効になっていることを確認する方法を学習する

  6. このチュートリアルでは、Google Cloud の課金対象となるコンポーネントを使用します。費用を見積もるには、Cloud TPU の料金ページを確認してください。不要な課金を回避するために、このチュートリアルを完了したら、作成したリソースを必ずクリーンアップしてください。

Compute Engine インスタンスを設定する

  1. Cloud Shell ウィンドウを開きます。

    Cloud Shell を開く

  2. プロジェクト ID の変数を作成します。

    export PROJECT_ID=project-id
    
  3. Cloud TPU を作成するプロジェクトを使用するように gcloud コマンドライン ツールを構成します。

    gcloud config set project ${PROJECT_ID}
    

    このコマンドを新しい Cloud Shell VM で初めて実行すると、Authorize Cloud Shell ページが表示されます。ページの下部にある [Authorize] をクリックして、gcloud に認証情報を使用した GCP API の呼び出しを許可します。

  4. Cloud Shell から、このチュートリアルで必要となる Compute Engine リソースを起動します。 注: Criteo Kaggle データセットでトレーニングを行う場合は、n1-highmem-96 machine-type を使用することをおすすめします。

    gcloud compute instances create dlrm-tutorial \
    --zone=us-central1-a \
    --machine-type=n1-standard-64 \
    --image-family=torch-xla \
    --image-project=ml-images  \
    --boot-disk-size=200GB \
    --scopes=https://www.googleapis.com/auth/cloud-platform
    
  5. 新しい Compute Engine インスタンスに接続します。

    gcloud compute ssh dlrm-tutorial --zone=us-central1-a
    

Cloud TPU リソースを起動する

  1. Compute Engine 仮想マシンから、次のコマンドを使用して Cloud TPU リソースを起動します。

    (vm) $ gcloud compute tpus create dlrm-tutorial \
    --zone=us-central1-a \
    --network=default \
    --version=pytorch-1.9  \
    --accelerator-type=v3-8
    
  2. Cloud TPU リソースの IP アドレスを識別します。

    (vm) $ gcloud compute tpus list --zone=us-central1-a
    

PyTorch 環境を作成および構成する

  1. conda 環境を開始します。

    (vm) $ conda activate torch-xla-1.9
    
  2. Cloud TPU リソースの環境変数を構成します。

    (vm) $ export TPU_IP_ADDRESS=ip-address
    
    (vm) $ export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
    

架空データでトレーニング ジョブを実行する

  1. 依存関係をインストールします。

    (vm) $ pip install onnx
    
  2. モデルをランダムデータで実行します。これには 5~10 分かかります。

    (vm) $ python /usr/share/torch-xla-1.9/tpu-examples/deps/dlrm/dlrm_tpu_runner.py \
        --arch-embedding-size=1000000-1000000-1000000-1000000-1000000-1000000-1000000-1000000 \
        --arch-sparse-feature-size=64 \
        --arch-mlp-bot=512-512-64 \
        --arch-mlp-top=1024-1024-1024-1 \
        --arch-interaction-op=dot \
        --lr-num-warmup-steps=10 \
        --lr-decay-start-step=10 \
        --mini-batch-size=2048 \
        --num-batches=1000 \
        --data-generation='random' \
        --numpy-rand-seed=727 \
        --print-time \
        --print-freq=100 \
        --num-indices-per-lookup=100 \
        --use-tpu \
        --num-workers 7 \
        --num-indices-per-lookup-fixed \
        --tpu-model-parallel-group-len=8 \
        --tpu-metrics-debug \
        --tpu-cores=8
    

(省略可)Criteo Kaggle データセットでトレーニングする

この手順は省略可能です。Criteo Kaggle データセットでトレーニングを行う場合にのみ実行してください。

  1. データセットをダウンロードします。

    こちらの手順に沿って、Criteo Kaggle データセットからデータセットをダウンロードします。ダウンロードが完了したら、./criteo-kaggle/ という名前のディレクトリに dac.tar.gz ファイルをコピーします。tar -xzvf コマンドを使用して、./critero-kaggle ディレクトリにある tar.gz ファイルの内容を抽出します。

     (vm) $ mkdir criteo-kaggle
     (vm) $ cd criteo-kaggle
     (vm) $ # Download dataset from above link here.
     (vm) $ tar -xzvf dac.tar.gz
     (vm) $ cd ..
    
  2. データセットの前処理を行います。

    次のスクリプトを起動して Criteo データセットを前処理します。このスクリプトでは、kaggleAdDisplayChallenge_processed.npz という名前のファイルが生成されます。また、データセットの前処理には 3 時間以上かかります。

    (vm) $ python /usr/share/torch-xla-1.9/tpu-examples/deps/dlrm/dlrm_data_pytorch.py \
        --data-generation=dataset \
        --data-set=kaggle \
        --raw-data-file=criteo-kaggle/train.txt \
        --mini-batch-size=128 \
        --memory-map \
        --test-mini-batch-size=16384 \
        --test-num-workers=4
    
  3. 前処理が成功したことを確認します。

    criteo-kaggle ディレクトリに、kaggleAdDisplayChallenge_processed.npz ファイルがあることを確認します。

  4. 前処理済みの Criteo Kaggle データセットで、トレーニング スクリプトを実行します。

    (vm) $ python /usr/share/torch-xla-1.9/tpu-examples/deps/dlrm/dlrm_tpu_runner.py \
        --arch-sparse-feature-size=16 \
        --arch-mlp-bot="13-512-256-64-16" \
        --arch-mlp-top="512-256-1" \
        --data-generation=dataset \
        --data-set=kaggle \
        --raw-data-file=criteo-kaggle/train.txt \
        --processed-data-file=criteo-kaggle/kaggleAdDisplayChallenge_processed.npz \
        --loss-function=bce \
        --round-targets=True \
        --learning-rate=0.1 \
        --mini-batch-size=128 \
        --print-freq=1024 \
        --print-time \
        --test-mini-batch-size=16384 \
        --test-num-workers=4 \
        --memory-map \
        --test-freq=101376 \
        --use-tpu \
        --num-indices-per-lookup=1 \
        --num-indices-per-lookup-fixed \
        --tpu-model-parallel-group-len 8 \
        --tpu-metrics-debug \
        --tpu-cores=8
    

    トレーニングは 2 時間以上で、精度は 78.75% 以上で完了するはずです。

クリーンアップ

作成したリソースを使用した後、アカウントに不要な請求が発生しないようにクリーンアップを行います。

  1. Compute Engine インスタンスとの接続を切断していない場合は切断します。

    (vm) $ exit
    

    プロンプトが user@projectname に変わります。これは、現在、Cloud Shell 内にいることを示しています。

  2. Cloud Shell で、gcloud コマンドライン ツールを使用して、Compute Engine インスタンスを削除します。

    $ gcloud compute instances delete dlrm-tutorial --zone=us-central1-a
    
  3. gcloud コマンドライン ツールを使用して、Cloud TPU リソースを削除します。

    $ gcloud compute tpus delete dlrm-tutorial --zone=us-central1-a
    

次のステップ

次のように PyTorch colabs を試す