This tutorial shows how to train diffusion models on TPUs using PyTorch Lightning and Pytorch XLA.
Objectives
- Create a Cloud TPU
- Install PyTorch Lightning
- Clone the diffusion repo
- Prepare the Imagenette dataset
- Run the training script
Costs
In this document, you use the following billable components of Google Cloud:
- Compute Engine
- Cloud TPU
To generate a cost estimate based on your projected usage,
use the pricing calculator.
Before you begin
Before starting this tutorial, check that your Google Cloud project is correctly set up.
- Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Google Cloud project.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Google Cloud project.
This walkthrough uses billable components of Google Cloud. Check the Cloud TPU pricing page to estimate your costs. Be sure to clean up resources you create when you've finished with them to avoid unnecessary charges.
Create a Cloud TPU
These instructions work on both single host and multi-host TPUs. This tutorial uses a v4-128, but it works similarly on all accelerator sizes.
Set up some environment variables to make the commands easier to use.
export ZONE=us-central2-b export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v4-128 export RUNTIME_VERSION=tpu-vm-v4-pt-2.0 export TPU_NAME=your_tpu_name
Create a Cloud TPU.
gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION} \ --subnetwork=tpusubnet
Install required software
Install the PyTorch Lightning package from source. This is important since the stable release is not compatible with PytorchXLA 2.0.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=us-central2-b \ --worker=all \ --command="git clone https://github.com/Lightning-AI/lightning.git cd lightning pip install -e . "
Clone the diffusion repo adapted for TPUs.
Note that logging output was reduced and the progress bar was disabled to reduce device-host communication which is critical for TPU performance.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=us-central2-b \ --worker=all \ --command="git clone https://github.com/pytorch-tpu/stable-diffusion.git cd stable-diffusion pip install -e . pip install -r requirements.txt sudo apt-get update && sudo apt-get install libgl1 -y "
Install necessary packages.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=us-central2-b \ --worker=all \ --command="pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers "
Replace torch._six import in the taming transformer package to be compatible with torch 2.0 and newer.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=us-central2-b \ --worker=all \ --command="sed -i 's/from torch._six import string_classes/string_classes = (str, bytes)/' src/taming-transformers/taming/data/utils.py "
Download Imagenette (a smaller version of Imagenet dataset) and move it to the appropriate directory.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone us-central2-b \ --worker=all \ --command="wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz tar -xf imagenette2.tgz mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_train/data mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_validation/data mv imagenette2/train/* ~/.cache/autoencoders/data/ILSVRC2012_train/data mv imagenette2/val/* ~/.cache/autoencoders/data/ILSVRC2012_validation/data rm -r imagenette2/ "
Download the first stage pretrained model.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone us-central2-b \ --worker=all \ --command="wget -O ~/stable-diffusion/models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip cd ~/stable-diffusion/models/first_stage_models/vq-f8/ unzip -o model.zip rm model.zip "
Train the model
Run the training with following command:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone us-central2-b \ --worker=all \ --command="cd stable-diffusion PJRT_DEVICE=TPU python3 main_tpu.py --train --no-test --base configs/latent-diffusion/cin-ldm-vq-f8-ss.yaml 2>&1 | tee logs_lighting_git.txt "
If you would like to use NUMA, first install numactl
by running the following
command.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone us-central2-b --worker=all --command="sudo apt-get update sudo apt-get install numactl "
When running the training script, add the numactl --cpunodebind=0
to the script
command:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone us-central2-b --worker=all --command="cd stable-diffusion PJRT_DEVICE=TPU numactl --cpunodebind=0 python3 main_tpu.py --train --no-test --base configs/latent-diffusion/cin-ldm-vq-f8-ss.yaml "
Clean up
Perform a cleanup to avoid incurring unnecessary charges to your account after using the resources you created:
Use Google Cloud CLI to delete the Cloud TPU resource.
$ gcloud compute tpus delete diffusion-tutorial --zone=us-central2-b
What's next
Try the PyTorch colabs:
- Getting Started with PyTorch on Cloud TPUs
- Training MNIST on TPUs
- Training ResNet18 on TPUs with Cifar10 dataset
- Inference with Pretrained ResNet50 Model
- Fast Neural Style Transfer
- MultiCore Training AlexNet on Fashion MNIST
- Single Core Training AlexNet on Fashion MNIST