Questo tutorial mostra come addestrare i modelli di diffusione sulle TPU utilizzando PyTorch Lightning e Pytorch XLA.
Obiettivi
- Crea una Cloud TPU
- Installa PyTorch Lightning
- Clona il repository di diffusione
- Prepara il set di dati Imagenette
- Esegui lo script di addestramento
Costi
In questo documento utilizzi i seguenti componenti fatturabili di Google Cloud:
- Compute Engine
- Cloud TPU
Per generare una stima dei costi basata sull'utilizzo previsto,
utilizza il Calcolatore prezzi.
Prima di iniziare
Prima di iniziare questo tutorial, controlla che il progetto Google Cloud sia configurato correttamente.
- Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Google Cloud project.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Google Cloud project.
Questa procedura dettagliata utilizza i componenti fatturabili di Google Cloud. Consulta la pagina Prezzi per Cloud TPU per stimare i costi. Una volta terminato il loro utilizzo, pulisci le risorse create per evitare addebiti non necessari.
Crea una Cloud TPU
Questo tutorial utilizza una versione 4-8, ma funziona in modo simile su tutte le dimensioni dell'acceleratore in un singolo host.
Configura alcune variabili di ambiente per semplificare l'utilizzo dei comandi.
export ZONE=us-central2-b export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v4-8 export RUNTIME_VERSION=tpu-ubuntu2204-base export TPU_NAME=your_tpu_name
Crea una Cloud TPU.
gcloud compute tpus tpu-vm create ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION} \ --subnetwork=tpusubnet
Installa il software necessario
Installa i pacchetti richiesti insieme alla release più recente di PyTorch/XLA 2.5.0.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --command="sudo apt-get update -y && sudo apt-get install libgl1 -y git clone https://github.com/pytorch-tpu/stable-diffusion.git cd stable-diffusion pip install -r requirements.txt pip install -e . pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers pip install clip pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
Correggi i file di origine in modo che siano compatibili con Torch 2.2 e versioni successive.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --command="cd stable-diffusion/ sed -i 's/from torch._six import string_classes/string_classes = (str, bytes)/g' src/taming-transformers/taming/data/utils.py sed -i 's/trainer_kwargs\\[\"callbacks\"\\]/# trainer_kwargs\\[\"callbacks\"\\]/g' main_tpu.py"
Scarica ImageNet (una versione più piccola del set di dati ImageNet) e spostalo nella directory appropriata.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --command="wget -nv https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz tar -xf imagenette2.tgz mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_train/data mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_validation/data mv imagenette2/train/* ~/.cache/autoencoders/data/ILSVRC2012_train/data mv imagenette2/val/* ~/.cache/autoencoders/data/ILSVRC2012_validation/data"
Scarica il modello preaddestrato della prima fase.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --command="cd stable-diffusion/ wget -nv -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip cd models/first_stage_models/vq-f8/ unzip -o model.zip"
Addestra il modello
Esegui l'addestramento con il seguente comando. Tieni presente che la procedura di addestramento dovrebbe richiedere circa 30 minuti nella versione 4-8.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --command="python3 stable-diffusion/main_tpu.py --train --no-test --base=stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8-ss.yaml -- data.params.batch_size=32 lightning.trainer.max_epochs=5 model.params.first_stage_config.params.ckpt_path=stable-diffusion/models/first_stage_models/vq-f8/model.ckpt lightning.trainer.enable_checkpointing=False lightning.strategy.sync_module_states=False"
Esegui la pulizia
Esegui una pulizia per evitare addebiti non necessari sul tuo account dopo aver utilizzato le risorse che hai creato:
Utilizza Google Cloud CLI per eliminare la risorsa Cloud TPU.
$ gcloud compute tpus tpu-vm delete diffusion-tutorial --zone=us-central2-b
Passaggi successivi
- Addestramento di Resnet50 su Cloud TPU con PyTorch
- Risoluzione dei problemi di PyTorch su TPU
- Documentazione di Pytorch/XLA