Cloud TPU VM JAX 빠른 시작

이 문서에서는 JAX 및 Cloud TPU 작업에 대해 간략히 안내합니다.

이 빠른 시작을 수행하기 전에 Google Cloud Platform 계정을 만들고 Google Cloud Platform SDK를 설치하고 gcloud 명령어를 구성해야 합니다. 자세한 내용은 계정 및 Cloud TPU 프로젝트 설정을 참조하세요.

Google Cloud SDK 설치

Google Cloud SDK에는 Google Cloud 제품 및 서비스와 상호작용하기 위한 도구 및 라이브러리가 포함되어 있습니다. 자세한 내용은 Google Cloud SDK 설치를 참조하세요.

gcloud 명령어 구성

다음 명령어를 실행하여 gcloud GCP 프로젝트를 사용하도록 구성하고 TPU VM 미리보기에 필요한 구성요소를 설치합니다.

  $ gcloud config set account your-email-account
  $ gcloud config set project project-id

Cloud TPU API 사용 설정

  1. Cloud Shell에서 다음 gcloud 명령어를 사용하여 Cloud TPU API를 사용 설정합니다. Google Cloud Console에서도 사용 설정할 수 있습니다.

    $ gcloud services enable tpu.googleapis.com
    
  2. 다음 명령어를 실행하여 서비스 계정을 만듭니다.

    $ gcloud beta services identity create --service tpu.googleapis.com
    

gcloud를 사용하여 Cloud TPU VM 만들기

Cloud TPU VMs에서는 모델 및 코드가 TPU 호스트 머신에서 직접 실행됩니다. TPU 호스트에 직접 SSH로 연결합니다. TPU 호스트에서 직접 임의 코드를 실행하고, 패키지를 설치하고, 로그를 보고, 코드를 디버깅할 수 있습니다.

  1. GCP Cloud Shell 또는 Google Cloud SDK가 설치된 컴퓨터 터미널에서 다음 명령어를 실행하여 TPU VM을 만듭니다.

    (vm)$ gcloud alpha compute tpus tpu-vm create tpu-name \
    --zone europe-west4-a \
    --accelerator-type v3-8 \
    --version v2-alpha

    필수 입력란

    zone
    Cloud TPU를 만들려는 영역입니다.
    accelerator-type
    생성할 Cloud TPU의 유형입니다.
    version
    Cloud TPU 런타임 버전입니다. 단일 TPU 기기, Pod 슬라이스, 전체 Pod에서 JAX를 사용할 때 이를 'v2-alpha'로 설정합니다.

Cloud TPU VM에 연결

다음 명령어를 사용하여 TPU VM에 SSH로 연결합니다.

$ gcloud alpha compute tpus tpu-vm ssh tpu-name --zone europe-west4-a

필수 입력란

tpu_name
연결하려는 TPU VM의 이름입니다.
zone
Cloud TPU를 만든 영역입니다.

Cloud TPU VM에 JAX 설치

(vm)$ pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

시스템 확인

JAX에 Cloud TPU가 표시되고 기본 작업을 실행할 수 있는지 확인하여 모든 것이 올바르게 설치되었는지 테스트합니다.

Python 3 인터프리터 시작:

(vm)$ python3
>>> import jax

사용 가능한 TPU 코어 수 표시:

>>> jax.device_count()

TPU 코어의 수가 표시됩니다. 8이어야 합니다.

간단한 계산 수행:

>>> jax.numpy.add(1, 1)

numpy add의 결과가 표시됩니다.

명령어에서 출력합니다.

DeviceArray(2, dtype=int32)

Python 인터프리터 종료:

>>> exit()

TPU VM에서 JAX 코드 실행

이제 원하는 JAX 코드를 실행할 수 있습니다. flax 예시는 JAX에서 표준 ML 모드 실행을 시작할 수 있는 훌륭한 장소입니다. 예를 들어 기본 MNIST 컨볼루션 네트워크를 학습시키려면 다음 안내를 따르세요.

  1. TensorFlow 데이터 세트 설치

    (vm)$ pip install --upgrade clu
    
  2. FLAX를 설치합니다.

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user -e flax
    
  3. FLAX MNIST 학습 스크립트 실행

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

    스크립트 출력은 다음과 같아야 합니다.

    I0513 21:09:35.448946 140431261813824 train.py:125] train epoch: 1, loss: 0.2312, accuracy: 93.00
    I0513 21:09:36.402860 140431261813824 train.py:176] eval epoch: 1, loss: 0.0563, accuracy: 98.05
    I0513 21:09:37.321380

삭제

TPU VM 사용이 완료되었으면 다음 단계에 따라 리소스를 삭제하세요.

  1. Compute Engine 인스턴스에서 연결을 해제합니다.

    (vm)$ exit
    
  2. Cloud TPU를 삭제합니다.

    $ gcloud alpha compute tpus tpu-vm delete tpu-name \
      --zone europe-west4-a
    
  3. 다음 명령어를 실행하여 리소스가 삭제되었는지 확인합니다. TPU가 더 이상 나열되지 않았는지 확인합니다. 삭제하는 데 몇 분 정도 걸릴 수 있습니다.

성능 참고사항

다음은 JAX에서 특히 TPU 사용과 관련된 몇 가지 중요한 세부정보입니다.

패딩

TPU에서 성능이 느려지는 가장 일반적인 원인 중 하나는 의도치 않은 패딩의 도입입니다.

  • Cloud TPU의 배열은 타일로 나누어집니다. 여기에는 차원 중 하나를 8의 배수로 패딩하고 다른 차원을 128의 배수로 패딩하는 작업이 수반됩니다.
  • 행렬 곱셈 단위는 패딩 요구를 최소화하는 대규모 행렬의 쌍에서 효과가 가장 높습니다.

bfloat16 dtype

기본적으로 JAX에서 TPU의 행렬 곱셈에는 float32 누적의 bfloat16이 사용됩니다. 이것은 관련된 jax.numpy 함수 호출에서 정밀도 인수로 제어될 수 있습니다(matmul, dot, einsum 등). 특히 다음 내용이 해당됩니다.

  • precision=jax.lax.Precision.DEFAULT: 혼합 bfloat16 정밀도(가장 빠름) 사용
  • precision=jax.lax.Precision.HIGH: 여러 MXU 패스를 사용하여 정밀도 향상
  • precision=jax.lax.Precision.HIGHEST: 훨씬 더 많은 MXU 패스를 사용하여 전체 float32 정밀도 달성

JAX는 bfloat16 dtype도 추가하여 배열을 bfloat16으로 명시적으로 변환하는 데 사용할 수 있으며 예를 들면 다음과 같습니다. jax.numpy.array(x, dtype=jax.numpy.bfloat16)

Colab에서 JAX 실행

Colab 노트북에서 JAX 코드를 실행할 때 Colab은 레거시 TPU 노드를 자동으로 만듭니다. TPU 노드에는 다른 아키텍처가 사용됩니다. 자세한 내용은 시스템 아키텍처를 참조하세요.