JAX를 사용하여 Cloud TPU VM에서 계산 실행

이 문서에서는 JAX 및 Cloud TPU 작업에 대해 간략히 안내합니다.

이 빠른 시작을 수행하기 전에 Google Cloud Platform 계정을 만들고 Google Cloud CLI를 설치하고 gcloud 명령어를 구성해야 합니다. 자세한 내용은 계정 및 Cloud TPU 프로젝트 설정을 참조하세요.

Google Cloud CLI 설치

Google Cloud CLI에는 Google Cloud 제품 및 서비스와 상호작용하기 위한 도구 및 라이브러리가 포함되어 있습니다. 자세한 내용은 Google Cloud CLI 설치를 참조하세요.

gcloud 명령어 구성

다음 명령어를 실행하여 Google Cloud 프로젝트를 사용하도록 gcloud를 구성하고 TPU VM 미리보기에 필요한 구성요소를 설치합니다.

  $ gcloud config set account your-email-account
  $ gcloud config set project your-project-id

Cloud TPU API 사용 설정

  1. Cloud Shell에서 다음 gcloud 명령어를 사용하여 Cloud TPU API를 사용 설정합니다. Google Cloud 콘솔에서도 사용 설정할 수 있습니다.

    $ gcloud services enable tpu.googleapis.com
    
  2. 다음 명령어를 실행하여 서비스 계정을 만듭니다.

    $ gcloud beta services identity create --service tpu.googleapis.com
    

gcloud를 사용하여 Cloud TPU VM 만들기

Cloud TPU VMs에서는 모델 및 코드가 TPU 호스트 머신에서 직접 실행됩니다. TPU 호스트에 직접 SSH로 연결합니다. TPU 호스트에서 직접 임의 코드를 실행하고, 패키지를 설치하고, 로그를 보고, 코드를 디버깅할 수 있습니다.

  1. Cloud Shell 또는 Google Cloud CLI가 설치된 컴퓨터 터미널에서 다음 명령어를 실행하여 TPU VM을 만듭니다.

    (vm)$ gcloud compute tpus tpu-vm create tpu-name \
    --zone=us-central2-b \
    --accelerator-type=v4-8 \
    --version=tpu-ubuntu2204-base
    

    필수 입력란

    zone
    Cloud TPU를 만들려는 영역입니다.
    accelerator-type
    가속기 유형은 만들려는 Cloud TPU의 버전과 크기를 지정합니다. 각 TPU 버전에서 지원되는 가속기 유형에 대한 자세한 내용은 TPU 버전을 참조하세요.
    version
    Cloud TPU 소프트웨어 버전입니다. 모든 TPU 유형에 tpu-ubuntu2204-base를 사용합니다.

Cloud TPU VM에 연결

다음 명령어를 사용하여 TPU VM에 SSH로 연결합니다.

$ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central2-b

필수 입력란

tpu_name
연결하려는 TPU VM의 이름입니다.
zone
Cloud TPU를 만든 영역입니다.

Cloud TPU VM에 JAX 설치

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

시스템 확인

JAX가 TPU에 액세스할 수 있고 기본 작업을 실행할 수 있는지 확인합니다.

Python 3 인터프리터 시작:

(vm)$ python3
>>> import jax

사용 가능한 TPU 코어 수 표시:

>>> jax.device_count()

TPU 코어 수가 표시됩니다. v4 TPU를 사용하는 경우에는 4여야 합니다. v2 또는 v3 TPU를 사용하는 경우에는 8이어야 합니다.

간단한 계산 수행:

>>> jax.numpy.add(1, 1)

numpy add의 결과가 표시됩니다.

명령어에서 출력합니다.

Array(2, dtype=int32, weak_type=true)

Python 인터프리터 종료:

>>> exit()

TPU VM에서 JAX 코드 실행

이제 원하는 JAX 코드를 실행할 수 있습니다. flax 예시는 JAX에서 표준 ML 모드 실행을 시작할 수 있는 훌륭한 장소입니다. 예를 들어 기본 MNIST 컨볼루션 네트워크를 학습시키려면 다음 안내를 따르세요.

  1. Flax 예시 종속 항목 설치

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
    
  2. FLAX를 설치합니다.

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
    
  3. FLAX MNIST 학습 스크립트 실행

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

스크립트가 데이터 세트를 다운로드하고 학습을 시작합니다. 스크립트 출력은 다음과 같아야 합니다.

  0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
  I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
  I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
  I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
  I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

삭제

TPU VM 사용이 완료되었으면 다음 단계에 따라 리소스를 삭제하세요.

  1. Compute Engine 인스턴스에서 연결을 해제합니다.

    (vm)$ exit
    
  2. Cloud TPU를 삭제합니다.

    $ gcloud compute tpus tpu-vm delete tpu-name \
      --zone=us-central2-b
    
  3. 다음 명령어를 실행하여 리소스가 삭제되었는지 확인합니다. TPU가 더 이상 나열되지 않았는지 확인합니다. 삭제하는 데 몇 분 정도 걸릴 수 있습니다.

    $ gcloud compute tpus tpu-vm list \
      --zone=us-central2-b
    

성능 참고사항

다음은 JAX에서 특히 TPU 사용과 관련된 몇 가지 중요한 세부정보입니다.

패딩

TPU에서 성능이 느려지는 가장 일반적인 원인 중 하나는 의도치 않은 패딩의 도입입니다.

  • Cloud TPU의 배열은 타일로 나누어집니다. 여기에는 차원 중 하나를 8의 배수로 패딩하고 다른 차원을 128의 배수로 패딩하는 작업이 수반됩니다.
  • 행렬 곱셈 단위는 패딩 요구를 최소화하는 대규모 행렬의 쌍에서 효과가 가장 높습니다.

bfloat16 dtype

기본적으로 JAX에서 TPU의 행렬 곱셈에는 float32 누적의 bfloat16이 사용됩니다. 이것은 관련된 jax.numpy 함수 호출에서 정밀도 인수로 제어될 수 있습니다(matmul, dot, einsum 등). 특히 다음 내용이 해당됩니다.

  • precision=jax.lax.Precision.DEFAULT: 혼합 bfloat16 정밀도 사용(가장 빠름)
  • precision=jax.lax.Precision.HIGH: 여러 MXU 패스를 사용하여 정밀도 향상
  • precision=jax.lax.Precision.HIGHEST: 훨씬 더 많은 MXU 패스를 사용하여 전체 float32 정밀도 달성

JAX는 bfloat16 dtype도 추가하여 배열을 bfloat16으로 명시적으로 변환하는 데 사용할 수 있으며 예를 들면 jax.numpy.array(x, dtype=jax.numpy.bfloat16)입니다.

Colab에서 JAX 실행

Colab 노트북에서 JAX 코드를 실행할 때 Colab은 레거시 TPU 노드를 자동으로 만듭니다. TPU 노드에는 다른 아키텍처가 사용됩니다. 자세한 내용은 시스템 아키텍처를 참조하세요.

다음 단계

Cloud TPU에 대한 자세한 내용은 다음을 참조하세요.