Pax를 사용하여 단일 호스트 TPU에서 학습


이 문서에서는 단일 호스트 TPU(v2-8, v3-8, v4-8)에서 Pax를 사용하는 방법을 간략히 소개합니다.

Pax는 JAX를 기반으로 머신러닝 실험을 구성하고 실행하는 프레임워크입니다. Pax는 기존 ML 프레임워크와 인프라 구성요소를 공유하고 모듈성을 위한 Praxis 모델링 라이브러리를 활용하여 대규모 ML을 간소화하는 데 중점을 둡니다.

목표

  • 학습용 TPU 리소스 설정
  • 단일 호스트 TPU에 Pax 설치
  • Pax를 사용한 변환기 기반 SPMD 모델 학습

시작하기 전에

다음 명령어를 실행하여 Cloud TPU 프로젝트를 사용하도록 gcloud를 구성하고 단일 호스트 TPU에서 Pax를 실행하는 모델을 학습시키는 데 필요한 구성요소를 설치합니다.

Google Cloud CLI 설치

Google Cloud CLI에는 Google Cloud CLI 제품 및 서비스와 상호작용하기 위한 도구 및 라이브러리가 포함되어 있습니다. 이전에 설치하지 않았으면 지금 Google Cloud CLI 설치의 안내에 따라 설치합니다.

gcloud 명령어 구성

(gcloud auth list를 실행하여 사용 가능한 계정을 확인합니다.)

$ gcloud config set account account

$ gcloud config set project project-id

Cloud TPU API 사용 설정

Cloud Shell에서 다음 gcloud 명령어를 사용하여 Cloud TPU API를 사용 설정합니다. Google Cloud 콘솔에서도 사용 설정할 수 있습니다.

$ gcloud services enable tpu.googleapis.com

다음 명령어를 실행하여 서비스 ID(서비스 계정)를 만듭니다.

$ gcloud beta services identity create --service tpu.googleapis.com

TPU VM 만들기

Cloud TPU VMs에서는 모델 및 코드가 TPU VM에서 직접 실행됩니다. TPU VM에 직접 SSH로 연결합니다. TPU VM에서 직접 임의 코드를 실행하고, 패키지를 설치하고, 로그를 보고, 코드를 디버깅할 수 있습니다.

Cloud Shell 또는 Google Cloud CLI가 설치된 컴퓨터 터미널에서 다음 명령어를 실행하여 TPU VM을 만듭니다.

계약의 가용성에 따라 zone을 설정합니다. 필요한 경우 TPU 리전 및 영역을 참조하세요.

accelerator-type 변수를 v2-8, v3-8 또는 v4-8로 설정합니다.

v2 및 v3 TPU 버전의 경우 version 변수를 tpu-vm-base로 설정하거나 v4 TPU의 경우 tpu-vm-v4-base로 설정합니다.

$ gcloud compute tpus tpu-vm create tpu-name \
--zone zone \
--accelerator-type accelerator-type \
--version version

Google Cloud TPU VM에 연결

다음 명령어를 사용하여 TPU VM에 SSH로 연결합니다.

$ gcloud compute tpus tpu-vm ssh tpu-name --zone zone

VM에 로그인하면 셸 프롬프트가 username@projectname에서 username@vm-name으로 변경됩니다.

Google Cloud TPU VM에 Pax 설치

다음 명령어를 사용하여 TPU VM에 Pax, JAX, libtpu를 설치합니다.

(vm)$ python3 -m pip install -U pip \
python3 -m pip install paxml jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

시스템 확인

JAX에 TPU 코어가 표시되는지 확인하여 모든 것이 올바르게 설치되었는지 테스트합니다.

(vm)$ python3 -c "import jax; print(jax.device_count())"

TPU 코어 수는 v2-8 또는 v3-8을 사용하는 경우 8, v4-8을 사용하는 경우 4여야 합니다.

TPU VM에서 Pax 코드 실행

이제 원하는 Pax 코드를 실행할 수 있습니다. lm_cloud 예시는 Pax에서 모델을 실행하기에 좋은 장소입니다. 예를 들어 다음 명령어는 합성 데이터에 대해 2B 매개변수 변환기 기반 SPMD 언어 모델을 학습시킵니다.

다음 명령어는 SPMD 언어 모델의 학습 출력을 보여줍니다. 약 20분 동안 300단계를 학습합니다.

(vm)$ python3 .local/lib/python3.10/site-packages/paxml/main.py  --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps --job_log_dir=job_log_dir

v4-8 슬라이스의 경우 출력에 다음이 포함되어야 합니다.

손실 및 단계 시간

단계의 요약 텐서=step_# loss = loss
단계의 요약 텐서=step_#초당 x 단계

삭제

이 튜토리얼에서 사용된 리소스 비용이 Google Cloud 계정에 청구되지 않도록 하려면 리소스가 포함된 프로젝트를 삭제하거나 프로젝트를 유지하고 개별 리소스를 삭제하세요.

TPU VM 사용이 완료되었으면 다음 단계에 따라 리소스를 삭제하세요.

Compute Engine 인스턴스에서 연결을 해제합니다.

(vm)$ exit

Cloud TPU를 삭제합니다.

$ gcloud compute tpus tpu-vm delete tpu-name  --zone zone

다음 단계

Cloud TPU에 대한 자세한 내용은 다음을 참조하세요.