Entraîner des modèles de diffusion avec PyTorch


Ce tutoriel explique comment entraîner des modèles de diffusion sur des TPU à l'aide de PyTorch Lightning et de PyTorch XLA.

Objectifs

  • Créer une instance Cloud TPU
  • Installer PyTorch Lightning
  • Cloner le dépôt diffusion
  • Préparer l'ensemble de données Imagenette
  • Exécuter le script d'entraînement

Coûts

Dans ce document, vous utilisez les composants facturables suivants de Google Cloud :

  • Compute Engine
  • Cloud TPU

Obtenez une estimation des coûts en fonction de votre utilisation prévue à l'aide du simulateur de coût. Les nouveaux utilisateurs de Google Cloud peuvent bénéficier d'un essai gratuit.

Avant de commencer

Avant de commencer ce tutoriel, vérifiez que votre projet Google Cloud est correctement configuré.

  1. Connectez-vous à votre compte Google Cloud. Si vous débutez sur Google Cloud, créez un compte pour évaluer les performances de nos produits en conditions réelles. Les nouveaux clients bénéficient également de 300 $ de crédits gratuits pour exécuter, tester et déployer des charges de travail.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  3. Make sure that billing is enabled for your Google Cloud project.

  4. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  5. Make sure that billing is enabled for your Google Cloud project.

  6. Ce tutoriel utilise des composants facturables de Google Cloud. Consultez la grille tarifaire de Cloud TPU pour estimer vos coûts. Veillez à nettoyer les ressources que vous avez créées lorsque vous avez terminé, afin d'éviter des frais inutiles.

Créer une instance Cloud TPU

Ce tutoriel utilise la version 4-8, mais fonctionne de la même manière sur tous les accélérateurs. tailles sur un seul hôte.

Configurez des variables d'environnement pour faciliter l'utilisation des commandes.

export ZONE=us-central2-b
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v4-8
export RUNTIME_VERSION=tpu-ubuntu2204-base
export TPU_NAME=your_tpu_name

Créez un Cloud TPU.

gcloud compute tpus tpu-vm create ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version=${RUNTIME_VERSION} \
--subnetwork=tpusubnet

Installer le logiciel requis

  1. Installez les packages requis ainsi que la dernière version de PyTorch/XLA v2.4.0.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="sudo apt-get update -y && sudo apt-get install libgl1 -y
    git clone https://github.com/pytorch-tpu/stable-diffusion.git
    cd stable-diffusion
    pip install -r requirements.txt
    pip install -e .
    pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U
    pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
    pip install clip
    pip install torch~=2.4.0 torch_xla[tpu]~=2.4.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
  2. Correction des fichiers sources pour qu'ils soient compatibles avec Torch 2.2 et les versions ultérieures.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="cd stable-diffusion/
    sed -i 's/from torch._six import string_classes/string_classes = (str, bytes)/g' src/taming-transformers/taming/data/utils.py
    sed -i 's/trainer_kwargs\\[\"callbacks\"\\]/# trainer_kwargs\\[\"callbacks\"\\]/g' main_tpu.py"
  3. Téléchargez Imagenette (une version plus petite de l'ensemble de données ImageNet) et déplacez-la dans le répertoire approprié.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="wget -nv https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
    tar -xf  imagenette2.tgz
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_validation/data
    mv imagenette2/train/*  ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mv imagenette2/val/* ~/.cache/autoencoders/data/ILSVRC2012_validation/data"
  4. Téléchargez le modèle pré-entraîné de la première étape.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="cd stable-diffusion/
    wget -nv -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
    cd  models/first_stage_models/vq-f8/
    unzip -o model.zip"

Entraîner le modèle

Exécutez l'entraînement à l'aide de la commande suivante. Notez que le processus d'entraînement devrait prendre environ 30 minutes sur la version 4-8.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--command="python3 stable-diffusion/main_tpu.py --train --no-test --base=stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8-ss.yaml -- data.params.batch_size=32 lightning.trainer.max_epochs=5 model.params.first_stage_config.params.ckpt_path=stable-diffusion/models/first_stage_models/vq-f8/model.ckpt lightning.trainer.enable_checkpointing=False lightning.strategy.sync_module_states=False"

Effectuer un nettoyage

Pour éviter une facturation inutile sur votre compte, effectuez un nettoyage des ressources que vous avez créées :

Utilisez la CLI Google Cloud pour supprimer la ressource Cloud TPU.

  $  gcloud compute tpus tpu-vm delete diffusion-tutorial --zone=us-central2-b
  

Étape suivante