Entrenamiento de modelos de difusión con PyTorch


En este instructivo, se muestra cómo entrenar modelos de difusión en TPU con PyTorch Lightning y Pytorch XLA.

Objetivos

  • Crear una Cloud TPU
  • Instala PyTorch Lightning
  • Clona el repositorio de difusión
  • Prepara el conjunto de datos de Imagenette
  • Ejecuta la secuencia de comandos de entrenamiento

Costos

En este documento, usarás los siguientes componentes facturables de Google Cloud:

  • Compute Engine
  • Cloud TPU

Para generar una estimación de costos en función del uso previsto, usa la calculadora de precios. Es posible que los usuarios nuevos de Google Cloud califiquen para obtener una prueba gratuita.

Antes de comenzar

Antes de comenzar este instructivo, verifica que tu proyecto de Google Cloud esté configurado correctamente.

  1. Accede a tu cuenta de Google Cloud. Si eres nuevo en Google Cloud, crea una cuenta para evaluar el rendimiento de nuestros productos en situaciones reales. Los clientes nuevos también obtienen $300 en créditos gratuitos para ejecutar, probar y, además, implementar cargas de trabajo.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  3. Asegúrate de que la facturación esté habilitada para tu proyecto de Google Cloud.

  4. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  5. Asegúrate de que la facturación esté habilitada para tu proyecto de Google Cloud.

  6. En esta explicación, se usan componentes facturables de Google Cloud. Consulta la página de precios de Cloud TPU para calcular los costos. Asegúrate de limpiar los recursos que crees cuando hayas terminado de usarlos para evitar cargos innecesarios.

Crear una Cloud TPU

En este instructivo, se usa una v4-8, pero funciona de manera similar en todos los tamaños de acelerador en un solo host.

Configura algunas variables de entorno para que los comandos sean más fáciles de usar.

export ZONE=us-central2-b
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v4-8
export RUNTIME_VERSION=tpu-ubuntu2204-base
export TPU_NAME=your_tpu_name

Crea una Cloud TPU.

gcloud compute tpus tpu-vm create ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version=${RUNTIME_VERSION} \
--subnetwork=tpusubnet

Instala el software obligatorio

  1. Instala los paquetes necesarios junto con la versión más reciente de PyTorch/XLA v2.2.0.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="sudo apt-get update -y && sudo apt-get install libgl1 -y
    git clone https://github.com/pytorch-tpu/stable-diffusion.git
    cd stable-diffusion
    pip install -r requirements.txt
    pip install -e .
    pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U
    pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
    pip install clip
    pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
    
  2. Se corrigieron los archivos de origen para que sean compatibles con Torch 2.2 y versiones posteriores.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="cd stable-diffusion/
    sed -i 's/from torch._six import string_classes/string_classes = (str, bytes)/g' src/taming-transformers/taming/data/utils.py
    sed -i 's/trainer_kwargs\\[\"callbacks\"\\]/# trainer_kwargs\\[\"callbacks\"\\]/g' main_tpu.py" 
    
  3. Descarga Imagenette (una versión más pequeña del conjunto de datos de Imagenet) y muévela al directorio correspondiente.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="wget -nv https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
    tar -xf  imagenette2.tgz
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_validation/data
    mv imagenette2/train/*  ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mv imagenette2/val/* ~/.cache/autoencoders/data/ILSVRC2012_validation/data"
    
  4. Descarga la primera etapa del modelo previamente entrenado.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="cd stable-diffusion/
    wget -nv -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
    cd  models/first_stage_models/vq-f8/
    unzip -o model.zip"
    

Entrenar el modelo

Ejecuta el entrenamiento con el siguiente comando. Ten en cuenta que se espera que el proceso de entrenamiento demore alrededor de 30 minutos en las versiones 4-8.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--command="python3 stable-diffusion/main_tpu.py --train --no-test --base=stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8-ss.yaml -- data.params.batch_size=32 lightning.trainer.max_epochs=5 model.params.first_stage_config.params.ckpt_path=stable-diffusion/models/first_stage_models/vq-f8/model.ckpt lightning.trainer.enable_checkpointing=False lightning.strategy.sync_module_states=False"

Limpia

Realiza una limpieza para evitar incurrir en cargos innecesarios en tu cuenta después de usar los recursos que creaste:

Usa Google Cloud CLI para borrar el recurso de Cloud TPU.

  $  gcloud compute tpus delete diffusion-tutorial --zone=us-central2-b
  

¿Qué sigue?

Prueba los siguientes colaboradores de PyTorch: