이 튜토리얼에서는 PyTorch Lightning 및 Pytorch XLA를 사용하여 TPU에서 분산 모델을 학습시키는 방법을 보여줍니다.
목표
- Cloud TPU 만들기
- PyTorch Lightning 설치
- 분산 저장소 클론
- Imagenette 데이터 세트 준비
- 학습 스크립트 실행
비용
이 문서에서는 비용이 청구될 수 있는 다음과 같은 Google Cloud 구성요소를 사용합니다.
- Compute Engine
- Cloud TPU
프로젝트 사용량을 기준으로 예상 비용을 산출하려면 가격 계산기를 사용하세요.
시작하기 전에
이 튜토리얼을 시작하기 전에 Google Cloud 프로젝트가 올바르게 설정되었는지 확인하세요.
- Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Google Cloud project.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Google Cloud project.
이 둘러보기에서는 비용이 청구될 수 있는 Google Cloud 구성요소를 사용합니다. 예상 비용은 Cloud TPU 가격 책정 페이지에서 확인하세요. 리소스 사용을 마쳤으면 불필요한 비용이 청구되지 않도록 생성한 리소스를 삭제하세요.
Cloud TPU 만들기
이 튜토리얼에서는 v4-8을 사용하지만 단일 호스트의 모든 가속기 크기에서 유사하게 작동합니다.
명령어를 더 쉽게 사용할 수 있도록 몇 가지 환경 변수를 설정합니다.
export ZONE=us-central2-b export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v4-8 export RUNTIME_VERSION=tpu-ubuntu2204-base export TPU_NAME=your_tpu_name
Cloud TPU를 만듭니다.
gcloud compute tpus tpu-vm create ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION} \ --subnetwork=tpusubnet
필수 소프트웨어 설치
PyTorch/XLA 최신 출시 버전 v2.4.0과 함께 필수 패키지를 설치합니다.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --command="sudo apt-get update -y && sudo apt-get install libgl1 -y git clone https://github.com/pytorch-tpu/stable-diffusion.git cd stable-diffusion pip install -r requirements.txt pip install -e . pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers pip install clip pip install torch~=2.4.0 torch_xla[tpu]~=2.4.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
torch 2.2 이상과 호환되도록 소스 파일을 수정합니다.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --command="cd stable-diffusion/ sed -i 's/from torch._six import string_classes/string_classes = (str, bytes)/g' src/taming-transformers/taming/data/utils.py sed -i 's/trainer_kwargs\\[\"callbacks\"\\]/# trainer_kwargs\\[\"callbacks\"\\]/g' main_tpu.py"
Imagenette(Imagenet 데이터 세트의 더 작은 버전)를 다운로드하고 적절한 디렉터리로 이동합니다.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --command="wget -nv https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz tar -xf imagenette2.tgz mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_train/data mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_validation/data mv imagenette2/train/* ~/.cache/autoencoders/data/ILSVRC2012_train/data mv imagenette2/val/* ~/.cache/autoencoders/data/ILSVRC2012_validation/data"
1단계 사전 학습된 모델을 다운로드합니다.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --command="cd stable-diffusion/ wget -nv -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip cd models/first_stage_models/vq-f8/ unzip -o model.zip"
모델 학습
다음 명령어를 사용하여 학습을 실행합니다. v4-8에서는 학습 프로세스가 약 30분 정도 걸릴 것으로 예상됩니다.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --command="python3 stable-diffusion/main_tpu.py --train --no-test --base=stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8-ss.yaml -- data.params.batch_size=32 lightning.trainer.max_epochs=5 model.params.first_stage_config.params.ckpt_path=stable-diffusion/models/first_stage_models/vq-f8/model.ckpt lightning.trainer.enable_checkpointing=False lightning.strategy.sync_module_states=False"
삭제
만든 리소스를 사용한 후에는 계정에 불필요한 비용이 청구되지 않도록 삭제를 수행하세요.
Google Cloud CLI를 사용하여 Cloud TPU 리소스를 삭제합니다.
$ gcloud compute tpus tpu-vm delete diffusion-tutorial --zone=us-central2-b
다음 단계