Addestramento di modelli di diffusione con PyTorch


Questo tutorial mostra come addestrare modelli di diffusione sulle TPU utilizzando PyTorch Lightning e Pytorch XLA.

Obiettivi

  • Crea una Cloud TPU
  • Installa PyTorch Lightning
  • Clona il repository di diffusione
  • Prepara il set di dati Imagenette
  • Esegui lo script di addestramento

Costi

In questo documento utilizzi i seguenti componenti fatturabili di Google Cloud:

  • Compute Engine
  • Cloud TPU

Per generare una stima dei costi basata sull'utilizzo previsto, utilizza il Calcolatore prezzi. I nuovi utenti di Google Cloud potrebbero essere idonei per una prova gratuita.

Prima di iniziare

Prima di iniziare questo tutorial, controlla che il progetto Google Cloud sia configurato correttamente.

  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  3. Make sure that billing is enabled for your Google Cloud project.

  4. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  5. Make sure that billing is enabled for your Google Cloud project.

  6. Questa procedura dettagliata utilizza i componenti fatturabili di Google Cloud. Consulta la pagina Prezzi per Cloud TPU per stimare i costi. Assicurati di pulire che crei una volta terminato per evitare inutili addebiti.

Crea una Cloud TPU

Questo tutorial utilizza una versione 4-8, ma funziona in modo simile su tutte le dimensioni dell'acceleratore in un singolo host.

Configura alcune variabili di ambiente per semplificare l'utilizzo dei comandi.

export ZONE=us-central2-b
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v4-8
export RUNTIME_VERSION=tpu-ubuntu2204-base
export TPU_NAME=your_tpu_name

Creare una Cloud TPU.

gcloud compute tpus tpu-vm create ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version=${RUNTIME_VERSION} \
--subnetwork=tpusubnet

Installa il software richiesto

  1. Installa i pacchetti richiesti insieme alla release più recente di PyTorch/XLA v2.4.0.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="sudo apt-get update -y && sudo apt-get install libgl1 -y
    git clone https://github.com/pytorch-tpu/stable-diffusion.git
    cd stable-diffusion
    pip install -r requirements.txt
    pip install -e .
    pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U
    pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
    pip install clip
    pip install torch~=2.4.0 torch_xla[tpu]~=2.4.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
  2. Correggi i file sorgente affinché siano compatibili con Torch 2.2 e versioni successive.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="cd stable-diffusion/
    sed -i 's/from torch._six import string_classes/string_classes = (str, bytes)/g' src/taming-transformers/taming/data/utils.py
    sed -i 's/trainer_kwargs\\[\"callbacks\"\\]/# trainer_kwargs\\[\"callbacks\"\\]/g' main_tpu.py"
  3. Scarica Imagenette (una versione più piccola del set di dati Imagenet) e lo spostiamo nella directory appropriata.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="wget -nv https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
    tar -xf  imagenette2.tgz
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_validation/data
    mv imagenette2/train/*  ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mv imagenette2/val/* ~/.cache/autoencoders/data/ILSVRC2012_validation/data"
  4. Scarica il modello preaddestrato della prima fase.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="cd stable-diffusion/
    wget -nv -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
    cd  models/first_stage_models/vq-f8/
    unzip -o model.zip"

Addestra il modello

Esegui l'addestramento con il seguente comando. Tieni presente che il processo di addestramento dovrebbe richiedere circa 30 minuti nella versione 4-8.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--command="python3 stable-diffusion/main_tpu.py --train --no-test --base=stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8-ss.yaml -- data.params.batch_size=32 lightning.trainer.max_epochs=5 model.params.first_stage_config.params.ckpt_path=stable-diffusion/models/first_stage_models/vq-f8/model.ckpt lightning.trainer.enable_checkpointing=False lightning.strategy.sync_module_states=False"

Esegui la pulizia

Esegui una pulizia per evitare che al tuo account vengano addebitati costi inutili dopo l'utilizzo le risorse che hai creato:

Utilizza Google Cloud CLI per eliminare la risorsa Cloud TPU.

  $  gcloud compute tpus tpu-vm delete diffusion-tutorial --zone=us-central2-b
  

Passaggi successivi