Melatih model difusi dengan PyTorch


Tutorial ini menunjukkan cara melatih model difusi pada TPU menggunakan PyTorch Lightning dan Pytorch XLA.

Tujuan

  • Buat Cloud TPU
  • Instal PyTorch Lightning
  • Meng-clone repositori difusi
  • Menyiapkan set data Imagenette
  • Menjalankan skrip pelatihan

Biaya

Dalam dokumen ini, Anda menggunakan komponen Google Cloud yang dapat ditagih berikut:

  • Compute Engine
  • Cloud TPU

Untuk membuat perkiraan biaya berdasarkan proyeksi penggunaan Anda, gunakan kalkulator harga. Pengguna baru Google Cloud mungkin memenuhi syarat untuk mendapatkan uji coba gratis.

Sebelum memulai

Sebelum memulai tutorial ini, pastikan project Google Cloud Anda sudah disiapkan dengan benar.

  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. Di konsol Google Cloud, pada halaman pemilih project, pilih atau buat project Google Cloud.

    Buka pemilih project

  3. Make sure that billing is enabled for your Google Cloud project.

  4. Di konsol Google Cloud, pada halaman pemilih project, pilih atau buat project Google Cloud.

    Buka pemilih project

  5. Make sure that billing is enabled for your Google Cloud project.

  6. Panduan ini menggunakan komponen Google Cloud yang dapat ditagih. Lihat halaman harga Cloud TPU untuk memperkirakan biaya Anda. Pastikan untuk membersihkan resource yang Anda buat setelah selesai menggunakannya untuk menghindari biaya yang tidak perlu.

Buat Cloud TPU

Petunjuk ini berfungsi pada TPU host tunggal dan multi-host. Tutorial ini menggunakan v4-128, tetapi cara kerjanya mirip dengan semua ukuran akselerator.

Siapkan beberapa variabel lingkungan untuk mempermudah penggunaan perintah.

export ZONE=us-central2-b
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v4-128
export RUNTIME_VERSION=tpu-ubuntu2204-base
export TPU_NAME=your_tpu_name

Membuat Cloud TPU.

gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version=${RUNTIME_VERSION} \
--subnetwork=tpusubnet

Instal software yang diperlukan

  1. Instal paket yang diperlukan beserta rilis terbaru PyTorch/XLA v2.2.0.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone=us-central2-b \
    --worker=all \
    --command="sudo apt-get update -y && sudo apt-get install libgl1 -y
    git clone https://github.com/pytorch-tpu/stable-diffusion.git
    cd stable-diffusion
    pip install -e .
    pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U
    pip install clip
    pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
  2. Memperbaiki file sumber agar kompatibel dengan flash 2.2 dan yang lebih baru.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone=us-central2-b \
    --worker=all \
    --command="cd ~/stable-diffusion/
    sed -i \'s/from torch._six import string_classes/string_classes = (str, bytes)/g\' src/taming-transformers/taming/data/utils.py
    sed -i \'s/trainer_kwargs\\[\"callbacks\"\\]/# trainer_kwargs\\[\"callbacks\"\\]/g\' main_tpu.py"
  3. Download Imagenette (versi set data Imagenet yang lebih kecil) lalu pindahkan ke direktori yang sesuai.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone us-central2-b \
    --worker=all \
    --command="wget -nv https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
    tar -xf  imagenette2.tgz
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_validation/data
    mv imagenette2/train/*  ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mv imagenette2/val/* ~/.cache/autoencoders/data/ILSVRC2012_validation/data"
  4. Download model terlatih tahap pertama.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone us-central2-b \
    --worker=all \
    --command="cd ~/stable-diffusion/
    wget -nv -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
    cd  models/first_stage_models/vq-f8/
    unzip -o model.zip"

Melatih model

Jalankan pelatihan dengan perintah berikut:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--zone us-central2-b \
--worker=all \
--command="python3 stable-diffusion/main_tpu.py --train --no-test --base=stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8-ss.yaml -- data.params.batch_size=32 lightning.trainer.max_epochs=5 model.params.first_stage_config.params.ckpt_path=stable-diffusion/models/first_stage_models/vq-f8/model.ckpt lightning.trainer.enable_checkpointing=False lightning.strategy.sync_module_states=False"

Pembersihan

Lakukan pembersihan untuk menghindari timbulnya biaya yang tidak perlu pada akun Anda setelah menggunakan resource yang Anda buat:

Gunakan Google Cloud CLI untuk menghapus resource Cloud TPU.

  $  gcloud compute tpus delete diffusion-tutorial --zone=us-central2-b
  

Langkah selanjutnya

Coba kolaborasi PyTorch: