使用 PyTorch 训练扩散模型


本教程介绍如何使用 PyTorch Lightning 和 Pytorch XLA 在 TPU 上训练扩散模型。

目标

  • 创建 Cloud TPU
  • 安装 PyTorch Lightning
  • 克隆扩散代码库
  • 准备 Imagenette 数据集
  • 运行训练脚本

费用

在本文档中,您将使用 Google Cloud 的以下收费组件:

  • Compute Engine
  • Cloud TPU

您可使用价格计算器根据您的预计使用情况来估算费用。 Google Cloud 新用户可能有资格申请免费试用

准备工作

在开始学习本教程之前,请检查您的 Google Cloud 项目是否已正确设置。

  1. 登录您的 Google Cloud 账号。如果您是 Google Cloud 新手,请创建一个账号来评估我们的产品在实际场景中的表现。新客户还可获享 $300 赠金,用于运行、测试和部署工作负载。
  2. 在 Google Cloud Console 中的项目选择器页面上,选择或创建一个 Google Cloud 项目

    转到“项目选择器”

  3. 确保您的 Google Cloud 项目已启用结算功能

  4. 在 Google Cloud Console 中的项目选择器页面上,选择或创建一个 Google Cloud 项目

    转到“项目选择器”

  5. 确保您的 Google Cloud 项目已启用结算功能

  6. 本演示使用 Google Cloud 的收费组件。请查看 Cloud TPU 价格页面估算您的费用。请务必在使用完您创建的资源以后清理这些资源,以免产生不必要的费用。

创建 Cloud TPU

以下说明适用于单主机和多主机 TPU。本教程使用的是 v4-128,但它适用于所有加速器大小。

设置一些环境变量,以使命令更易于使用。

export ZONE=us-central2-b
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v4-128
export RUNTIME_VERSION=tpu-ubuntu2204-base
export TPU_NAME=your_tpu_name

创建 Cloud TPU。

gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version=${RUNTIME_VERSION} \
--subnetwork=tpusubnet

安装所需的软件

  1. 安装所需的软件包以及 PyTorch/XLA 最新版本 v2.2.0。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone=us-central2-b \
    --worker=all \
    --command="sudo apt-get update -y && sudo apt-get install libgl1 -y
    git clone https://github.com/pytorch-tpu/stable-diffusion.git
    cd stable-diffusion
    pip install -e .
    pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U
    pip install clip
    pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
    
  2. 修复了源文件,以便与 Torch 2.2 及更高版本兼容。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone=us-central2-b \
    --worker=all \
    --command="cd ~/stable-diffusion/
    sed -i \'s/from torch._six import string_classes/string_classes = (str, bytes)/g\' src/taming-transformers/taming/data/utils.py
    sed -i \'s/trainer_kwargs\\[\"callbacks\"\\]/# trainer_kwargs\\[\"callbacks\"\\]/g\' main_tpu.py"
    
  3. 下载 Imagenette(较小版本的 Imagenet 数据集)并将其移至适当的目录。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone us-central2-b \
    --worker=all \
    --command="wget -nv https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
    tar -xf  imagenette2.tgz
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_validation/data
    mv imagenette2/train/*  ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mv imagenette2/val/* ~/.cache/autoencoders/data/ILSVRC2012_validation/data"
    
  4. 下载第一阶段的预训练模型。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone us-central2-b \
    --worker=all \
    --command="cd ~/stable-diffusion/
    wget -nv -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
    cd  models/first_stage_models/vq-f8/
    unzip -o model.zip"
    

训练模型

使用以下命令运行训练:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--zone us-central2-b \
--worker=all \
--command="python3 stable-diffusion/main_tpu.py --train --no-test --base=stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8-ss.yaml -- data.params.batch_size=32 lightning.trainer.max_epochs=5 model.params.first_stage_config.params.ckpt_path=stable-diffusion/models/first_stage_models/vq-f8/model.ckpt lightning.trainer.enable_checkpointing=False lightning.strategy.sync_module_states=False"

清理

使用您创建的资源后,请进行清理,以免您的账号产生不必要的费用:

使用 Google Cloud CLI 删除 Cloud TPU 资源。

  $  gcloud compute tpus delete diffusion-tutorial --zone=us-central2-b
  

后续步骤

试用 PyTorch Colab: