JAX를 사용하여 Cloud TPU VM에서 계산 실행
이 문서에서는 JAX 및 Cloud TPU 작업에 대해 간략히 안내합니다.
시작하기 전에
이 문서의 명령어를 실행하기 전에 Google Cloud계정을 만들고 Google Cloud CLI를 설치하고 gcloud 명령어를 구성해야 합니다. 자세한 내용은 Cloud TPU 환경 설정을 참조하세요.
gcloud를 사용하여 Cloud TPU VM 만들기
- 명령어를 더 쉽게 사용할 수 있도록 몇 가지 환경 변수를 정의합니다. - export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-east5-a export ACCELERATOR_TYPE=v5litepod-8 export RUNTIME_VERSION=v2-alpha-tpuv5-lite - 환경 변수 설명- 변수 - 설명 - PROJECT_ID- Google Cloud 프로젝트 ID입니다. 기존 프로젝트를 사용하거나 새 프로젝트를 만듭니다. - TPU_NAME- TPU의 이름입니다. - ZONE- TPU VM을 만들 영역입니다. 지원되는 영역에 대한 자세한 내용은 TPU 리전 및 영역을 참조하세요. - ACCELERATOR_TYPE- 가속기 유형은 만들려는 Cloud TPU의 버전과 크기를 지정합니다. 각 TPU 버전에서 지원되는 가속기 유형에 대한 자세한 내용은 TPU 버전을 참조하세요. - RUNTIME_VERSION- Cloud TPU 소프트웨어 버전입니다. 
- Cloud Shell 또는 Google Cloud CLI가 설치된 컴퓨터 터미널에서 다음 명령어를 실행하여 TPU VM을 만듭니다. - $ gcloud compute tpus tpu-vm create $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION 
Cloud TPU VM에 연결
다음 명령어를 사용하여 SSH를 통해 TPU VM에 연결합니다.
$ gcloud compute tpus tpu-vm ssh $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
SSH를 사용하여 TPU VM에 연결할 수 없는 경우 TPU VM에 외부 IP 주소가 없기 때문일 수 있습니다. 외부 IP 주소가 없는 TPU VM에 액세스하려면 공개 IP 주소가 없는 TPU VM에 연결의 안내를 따르세요.
Cloud TPU VM에 JAX 설치
(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
시스템 확인
JAX가 TPU에 액세스할 수 있고 기본 작업을 실행할 수 있는지 확인합니다.
- Python 3 인터프리터 시작: - (vm)$ python3- >>> import jax 
- 사용 가능한 TPU 코어 수 표시: - >>> jax.device_count() 
TPU 코어 수가 표시됩니다. 표시되는 코어 수는 사용 중인 TPU 버전에 따라 다릅니다. 자세한 내용은 TPU 버전을 참조하세요.
계산 수행
>>> jax.numpy.add(1, 1)
numpy add의 결과가 표시됩니다.
명령어에서 출력합니다.
Array(2, dtype=int32, weak_type=True)
Python 인터프리터 종료
>>> exit()
TPU VM에서 JAX 코드 실행
이제 원하는 JAX 코드를 실행할 수 있습니다. Flax 예시는 JAX에서 표준 ML 모드 실행을 시작할 수 있는 훌륭한 장소입니다. 예를 들어 기본 MNIST 컨볼루션 네트워크를 학습시키려면 다음 안내를 따르세요.
- Flax 예시 종속 항목 설치 - (vm)$ pip install --upgrade clu (vm)$ pip install tensorflow (vm)$ pip install tensorflow_datasets 
- Flax를 설치합니다. - (vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user flax 
- Flax MNIST 학습 스크립트를 실행합니다. - (vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5 
스크립트가 데이터 세트를 다운로드하고 학습을 시작합니다. 스크립트 출력은 다음과 같아야 합니다.
I0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88 I0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72 I0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04 I0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15 I0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18
삭제
이 페이지에서 사용한 리소스 비용이 Google Cloud 계정에 청구되지 않도록 하려면 다음 단계를 수행합니다.
TPU VM 사용이 완료되었으면 다음 단계에 따라 리소스를 삭제하세요.
- Cloud TPU 인스턴스에서 아직 연결을 해제하지 않았으면 연결을 해제합니다. - (vm)$ exit - 프롬프트가 username@projectname으로 바뀌면 Cloud Shell에 있는 것입니다. 
- Cloud TPU를 삭제합니다. - $ gcloud compute tpus tpu-vm delete $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE 
- 다음 명령어를 실행하여 리소스가 삭제되었는지 확인합니다. TPU가 더 이상 나열되지 않았는지 확인합니다. 삭제하는 데 몇 분 정도 걸릴 수 있습니다. - $ gcloud compute tpus tpu-vm list \ --zone=$ZONE 
성능 참고사항
다음은 JAX에서 특히 TPU 사용과 관련된 몇 가지 중요한 세부정보입니다.
패딩
TPU에서 성능이 느려지는 가장 일반적인 원인 중 하나는 의도치 않은 패딩의 도입입니다.
- Cloud TPU의 배열은 타일로 나누어집니다. 여기에는 차원 중 하나를 8의 배수로 패딩하고 다른 차원을 128의 배수로 패딩하는 작업이 수반됩니다.
- 행렬 곱셈 단위는 패딩 요구를 최소화하는 대규모 행렬의 쌍에서 효과가 가장 높습니다.
bfloat16 dtype
기본적으로 JAX에서 TPU의 행렬 곱셈에는 float32 누적의 bfloat16이 사용됩니다. 이것은 관련된 jax.numpy 함수 호출에서 정밀도 인수로 제어될 수 있습니다(matmul, dot, einsum 등). 특히 다음 옵션이 지원됩니다.
- precision=jax.lax.Precision.DEFAULT: 혼합 bfloat16 정밀도(가장 빠름) 사용
- precision=jax.lax.Precision.HIGH: 여러 MXU 패스를 사용하여 정밀도 향상
- precision=jax.lax.Precision.HIGHEST: 훨씬 더 많은 MXU 패스를 사용하여 전체 float32 정밀도 달성
JAX는 bfloat16 dtype도 추가하여 배열을 bfloat16으로 명시적으로 변환하는 데 사용할 수 있으며 예를 들면 jax.numpy.array(x, dtype=jax.numpy.bfloat16)입니다.
다음 단계
Cloud TPU에 대한 자세한 내용은 다음을 참조하세요.