PyTorch를 사용한 분산 모델 학습


이 튜토리얼에서는 PyTorch Lightning 및 Pytorch XLA를 사용하여 TPU에서 분산 모델을 학습시키는 방법을 보여줍니다.

목표

  • Cloud TPU 만들기
  • PyTorch Lightning 설치
  • 분산 저장소 클론
  • Imagenette 데이터 세트 준비
  • 학습 스크립트 실행

비용

이 문서에서는 비용이 청구될 수 있는 다음과 같은 Google Cloud 구성요소를 사용합니다.

  • Compute Engine
  • Cloud TPU

프로젝트 사용량을 기준으로 예상 비용을 산출하려면 가격 계산기를 사용하세요. Google Cloud를 처음 사용하는 사용자는 무료 체험판을 사용할 수 있습니다.

시작하기 전에

이 튜토리얼을 시작하기 전에 Google Cloud 프로젝트가 올바르게 설정되었는지 확인하세요.

  1. Google Cloud 계정에 로그인합니다. Google Cloud를 처음 사용하는 경우 계정을 만들고 Google 제품의 실제 성능을 평가해 보세요. 신규 고객에게는 워크로드를 실행, 테스트, 배포하는 데 사용할 수 있는 $300의 무료 크레딧이 제공됩니다.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  3. Google Cloud 프로젝트에 결제가 사용 설정되어 있는지 확인합니다.

  4. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  5. Google Cloud 프로젝트에 결제가 사용 설정되어 있는지 확인합니다.

  6. 이 둘러보기에서는 비용이 청구될 수 있는 Google Cloud 구성요소를 사용합니다. 예상 비용은 Cloud TPU 가격 책정 페이지에서 확인하세요. 리소스 사용을 마쳤으면 불필요한 비용이 청구되지 않도록 생성한 리소스를 삭제하세요.

Cloud TPU 만들기

이 안내는 단일 호스트 및 멀티 호스트 TPU 모두에 적용됩니다. 튜토리얼에서는 v4-128을 사용하지만 모든 가속기 크기에서 유사하게 작동합니다.

명령어를 더 쉽게 사용할 수 있도록 몇 가지 환경 변수를 설정합니다.

export ZONE=us-central2-b
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v4-128
export RUNTIME_VERSION=tpu-ubuntu2204-base
export TPU_NAME=your_tpu_name

Cloud TPU를 만듭니다.

gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version=${RUNTIME_VERSION} \
--subnetwork=tpusubnet

필수 소프트웨어 설치

  1. PyTorch/XLA 최신 출시 버전 v2.2.0과 함께 필수 패키지를 설치합니다.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone=us-central2-b \
    --worker=all \
    --command="sudo apt-get update -y && sudo apt-get install libgl1 -y
    git clone https://github.com/pytorch-tpu/stable-diffusion.git
    cd stable-diffusion
    pip install -e .
    pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U
    pip install clip
    pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
  2. torch 2.2 이상과 호환되도록 소스 파일을 수정합니다.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone=us-central2-b \
    --worker=all \
    --command="cd ~/stable-diffusion/
    sed -i \'s/from torch._six import string_classes/string_classes = (str, bytes)/g\' src/taming-transformers/taming/data/utils.py
    sed -i \'s/trainer_kwargs\\[\"callbacks\"\\]/# trainer_kwargs\\[\"callbacks\"\\]/g\' main_tpu.py"
  3. Imagenette(Imagenet 데이터 세트의 더 작은 버전)를 다운로드하고 적절한 디렉터리로 이동합니다.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone us-central2-b \
    --worker=all \
    --command="wget -nv https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
    tar -xf  imagenette2.tgz
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_validation/data
    mv imagenette2/train/*  ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mv imagenette2/val/* ~/.cache/autoencoders/data/ILSVRC2012_validation/data"
  4. 1단계 사전 학습된 모델을 다운로드합니다.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --zone us-central2-b \
    --worker=all \
    --command="cd ~/stable-diffusion/
    wget -nv -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
    cd  models/first_stage_models/vq-f8/
    unzip -o model.zip"

모델 학습

다음 명령어를 사용하여 학습을 실행합니다.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--zone us-central2-b \
--worker=all \
--command="python3 stable-diffusion/main_tpu.py --train --no-test --base=stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8-ss.yaml -- data.params.batch_size=32 lightning.trainer.max_epochs=5 model.params.first_stage_config.params.ckpt_path=stable-diffusion/models/first_stage_models/vq-f8/model.ckpt lightning.trainer.enable_checkpointing=False lightning.strategy.sync_module_states=False"

삭제

만든 리소스를 사용한 후에는 계정에 불필요한 비용이 청구되지 않도록 삭제를 수행하세요.

Google Cloud CLI를 사용하여 Cloud TPU 리소스를 삭제합니다.

  $  gcloud compute tpus delete diffusion-tutorial --zone=us-central2-b
  

다음 단계

PyTorch Colab 사용: