Cloud TPU での PyTorch を使用した FairSeq Transformer のトレーニング

このチュートリアルでは、FairSeq バージョンの Transformer と英語からドイツ語に翻訳する WMT 18 翻訳タスクを中心に取り上げます。

目標

  • データセットを準備します。
  • トレーニング ジョブを実行します。
  • 出力結果を確認します。

費用

このチュートリアルでは、Google Cloud の課金対象となる以下のコンポーネントを使用します。

  • Compute Engine
  • Cloud TPU

料金計算ツールを使うと、予想使用量に基づいて費用の見積もりを出すことができます。新しい Google Cloud ユーザーは無料トライアルをご利用いただける場合があります。

始める前に

このチュートリアルを開始する前に、Google Cloud プロジェクトが正しく設定されていることを確認します。

  1. Google アカウントにログインします。

    Google アカウントをまだお持ちでない場合は、新しいアカウントを登録します。

  2. Cloud Console のプロジェクト セレクタページで、Cloud プロジェクトを選択または作成します。

    プロジェクト セレクタのページに移動

  3. Google Cloud プロジェクトに対して課金が有効になっていることを確認します。プロジェクトに対して課金が有効になっていることを確認する方法を学習する

  4. このチュートリアルでは、Google Cloud の課金対象となるコンポーネントを使用します。費用を見積もるには、Cloud TPU の料金ページを確認してください。不要な課金を回避するために、このチュートリアルを完了したら、作成したリソースを必ずクリーンアップしてください。

Compute Engine インスタンスを設定する

  1. Cloud Shell ウィンドウを開きます。

    Cloud Shell を開く

  2. プロジェクト ID の変数を作成します。

    export PROJECT_ID=project-id
    
  3. Cloud TPU を作成するプロジェクトを使用するように gcloud コマンドライン ツールを構成します。

    gcloud config set project ${PROJECT_ID}
    
  4. Cloud Shell から、このチュートリアルで必要となる Compute Engine リソースを起動します。

    gcloud compute --project=${PROJECT_ID} instances create transformer-tutorial \
    --zone=us-central1-a  \
    --machine-type=n1-standard-16  \
    --image-family=torch-xla \
    --image-project=ml-images  \
    --boot-disk-size=200GB \
    --scopes=https://www.googleapis.com/auth/cloud-platform
    
  5. 新しい Compute Engine インスタンスに接続します。

    gcloud compute ssh transformer-tutorial --zone=us-central1-a
    

Cloud TPU リソースを起動する

  1. Compute Engine 仮想マシンから、次のコマンドを使用して Cloud TPU リソースを起動します。

    (vm) $ gcloud compute tpus create transformer-tutorial \
    --zone=us-central1-a \
    --network=default \
    --version=pytorch-1.6 \
    --accelerator-type=v3-8
    
  2. Cloud TPU リソースの IP アドレスを識別します。

    (vm) $ gcloud compute tpus list --zone=us-central1-a
    

    IP アドレスは NETWORK_ENDPOINTS 列の下に表示されます。この IP アドレスは、PyTorch 環境を作成して構成するときに必要になります。

データのダウンロード

  1. モデルデータを格納する pytorch-tutorial-data ディレクトリを作成します。

    (vm) $ mkdir $HOME/pytorch-tutorial-data
    
  2. pytorch-tutorial-data ディレクトリに移動します。

    (vm) $ cd $HOME/pytorch-tutorial-data
    
  3. モデルデータをダウンロードします。

    (vm) $ wget https://dl.fbaipublicfiles.com/fairseq/data/wmt18_en_de_bpej32k.zip
    
  4. データを抽出します。

    (vm) $ sudo apt-get install unzip && \
    unzip wmt18_en_de_bpej32k.zip
    

PyTorch 環境を作成および構成する

  1. conda 環境を開始します。

    (vm) $ conda activate torch-xla-1.6
    
  2. Cloud TPU リソースの環境変数を構成します。

    (vm) $ export TPU_IP_ADDRESS=ip-address; \
    export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
    

モデルのトレーニング

モデルをトレーニングするには、次のスクリプトを実行します。

(vm) $ python /usr/share/torch-xla-1.6/tpu-examples/deps/fairseq/train.py \
  $HOME/pytorch-tutorial-data/wmt18_en_de_bpej32k \
  --save-interval=1 \
  --arch=transformer_vaswani_wmt_en_de_big \
  --max-target-positions=64 \
  --attention-dropout=0.1 \
  --no-progress-bar \
  --criterion=label_smoothed_cross_entropy \
  --source-lang=en \
  --lr-scheduler=inverse_sqrt \
  --min-lr 1e-09 \
  --skip-invalid-size-inputs-valid-test \
  --target-lang=de \
  --label-smoothing=0.1 \
  --update-freq=1 \
  --optimizer adam \
  --adam-betas '(0.9, 0.98)' \
  --warmup-init-lr 1e-07 \
  --lr 0.0005 \
  --warmup-updates 4000 \
  --share-all-embeddings \
  --dropout 0.3 \
  --weight-decay 0.0 \
  --valid-subset=valid \
  --max-epoch=25 \
  --input_shapes 128x64 \
  --num_cores=8 \
  --metrics_debug \
  --log_steps=100

出力結果を確認する

トレーニング ジョブが完了の後、モデル チェックポイントが次のディレクトリに置かれます。

$HOME/checkpoints

クリーンアップ

作成したリソースを使用した後、アカウントに不要な請求が発生しないようにクリーンアップを行います。

  1. Compute Engine インスタンスとの接続を切断していない場合は切断します。

    (vm) $ exit
    

    プロンプトが user@projectname に変わります。これは、現在、Cloud Shell 内にいることを示しています。

  2. Cloud Shell で、gcloud コマンドライン ツールを使用して、Compute Engine インスタンスを削除します。

    $  gcloud compute instances delete transformer-tutorial  --zone=us-central1-a
    
  3. gcloud コマンドライン ツールを使用して、Cloud TPU リソースを削除します。

    $  gcloud compute tpus delete transformer-tutorial --zone=us-central1-a
    

次のステップ

次のように PyTorch colabs を試す