Cloud TPU での PyTorch を使用した Wav2Vec2 の事前トレーニング

このチュートリアルでは、PyTorch を使用して Cloud TPU デバイスで FairSeq の Wav2Vec2 モデルを事前トレーニングする方法を説明します。PyTorch と ImageNet データセットを使用する、TPU 用に最適化されたその他のイメージ分類モデルにも、同じパターンを適用できます。

このチュートリアルのモデルは、wav2vec 2.0: Framework for Self-Supervised Learning of Speech Representations の論文に基づいています。

目標

  • PyTorch 環境を作成して構成します。
  • オープンソースの LibriSpeech データをダウンロードします。
  • トレーニング ジョブを実行します。

料金

このチュートリアルでは、Google Cloud の課金対象となる以下のコンポーネントを使用します。

  • Compute Engine
  • Cloud TPU

料金計算ツールを使うと、予想使用量に基づいて費用の見積もりを出すことができます。新しい Google Cloud ユーザーは無料トライアルをご利用いただける場合があります。

始める前に

このチュートリアルを開始する前に、Google Cloud プロジェクトが正しく設定されていることを確認します。

  1. Google Cloud Console のプロジェクト セレクタ ページで、Google Cloud プロジェクトを選択または作成します。 注: この手順で作成するリソースを保持し続けない場合は、既存のプロジェクトを選択するのではなく、新しいプロジェクトを作成してください。チュートリアルが終了した後、そのプロジェクトを削除して、プロジェクトに関連するすべてのリソースを削除できます。
  2. プロジェクトを選択するページに移動し、Cloud プロジェクトに対して課金が有効になっていることを確認します。プロジェクトに対して課金が有効になっていることを確認する方法を確かめる
  1. Google Cloud アカウントにログインします。Google Cloud を初めて使用する場合は、アカウントを作成して、実際のシナリオでの Google プロダクトのパフォーマンスを評価してください。新規のお客様には、ワークロードの実行、テスト、デプロイができる無料クレジット $300 分を差し上げます。
  2. Google Cloud Console の [プロジェクト セレクタ] ページで、Google Cloud プロジェクトを選択または作成します。

    プロジェクト セレクタに移動

  3. Cloud プロジェクトに対して課金が有効になっていることを確認します。プロジェクトに対して課金が有効になっていることを確認する方法を学習する

  4. このチュートリアルでは、Google Cloud の課金対象となるコンポーネントを使用します。費用を見積もるには、Cloud TPU の料金ページを確認してください。不要な課金を回避するために、このチュートリアルを完了したら、作成したリソースを必ずクリーンアップしてください。

Compute Engine インスタンスを設定する

  1. Cloud Shell ウィンドウを開きます。

    Cloud Shell を開く

  2. プロジェクト ID の変数を作成します。

    export PROJECT_ID=project-id
    
  3. Cloud TPU を作成するプロジェクトを使用するように gcloud コマンドライン ツールを構成します。

    gcloud config set project ${PROJECT_ID}
    

    このコマンドを新しい Cloud Shell VM で初めて実行すると、Authorize Cloud Shell ページが表示されます。ページの下部にある [Authorize] をクリックして、gcloud に認証情報を使用した GCP API の呼び出しを許可します。

  4. Cloud Shell から、このチュートリアルに必要な Compute Engine リソースを起動します。

    gcloud compute instances create wav2vec2-tutorial \
      --zone=us-central1-a \
      --machine-type=n1-standard-64 \
      --image-family=torch-xla \
      --image-project=ml-images  \
      --boot-disk-size=200GB \
      --scopes=https://www.googleapis.com/auth/cloud-platform
    
  5. 新しい Compute Engine インスタンスに接続します。

    gcloud compute ssh wav2vec2-tutorial --zone=us-central1-a
    

Cloud TPU リソースを起動する

  1. Compute Engine 仮想マシンで、PyTorch のバージョンを設定します。

    (vm) $ export PYTORCH_VERSION=1.8.1
    
  2. 次のコマンドを使用して Cloud TPU リソースを起動します。

    (vm) $ gcloud compute tpus create w2v2-tutorial \
    --zone=us-central1-a \
    --network=default \
    --version=pytorch-1.8 \
    --accelerator-type=v3-8
    
  3. Cloud TPU リソースの IP アドレスを識別します。

    (vm) $ gcloud compute tpus list --zone=us-central1-a
    

