Diffusionsmodelle mit PyTorch trainieren


In dieser Anleitung wird gezeigt, wie Sie Diffusionsmodelle auf TPUs mit PyTorch Lightning und Pytorch XLA trainieren.

Lernziele

  • Cloud TPU erstellen
  • PyTorch Lightning installieren
  • Diffusion-Repository klonen
  • Imagenette-Dataset vorbereiten
  • Trainingsskript ausführen

Kosten

In diesem Dokument verwenden Sie die folgenden kostenpflichtigen Komponenten von Google Cloud:

  • Compute Engine
  • Cloud TPU

Mit dem Preisrechner können Sie eine Kostenschätzung für Ihre voraussichtliche Nutzung vornehmen. Neuen Google Cloud-Nutzern steht möglicherweise eine kostenlose Testversion zur Verfügung.

Hinweise

Bevor Sie mit dieser Anleitung beginnen, prüfen Sie, ob Ihr Google Cloud-Projekt ordnungsgemäß eingerichtet ist.

  1. Melden Sie sich bei Ihrem Google Cloud-Konto an. Wenn Sie mit Google Cloud noch nicht vertraut sind, erstellen Sie ein Konto, um die Leistungsfähigkeit unserer Produkte in der Praxis sehen und bewerten zu können. Neukunden erhalten außerdem ein Guthaben von 300 $, um Arbeitslasten auszuführen, zu testen und bereitzustellen.
  2. Wählen Sie in der Google Cloud Console auf der Seite der Projektauswahl ein Google Cloud-Projekt aus oder erstellen Sie eines.

    Zur Projektauswahl

  3. Die Abrechnung für das Google Cloud-Projekt muss aktiviert sein.

  4. Wählen Sie in der Google Cloud Console auf der Seite der Projektauswahl ein Google Cloud-Projekt aus oder erstellen Sie eines.

    Zur Projektauswahl

  5. Die Abrechnung für das Google Cloud-Projekt muss aktiviert sein.

  6. In dieser Anleitung werden kostenpflichtige Komponenten der Google Cloud verwendet. Rufen Sie die Seite mit den Cloud TPU-Preisen auf, um Ihre Kosten abzuschätzen. Denken Sie daran, nicht mehr benötigte Ressourcen zu bereinigen, um unnötige Kosten zu vermeiden.

Cloud TPU erstellen

Diese Anleitung funktioniert sowohl auf TPUs mit einem einzelnen Host als auch mit mehreren Hosts. In dieser Anleitung wird v4-128 verwendet, die Funktionsweise ist aber für alle Beschleunigergrößen ähnlich.

Richten Sie einige Umgebungsvariablen ein, um die Befehle nutzerfreundlicher zu machen.

export ZONE=us-central2-b
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v4-128
export RUNTIME_VERSION=tpu-ubuntu2204-base
export TPU_NAME=your_tpu_name

Erstellen Sie eine Cloud TPU.

gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version=${RUNTIME_VERSION} \
--subnetwork=tpusubnet

Erforderliche Software installieren

  1. Installieren Sie die erforderlichen Pakete zusammen mit dem neuesten Release von PyTorch/XLA v2.2.0.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone=us-central2-b \
    --worker=all \
    --command="sudo apt-get update -y && sudo apt-get install libgl1 -y
    git clone https://github.com/pytorch-tpu/stable-diffusion.git
    cd stable-diffusion
    pip install -e .
    pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U
    pip install clip
    pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
    
  2. Korrigieren Sie die Quelldateien, damit sie mit Torch 2.2 und höher kompatibel sind.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone=us-central2-b \
    --worker=all \
    --command="cd ~/stable-diffusion/
    sed -i \'s/from torch._six import string_classes/string_classes = (str, bytes)/g\' src/taming-transformers/taming/data/utils.py
    sed -i \'s/trainer_kwargs\\[\"callbacks\"\\]/# trainer_kwargs\\[\"callbacks\"\\]/g\' main_tpu.py"
    
  3. Laden Sie Imagenette (eine kleinere Version des Imagenet-Datasets) herunter und verschieben Sie es in das entsprechende Verzeichnis.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone us-central2-b \
    --worker=all \
    --command="wget -nv https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
    tar -xf  imagenette2.tgz
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_validation/data
    mv imagenette2/train/*  ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mv imagenette2/val/* ~/.cache/autoencoders/data/ILSVRC2012_validation/data"
    
  4. Laden Sie das vortrainierte Modell der ersten Phase herunter.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone us-central2-b \
    --worker=all \
    --command="cd ~/stable-diffusion/
    wget -nv -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
    cd  models/first_stage_models/vq-f8/
    unzip -o model.zip"
    

Modell trainieren

Führen Sie das Training mit dem folgenden Befehl aus:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--zone us-central2-b \
--worker=all \
--command="python3 stable-diffusion/main_tpu.py --train --no-test --base=stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8-ss.yaml -- data.params.batch_size=32 lightning.trainer.max_epochs=5 model.params.first_stage_config.params.ckpt_path=stable-diffusion/models/first_stage_models/vq-f8/model.ckpt lightning.trainer.enable_checkpointing=False lightning.strategy.sync_module_states=False"

Bereinigen

Führen Sie eine Bereinigung durch, damit Ihr Konto nach der Verwendung der von Ihnen erstellten Ressourcen nicht unnötig belastet wird:

Löschen Sie die Cloud TPU-Ressource mit der Google Cloud CLI.

  $  gcloud compute tpus delete diffusion-tutorial --zone=us-central2-b
  

Nächste Schritte

Testen Sie die PyTorch Colabs: