JAX를 사용하여 Cloud TPU VM에서 계산 실행
이 문서에서는 JAX 및 Cloud TPU 작업에 대해 간략히 안내합니다.
이 빠른 시작을 수행하기 전에 Google Cloud Platform 계정을 만들고 Google Cloud CLI를 설치하고 gcloud
명령어를 구성해야 합니다.
자세한 내용은 계정 및 Cloud TPU 프로젝트 설정을 참조하세요.
Google Cloud CLI 설치
Google Cloud CLI에는 Google Cloud 제품 및 서비스와 상호작용하기 위한 도구 및 라이브러리가 포함되어 있습니다. 자세한 내용은 Google Cloud CLI 설치를 참조하세요.
gcloud
명령어 구성
다음 명령어를 실행하여 Google Cloud 프로젝트를 사용하도록 gcloud
를 구성하고 TPU VM 미리보기에 필요한 구성요소를 설치합니다.
$ gcloud config set account your-email-account $ gcloud config set project your-project-id
Cloud TPU API 사용 설정
Cloud Shell에서 다음
gcloud
명령어를 사용하여 Cloud TPU API를 사용 설정합니다. Google Cloud 콘솔에서도 사용 설정할 수 있습니다.$ gcloud services enable tpu.googleapis.com
다음 명령어를 실행하여 서비스 계정을 만듭니다.
$ gcloud beta services identity create --service tpu.googleapis.com
gcloud
를 사용하여 Cloud TPU VM 만들기
Cloud TPU VMs에서는 모델 및 코드가 TPU 호스트 머신에서 직접 실행됩니다. TPU 호스트에 직접 SSH로 연결합니다. TPU 호스트에서 직접 임의 코드를 실행하고, 패키지를 설치하고, 로그를 보고, 코드를 디버깅할 수 있습니다.
Cloud Shell 또는 Google Cloud CLI가 설치된 컴퓨터 터미널에서 다음 명령어를 실행하여 TPU VM을 만듭니다.
(vm)$ gcloud compute tpus tpu-vm create tpu-name \ --zone=us-central2-b \ --accelerator-type=v4-8 \ --version=tpu-ubuntu2204-base
Cloud TPU VM에 연결
다음 명령어를 사용하여 TPU VM에 SSH로 연결합니다.
$ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central2-b
필수 입력란
tpu_name
- 연결하려는 TPU VM의 이름입니다.
zone
- Cloud TPU를 만든 영역입니다.
Cloud TPU VM에 JAX 설치
(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
시스템 확인
JAX가 TPU에 액세스할 수 있고 기본 작업을 실행할 수 있는지 확인합니다.
Python 3 인터프리터 시작:
(vm)$ python3
>>> import jax
사용 가능한 TPU 코어 수 표시:
>>> jax.device_count()
TPU 코어 수가 표시됩니다. v4 TPU를 사용하는 경우에는 4
여야 합니다. v2 또는 v3 TPU를 사용하는 경우에는 8
이어야 합니다.
간단한 계산 수행:
>>> jax.numpy.add(1, 1)
numpy add의 결과가 표시됩니다.
명령어에서 출력합니다.
Array(2, dtype=int32, weak_type=true)
Python 인터프리터 종료:
>>> exit()
TPU VM에서 JAX 코드 실행
이제 원하는 JAX 코드를 실행할 수 있습니다. flax 예시는 JAX에서 표준 ML 모드 실행을 시작할 수 있는 훌륭한 장소입니다. 예를 들어 기본 MNIST 컨볼루션 네트워크를 학습시키려면 다음 안내를 따르세요.
Flax 예시 종속 항목 설치
(vm)$ pip install --upgrade clu (vm)$ pip install tensorflow (vm)$ pip install tensorflow_datasets
FLAX를 설치합니다.
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user flax
FLAX MNIST 학습 스크립트 실행
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
스크립트가 데이터 세트를 다운로드하고 학습을 시작합니다. 스크립트 출력은 다음과 같아야 합니다.
0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88 I0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72 I0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04 I0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15 I0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18
삭제
TPU VM 사용이 완료되었으면 다음 단계에 따라 리소스를 삭제하세요.
Compute Engine 인스턴스에서 연결을 해제합니다.
(vm)$ exit
Cloud TPU를 삭제합니다.
$ gcloud compute tpus tpu-vm delete tpu-name \ --zone=us-central2-b
다음 명령어를 실행하여 리소스가 삭제되었는지 확인합니다. TPU가 더 이상 나열되지 않았는지 확인합니다. 삭제하는 데 몇 분 정도 걸릴 수 있습니다.
$ gcloud compute tpus tpu-vm list \ --zone=us-central2-b
성능 참고사항
다음은 JAX에서 특히 TPU 사용과 관련된 몇 가지 중요한 세부정보입니다.
패딩
TPU에서 성능이 느려지는 가장 일반적인 원인 중 하나는 의도치 않은 패딩의 도입입니다.
- Cloud TPU의 배열은 타일로 나누어집니다. 여기에는 차원 중 하나를 8의 배수로 패딩하고 다른 차원을 128의 배수로 패딩하는 작업이 수반됩니다.
- 행렬 곱셈 단위는 패딩 요구를 최소화하는 대규모 행렬의 쌍에서 효과가 가장 높습니다.
bfloat16 dtype
기본적으로 JAX에서 TPU의 행렬 곱셈에는 float32 누적의 bfloat16이 사용됩니다. 이것은 관련된 jax.numpy 함수 호출에서 정밀도 인수로 제어될 수 있습니다(matmul, dot, einsum 등). 특히 다음 내용이 해당됩니다.
precision=jax.lax.Precision.DEFAULT
: 혼합 bfloat16 정밀도 사용(가장 빠름)precision=jax.lax.Precision.HIGH
: 여러 MXU 패스를 사용하여 정밀도 향상precision=jax.lax.Precision.HIGHEST
: 훨씬 더 많은 MXU 패스를 사용하여 전체 float32 정밀도 달성
JAX는 bfloat16 dtype도 추가하여 배열을 bfloat16
으로 명시적으로 변환하는 데 사용할 수 있으며 예를 들면 jax.numpy.array(x, dtype=jax.numpy.bfloat16)
입니다.
Colab에서 JAX 실행
Colab 노트북에서 JAX 코드를 실행할 때 Colab은 레거시 TPU 노드를 자동으로 만듭니다. TPU 노드에는 다른 아키텍처가 사용됩니다. 자세한 내용은 시스템 아키텍처를 참조하세요.
다음 단계
Cloud TPU에 대한 자세한 내용은 다음을 참조하세요.