PyTorch を使用した拡散モデルのトレーニング


このチュートリアルでは、PyTorch Lightning と Pytorch XLA を使用して TPU で拡散モデルをトレーニングする方法について説明します。

目標

  • Cloud TPU の作成
  • PyTorch Lightning をインストールする
  • 拡散リポジトリのクローンを作成する
  • Imagenette データセットを準備する
  • トレーニング スクリプトを実行します

費用

このドキュメントでは、Google Cloud の次の課金対象のコンポーネントを使用します。

  • Compute Engine
  • Cloud TPU

料金計算ツールを使うと、予想使用量に基づいて費用の見積もりを生成できます。 新しい Google Cloud ユーザーは無料トライアルをご利用いただける場合があります。

始める前に

このチュートリアルを開始する前に、Google Cloud プロジェクトが正しく設定されていることを確認します。

  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  3. Make sure that billing is enabled for your Google Cloud project.

  4. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  5. Make sure that billing is enabled for your Google Cloud project.

  6. このチュートリアルでは、Google Cloud の課金対象となるコンポーネントを使用します。費用を見積もるには、Cloud TPU の料金ページを確認してください。不要な課金を回避するために、このチュートリアルを完了したら、作成したリソースを必ずクリーンアップしてください。

Cloud TPU の作成

以下の手順は、単一ホスト TPU とマルチホスト TPU の両方で機能します。このチュートリアルでは v4-128 を使用しますが、すべてのアクセラレータ サイズで同様に機能します。

コマンドを使いやすくするために、いくつかの環境変数を設定します。

export ZONE=us-central2-b
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v4-128
export RUNTIME_VERSION=tpu-ubuntu2204-base
export TPU_NAME=your_tpu_name

Cloud TPU を作成します。

gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version=${RUNTIME_VERSION} \
--subnetwork=tpusubnet

必要なソフトウェアのインストール

  1. PyTorch/XLA の最新リリース v2.2.0 とともに必要なパッケージをインストールします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone=us-central2-b \
    --worker=all \
    --command="sudo apt-get update -y && sudo apt-get install libgl1 -y
    git clone https://github.com/pytorch-tpu/stable-diffusion.git
    cd stable-diffusion
    pip install -e .
    pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U
    pip install clip
    pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
  2. torch 2.2 以降との互換性を保つようにソースファイルを修正します。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone=us-central2-b \
    --worker=all \
    --command="cd ~/stable-diffusion/
    sed -i \'s/from torch._six import string_classes/string_classes = (str, bytes)/g\' src/taming-transformers/taming/data/utils.py
    sed -i \'s/trainer_kwargs\\[\"callbacks\"\\]/# trainer_kwargs\\[\"callbacks\"\\]/g\' main_tpu.py"
  3. Imagenette(より小さなバージョンの Imagenet データセット)をダウンロードして、適切なディレクトリに移動します。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone us-central2-b \
    --worker=all \
    --command="wget -nv https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
    tar -xf  imagenette2.tgz
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_validation/data
    mv imagenette2/train/*  ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mv imagenette2/val/* ~/.cache/autoencoders/data/ILSVRC2012_validation/data"
  4. 第 1 段階の事前トレーニング済みモデルをダウンロードします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone us-central2-b \
    --worker=all \
    --command="cd ~/stable-diffusion/
    wget -nv -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
    cd  models/first_stage_models/vq-f8/
    unzip -o model.zip"

モデルのトレーニング

次のコマンドを使用してトレーニングを実行します。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--zone us-central2-b \
--worker=all \
--command="python3 stable-diffusion/main_tpu.py --train --no-test --base=stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8-ss.yaml -- data.params.batch_size=32 lightning.trainer.max_epochs=5 model.params.first_stage_config.params.ckpt_path=stable-diffusion/models/first_stage_models/vq-f8/model.ckpt lightning.trainer.enable_checkpointing=False lightning.strategy.sync_module_states=False"

クリーンアップ

作成したリソースを使用した後、アカウントに不要な請求が発生しないようにクリーンアップを行います。

Google Cloud CLI を使用して Cloud TPU リソースを削除します。

  $  gcloud compute tpus delete diffusion-tutorial --zone=us-central2-b
  

次のステップ

次のように PyTorch colabs を試す