Como treinar modelos de difusão com o PyTorch


Neste tutorial, mostramos como treinar modelos de difusão em TPUs usando o PyTorch Lightning e Pytorch XLA.

Objetivos

  • Criar uma Cloud TPU
  • Instalar o PyTorch Lightning
  • Clonar o repositório de difusão
  • Preparar o conjunto de dados Imagenette
  • Executar o script de treinamento

Custos

Neste documento, você usará os seguintes componentes faturáveis do Google Cloud:

  • Compute Engine
  • Cloud TPU

Para gerar uma estimativa de custo baseada na projeção de uso deste tutorial, use a calculadora de preços. Novos usuários do Google Cloud podem estar qualificados para uma avaliação gratuita.

Antes de começar

Antes de começar o tutorial, verifique se o projeto do Google Cloud foi configurado corretamente.

  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  3. Make sure that billing is enabled for your Google Cloud project.

  4. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  5. Make sure that billing is enabled for your Google Cloud project.

  6. Este tutorial usa componentes faturáveis do Google Cloud. Consulte a página de preços da Cloud TPU para fazer uma estimativa dos custos. Para evitar cobranças desnecessárias, não se esqueça de apagar os recursos criados ao terminar de usá-los.

Criar uma Cloud TPU

Este tutorial usa um v4-8, mas funciona de maneira semelhante em todos os tamanhos de acelerador em um único host.

Configure algumas variáveis de ambiente para facilitar o uso dos comandos.

export ZONE=us-central2-b
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v4-8
export RUNTIME_VERSION=tpu-ubuntu2204-base
export TPU_NAME=your_tpu_name

Crie uma Cloud TPU.

gcloud compute tpus tpu-vm create ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version=${RUNTIME_VERSION} \
--subnetwork=tpusubnet

Instalar o software necessário

  1. Instale os pacotes necessários com a versão v2.4.0 mais recente do PyTorch/XLA.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="sudo apt-get update -y && sudo apt-get install libgl1 -y
    git clone https://github.com/pytorch-tpu/stable-diffusion.git
    cd stable-diffusion
    pip install -r requirements.txt
    pip install -e .
    pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U
    pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
    pip install clip
    pip install torch~=2.4.0 torch_xla[tpu]~=2.4.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
  2. Os arquivos de origem foram corrigidos para serem compatíveis com a Toch 2.2 e versões mais recentes.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="cd stable-diffusion/
    sed -i 's/from torch._six import string_classes/string_classes = (str, bytes)/g' src/taming-transformers/taming/data/utils.py
    sed -i 's/trainer_kwargs\\[\"callbacks\"\\]/# trainer_kwargs\\[\"callbacks\"\\]/g' main_tpu.py"
  3. Faça o download do Imagenette (uma versão menor do conjunto de dados do Imagenet) e movê-lo para o diretório apropriado.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="wget -nv https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
    tar -xf  imagenette2.tgz
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_validation/data
    mv imagenette2/train/*  ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mv imagenette2/val/* ~/.cache/autoencoders/data/ILSVRC2012_validation/data"
  4. Faça o download do modelo pré-treinado da primeira fase.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="cd stable-diffusion/
    wget -nv -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
    cd  models/first_stage_models/vq-f8/
    unzip -o model.zip"

Treine o modelo

Execute o treinamento com o comando a seguir. O processo de treinamento leva cerca de 30 minutos na v4-8.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--command="python3 stable-diffusion/main_tpu.py --train --no-test --base=stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8-ss.yaml -- data.params.batch_size=32 lightning.trainer.max_epochs=5 model.params.first_stage_config.params.ckpt_path=stable-diffusion/models/first_stage_models/vq-f8/model.ckpt lightning.trainer.enable_checkpointing=False lightning.strategy.sync_module_states=False"

Limpar

Execute uma limpeza para evitar cobranças desnecessárias na sua conta depois de usar os recursos criados:

Use a CLI do Google Cloud para excluir o recurso do Cloud TPU.

  $  gcloud compute tpus tpu-vm delete diffusion-tutorial --zone=us-central2-b
  

A seguir