En este instructivo, se muestra cómo entrenar modelos de difusión en TPU con PyTorch Lightning y Pytorch XLA.
Objetivos
- Crear una Cloud TPU
- Instala PyTorch Lightning
- Clona el repositorio de difusión
- Prepara el conjunto de datos Imagenette
- Ejecuta la secuencia de comandos de entrenamiento
Costos
En este documento, usarás los siguientes componentes facturables de Google Cloud:
- Compute Engine
- Cloud TPU
Para generar una estimación de costos en función del uso previsto, usa la calculadora de precios.
Antes de comenzar
Antes de comenzar este instructivo, verifica que tu proyecto de Google Cloud esté configurado correctamente.
- Accede a tu cuenta de Google Cloud. Si eres nuevo en Google Cloud, crea una cuenta para evaluar el rendimiento de nuestros productos en situaciones reales. Los clientes nuevos también obtienen $300 en créditos gratuitos para ejecutar, probar y, además, implementar cargas de trabajo.
-
En la página del selector de proyectos de la consola de Google Cloud, selecciona o crea un proyecto de Google Cloud.
-
Asegúrate de que la facturación esté habilitada para tu proyecto de Google Cloud.
-
En la página del selector de proyectos de la consola de Google Cloud, selecciona o crea un proyecto de Google Cloud.
-
Asegúrate de que la facturación esté habilitada para tu proyecto de Google Cloud.
En esta explicación, se usan componentes facturables de Google Cloud. Consulta la página de precios de Cloud TPU para calcular los costos. Asegúrate de limpiar los recursos que crees cuando hayas terminado de usarlos para evitar cargos innecesarios.
Crear una Cloud TPU
Estas instrucciones funcionan en TPU de host único y de varios hosts. En este instructivo, se usa una versión v4-128, pero funciona de manera similar en todos los tamaños de acelerador.
Configura algunas variables de entorno para facilitar el uso de los comandos.
export ZONE=us-central2-b export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v4-128 export RUNTIME_VERSION=tpu-ubuntu2204-base export TPU_NAME=your_tpu_name
Crea una Cloud TPU.
gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION} \ --subnetwork=tpusubnet
Instala el software requerido
Instala los paquetes necesarios junto con la última versión v2.2.0 de PyTorch/XLA.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=us-central2-b \ --worker=all \ --command="sudo apt-get update -y && sudo apt-get install libgl1 -y git clone https://github.com/pytorch-tpu/stable-diffusion.git cd stable-diffusion pip install -e . pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U pip install clip pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
Se corrigieron los archivos de origen para que sean compatibles con la linterna 2.2 y versiones posteriores.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=us-central2-b \ --worker=all \ --command="cd ~/stable-diffusion/ sed -i \'s/from torch._six import string_classes/string_classes = (str, bytes)/g\' src/taming-transformers/taming/data/utils.py sed -i \'s/trainer_kwargs\\[\"callbacks\"\\]/# trainer_kwargs\\[\"callbacks\"\\]/g\' main_tpu.py"
Descarga Imagenette (una versión más pequeña del conjunto de datos de Imagenet) y muévela al directorio correspondiente.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone us-central2-b \ --worker=all \ --command="wget -nv https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz tar -xf imagenette2.tgz mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_train/data mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_validation/data mv imagenette2/train/* ~/.cache/autoencoders/data/ILSVRC2012_train/data mv imagenette2/val/* ~/.cache/autoencoders/data/ILSVRC2012_validation/data"
Descarga el modelo previamente entrenado de la primera etapa.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone us-central2-b \ --worker=all \ --command="cd ~/stable-diffusion/ wget -nv -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip cd models/first_stage_models/vq-f8/ unzip -o model.zip"
Entrenar el modelo
Ejecuta el entrenamiento con el siguiente comando:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone us-central2-b \ --worker=all \ --command="python3 stable-diffusion/main_tpu.py --train --no-test --base=stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8-ss.yaml -- data.params.batch_size=32 lightning.trainer.max_epochs=5 model.params.first_stage_config.params.ckpt_path=stable-diffusion/models/first_stage_models/vq-f8/model.ckpt lightning.trainer.enable_checkpointing=False lightning.strategy.sync_module_states=False"
Limpia
Realiza una limpieza para evitar incurrir en cargos innecesarios en tu cuenta después de usar los recursos que creaste:
Usa Google Cloud CLI para borrar el recurso de Cloud TPU.
$ gcloud compute tpus delete diffusion-tutorial --zone=us-central2-b
¿Qué sigue?
Prueba los siguientes colaboradores de PyTorch:
- Comienza a usar PyTorch en Cloud TPU
- Entrenamiento de MNIST en TPU
- Entrenamiento de ResNet18 en TPU con el conjunto de datos de Cifar10
- Inferencia con el modelo ResNet50 previamente entrenado
- Transferencia de estilo neuronal rápida
- Capacitación de varios núcleos de AlexCre en Fashion MNIST
- Capacitación de núcleo único de AlexNet en Fashion MNIST