このチュートリアルでは、PyTorch Lightning と Pytorch XLA を使用して TPU で拡散モデルをトレーニングする方法について説明します。
目標
- Cloud TPU の作成
- PyTorch Lightning をインストールする
- 拡散リポジトリのクローンを作成する
- Imagenette データセットを準備する
- トレーニング スクリプトを実行します
費用
このドキュメントでは、Google Cloud の次の課金対象のコンポーネントを使用します。
- Compute Engine
- Cloud TPU
料金計算ツールを使うと、予想使用量に基づいて費用の見積もりを生成できます。
始める前に
このチュートリアルを開始する前に、Google Cloud プロジェクトが正しく設定されていることを確認します。
- Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Google Cloud project.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Google Cloud project.
このチュートリアルでは、Google Cloud の課金対象となるコンポーネントを使用します。費用を見積もるには、Cloud TPU の料金ページを確認してください。不要な課金を回避するために、このチュートリアルを完了したら、作成したリソースを必ずクリーンアップしてください。
Cloud TPU の作成
以下の手順は、単一ホスト TPU とマルチホスト TPU の両方で機能します。このチュートリアルでは v4-128 を使用しますが、すべてのアクセラレータ サイズで同様に機能します。
コマンドを使いやすくするために、いくつかの環境変数を設定します。
export ZONE=us-central2-b export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v4-128 export RUNTIME_VERSION=tpu-ubuntu2204-base export TPU_NAME=your_tpu_name
Cloud TPU を作成します。
gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION} \ --subnetwork=tpusubnet
必要なソフトウェアのインストール
PyTorch/XLA の最新リリース v2.2.0 とともに必要なパッケージをインストールします。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=us-central2-b \ --worker=all \ --command="sudo apt-get update -y && sudo apt-get install libgl1 -y git clone https://github.com/pytorch-tpu/stable-diffusion.git cd stable-diffusion pip install -e . pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U pip install clip pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
torch 2.2 以降との互換性を保つようにソースファイルを修正します。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=us-central2-b \ --worker=all \ --command="cd ~/stable-diffusion/ sed -i \'s/from torch._six import string_classes/string_classes = (str, bytes)/g\' src/taming-transformers/taming/data/utils.py sed -i \'s/trainer_kwargs\\[\"callbacks\"\\]/# trainer_kwargs\\[\"callbacks\"\\]/g\' main_tpu.py"
Imagenette(より小さなバージョンの Imagenet データセット)をダウンロードして、適切なディレクトリに移動します。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone us-central2-b \ --worker=all \ --command="wget -nv https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz tar -xf imagenette2.tgz mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_train/data mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_validation/data mv imagenette2/train/* ~/.cache/autoencoders/data/ILSVRC2012_train/data mv imagenette2/val/* ~/.cache/autoencoders/data/ILSVRC2012_validation/data"
第 1 段階の事前トレーニング済みモデルをダウンロードします。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone us-central2-b \ --worker=all \ --command="cd ~/stable-diffusion/ wget -nv -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip cd models/first_stage_models/vq-f8/ unzip -o model.zip"
モデルのトレーニング
次のコマンドを使用してトレーニングを実行します。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone us-central2-b \ --worker=all \ --command="python3 stable-diffusion/main_tpu.py --train --no-test --base=stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8-ss.yaml -- data.params.batch_size=32 lightning.trainer.max_epochs=5 model.params.first_stage_config.params.ckpt_path=stable-diffusion/models/first_stage_models/vq-f8/model.ckpt lightning.trainer.enable_checkpointing=False lightning.strategy.sync_module_states=False"
クリーンアップ
作成したリソースを使用した後、アカウントに不要な請求が発生しないようにクリーンアップを行います。
Google Cloud CLI を使用して Cloud TPU リソースを削除します。
$ gcloud compute tpus delete diffusion-tutorial --zone=us-central2-b
次のステップ
次のように PyTorch colabs を試す
- Cloud TPU での PyTorch のスタートガイド
- TPU 上で MNIST をトレーニングする
- Cifar10 データセットを使用して TPU 上で ResNet18 をトレーニングする
- Pretrained ResNet50 モデルを使用して推論する
- 高速なニューラルスタイル変換
- Fashion MNIST を使用した AlexNet のマルチコア トレーニング
- Fashion MNIST を使用した AlexNet のシングルコア トレーニング