PyTorch 環境を作成および構成する

  1. conda 環境を開始します。

    (vm) $ conda activate torch-xla-1.8.1
    
  2. Cloud TPU リソースの環境変数を構成します。

    (vm) $ export TPU_IP_ADDRESS=ip-address
    
    (vm) $ export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
    

データをダウンロードして準備する

このタスクで使用できる代替データセットについては、OpenSLR ウェブサイトをご覧ください。このチュートリアルでは、前処理時間が最も短いことから、dev-clean.tar.gz を使用します。

  1. Wav2Vec2 には、必要な依存関係がいくつかあります。ここでは、それをインストールします。

    (vm) $ pip install omegaconf hydra-core soundfile
    (vm) $ sudo apt-get install libsndfile-dev
    
  2. データセットをダウンロードします。

    (vm) $ curl https://www.openslr.org/resources/12/dev-clean.tar.gz --output dev-clean.tar.gz
    
  3. 圧縮ファイルを展開します。ファイルは、LibriSpeech フォルダに保存されます。

    (vm) $ tar xf dev-clean.tar.gz
    
  4. 最新の fairseq モデルをダウンロードしてインストールします。

    (vm) $ git clone --recursive https://github.com/pytorch/fairseq.git
    (vm) $ cd fairseq
    (vm) $ pip install --editable .
    (vm) $ cd -
  5. データセットを準備します。 このスクリプトでは、元データ(LibriSpeech/ の下)へのポインタを持つ manifest/ という名前のフォルダを作成します。

    (vm) $ python fairseq/examples/wav2vec/wav2vec_manifest.py LibriSpeech/ --dest manifest/

トレーニング ジョブを実行する

  1. LibriSpeech データでモデルを実行します。このスクリプトの実行には、約 2 時間かかります。

    (vm) $ OMP_NUM_THREADS=1 python fairseq/train.py \
     manifest/ \
     --num-batch-buckets 3 \
     --tpu \
     --max-sentences 4 \
     --max-sentences-valid 4 \
     --required-batch-size-multiple 4 \
     --distributed-world-size 8 \
     --distributed-port 12597 \
     --update-freq 1 \
     --enable-padding \
     --log-interval 5 \
     --num-workers 6 \
     --task audio_pretraining \
     --criterion wav2vec \
     --arch wav2vec2 \
     --log-keys  "['prob_perplexity','code_perplexity','temp']" \
     --quantize-targets \
     --extractor-mode default \
     --conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' \
     --final-dim 256 \
     --latent-vars 320 \
     --latent-groups 2 \
     --latent-temp '(2,0.5,0.999995)' \
     --infonce \
     --optimizer adam \
     --adam-betas '(0.9,0.98)' \
     --adam-eps 1e-06 \
     --lr-scheduler polynomial_decay \
     --total-num-update 400000 \
     --lr 0.0005 \
     --warmup-updates 32000 \
     --mask-length 10 \
     --mask-prob 0.65 \
     --mask-selection static \
     --mask-other 0 \
     --mask-channel-prob 0.1 \
     --encoder-layerdrop 0 \
     --dropout-input 0.0 \
     --dropout-features 0.0 \
     --feature-grad-mult 0.1 \
     --loss-weights '[0.1, 10]' \
     --conv-pos 128 \
     --conv-pos-groups 16 \
     --num-negatives 100 \
     --cross-sample-negatives 0 \
     --max-sample-size 250000 \
     --min-sample-size 32000 \
     --dropout 0.0 \
     --attention-dropout 0.0 \
     --weight-decay 0.01 \
     --max-tokens 1400000 \
     --max-epoch 10 \
     --save-interval 2 \
     --skip-invalid-size-inputs-valid-test \
     --ddp-backend no_c10d \
     --log-format simple

クリーンアップ

このチュートリアルで使用したリソースについて、Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを保存して個々のリソースを削除します。

  1. Compute Engine インスタンスとの接続を切断していない場合は切断します。

    (vm)$ exit
    

    プロンプトが user@projectname に変わります。これは、現在、Cloud Shell 内にいることを示しています。

  2. Cloud Shell で、gcloud コマンドライン ツールを使用して、Compute Engine VM インスタンスと TPU を削除します。

    $ gcloud compute tpus execution-groups delete w2v2-tutorial --zone=us-central1-a
    

次のステップ

Cloud TPU Pod へのスケーリング

このチュートリアルの事前トレーニング タスクを強力な Cloud TPU Pod にスケールするには、Cloud TPU Pod での PyTorch モデルのトレーニング チュートリアルをご覧ください。

次のように PyTorch colabs を試す