In dieser Anleitung wird gezeigt, wie Sie Diffusionsmodelle auf TPUs mit PyTorch Lightning trainieren und Pytorch XLA.
Lernziele
- Cloud TPU erstellen
- PyTorch Lightning installieren
- Diffusion-Repository klonen
- Imagenette-Dataset vorbereiten
- Trainingsskript ausführen
Kosten
In diesem Dokument verwenden Sie die folgenden kostenpflichtigen Komponenten von Google Cloud:
- Compute Engine
- Cloud TPU
Mit dem Preisrechner können Sie eine Kostenschätzung für Ihre voraussichtliche Nutzung vornehmen.
Hinweise
Bevor Sie mit dieser Anleitung beginnen, prüfen Sie, ob Ihr Google Cloud-Projekt ordnungsgemäß eingerichtet ist.
- Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Google Cloud project.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Google Cloud project.
In dieser Anleitung werden kostenpflichtige Komponenten der Google Cloud verwendet. Rufen Sie die Seite mit den Cloud TPU-Preisen auf, um Ihre Kosten abzuschätzen. Denken Sie daran, nicht mehr benötigte Ressourcen zu bereinigen, um unnötige Kosten zu vermeiden.
Cloud TPU erstellen
In dieser Anleitung wird ein v4-8 verwendet, aber die Anleitung funktioniert ähnlich für alle Beschleunigergrößen auf einem einzelnen Host.
Richten Sie einige Umgebungsvariablen ein, um die Verwendung der Befehle zu vereinfachen.
export ZONE=us-central2-b export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v4-8 export RUNTIME_VERSION=tpu-ubuntu2204-base export TPU_NAME=your_tpu_name
Erstellen Sie eine Cloud TPU.
gcloud compute tpus tpu-vm create ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION} \ --subnetwork=tpusubnet
Erforderliche Software installieren
Installieren Sie die erforderlichen Pakete zusammen mit dem neuesten Release 2.4.0 von PyTorch/XLA.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --command="sudo apt-get update -y && sudo apt-get install libgl1 -y git clone https://github.com/pytorch-tpu/stable-diffusion.git cd stable-diffusion pip install -r requirements.txt pip install -e . pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers pip install clip pip install torch~=2.4.0 torch_xla[tpu]~=2.4.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
Korrigieren Sie die Quelldateien so, dass sie mit Torch 2.2 und höher kompatibel sind.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --command="cd stable-diffusion/ sed -i 's/from torch._six import string_classes/string_classes = (str, bytes)/g' src/taming-transformers/taming/data/utils.py sed -i 's/trainer_kwargs\\[\"callbacks\"\\]/# trainer_kwargs\\[\"callbacks\"\\]/g' main_tpu.py"
Imagenette herunterladen (eine kleinere Version des Imagenet-Datasets) und in das entsprechende Verzeichnis verschieben.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --command="wget -nv https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz tar -xf imagenette2.tgz mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_train/data mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_validation/data mv imagenette2/train/* ~/.cache/autoencoders/data/ILSVRC2012_train/data mv imagenette2/val/* ~/.cache/autoencoders/data/ILSVRC2012_validation/data"
Laden Sie das vortrainierte Modell der ersten Phase herunter.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --command="cd stable-diffusion/ wget -nv -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip cd models/first_stage_models/vq-f8/ unzip -o model.zip"
Modell trainieren
Führen Sie das Training mit dem folgenden Befehl aus. Der Trainingsprozess dauert voraussichtlich etwa 30 Minuten auf Version 4-8.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --command="python3 stable-diffusion/main_tpu.py --train --no-test --base=stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8-ss.yaml -- data.params.batch_size=32 lightning.trainer.max_epochs=5 model.params.first_stage_config.params.ckpt_path=stable-diffusion/models/first_stage_models/vq-f8/model.ckpt lightning.trainer.enable_checkpointing=False lightning.strategy.sync_module_states=False"
Bereinigen
Führen Sie eine Bereinigung durch, damit Ihr Konto nach der Verwendung der von Ihnen erstellten Ressourcen nicht unnötig belastet wird:
Verwenden Sie die Google Cloud CLI, um die Cloud TPU-Ressource zu löschen.
$ gcloud compute tpus tpu-vm delete diffusion-tutorial --zone=us-central2-b
Nächste Schritte
- Resnet50 auf Cloud TPU mit PyTorch trainieren
- Fehlerbehebung bei PyTorch auf TPUs
- Pytorch/XLA-Dokumentation