Cloud TPU VM JAX 빠른 시작

이 문서에서는 JAX 및 Cloud TPU 작업을 간략히 소개합니다.

Google 계정으로 로그인합니다. 아직 계정이 없으면 새 계정을 등록하세요. Google Cloud Console의 프로젝트 선택기 페이지에서 클라우드 프로젝트를 선택하거나 만듭니다. 프로젝트에 결제가 사용 설정되어 있는지 확인합니다.

Google Cloud SDK 설치

Google Cloud SDK에는 Google Cloud 제품 및 서비스와 상호작용하기 위한 도구 및 라이브러리가 포함되어 있습니다. 자세한 내용은 Google Cloud SDK 설치를 참조하세요.

gcloud 명령어 구성

다음 명령어를 실행하여 gcloud GCP 프로젝트를 사용하도록 구성하고 TPU VM 미리보기에 필요한 구성요소를 설치합니다.

  $ gcloud config set account your-email-account
  $ gcloud config set project project-id

Cloud TPU API 사용 설정

  1. Cloud Shell에서 다음 gcloud 명령어를 사용하여 Cloud TPU API를 사용 설정합니다. Google Cloud Console에서도 사용 설정할 수 있습니다.

    $ gcloud services enable tpu.googleapis.com
    
  2. 다음 명령어를 실행하여 서비스 계정을 만듭니다.

    $ gcloud beta services identity create --service tpu.googleapis.com
    

gcloud를 사용하여 Cloud TPU VM 만들기

Cloud TPU VM을 사용하면 모델과 코드가 TPU 호스트 머신에서 직접 실행됩니다. TPU 호스트에 SSH를 통해 연결합니다. TPU 호스트에서 직접 임의의 코드를 실행하고, 패키지를 설치하며, 로그를 확인하고, 코드를 디버깅할 수 있습니다.

  1. GCP Cloud Shell 또는 Google Cloud SDK가 설치된 컴퓨터 터미널에서 다음 명령어를 실행하여 TPU VM을 만듭니다.

    (vm)$ gcloud alpha compute tpus tpu-vm create tpu-name \
    --zone europe-west4-a \
    --accelerator-type v3-8 \
    --version v2-alpha

    필수 입력란

    zone
    Cloud TPU를 만들려는 영역입니다.
    accelerator-type
    생성할 Cloud TPU의 유형입니다.
    version
    Cloud TPU 런타임 버전입니다. 단일 TPU 기기, Pod 슬라이스 또는 전체 Pod에서 JAX를 사용하는 경우 이를 'v2-alpha'로 설정합니다.

Cloud TPU VM에 연결

다음 명령어를 사용하여 TPU VM에 SSH를 통해 연결합니다.

$ gcloud alpha compute tpus tpu-vm ssh tpu-name --zone europe-west4-a

필수 입력란

tpu_name
연결하려는 TPU VM의 이름입니다.
zone
Cloud TPU를 만든 영역입니다.

Cloud TPU VM에 JAX 설치

(vm)$ pip3 install --upgrade jax jaxlib

시스템 확인

JAX에 Cloud TPU 코어가 있는지 확인하고 기본 작업을 실행할 수 있는지 확인하여 모든 항목이 올바르게 설치되었는지 테스트합니다.

Python 3 인터프리터 시작:

(vm)$ python3
>>> import jax

사용 가능한 TPU 코어 수 표시:

>>> jax.device_count()

TPU 코어의 수가 표시됩니다. 8이어야 합니다.

간단한 계산 수행:

>>> jax.numpy.add(1, 1)

Numpy Add 결과가 표시됩니다.

명령어의 출력:

DeviceArray(2, dtype=int32)

Python 인터프리터 종료:

>>> exit()

TPU VM에서 JAX 코드 실행

이제 원하는 JAX 코드를 실행할 수 있습니다. flax 예시는 JAX에서 표준 ML 모델 실행을 시작하기에 좋은 출발점입니다. 예를 들어 기본 MNIST 컨볼루셔널 네트워크를 학습시키려면 다음을 실행하세요.

  1. TensorFlow 데이터 세트 설치

    (vm)$ pip install --upgrade clu
    
  2. FLAX를 설치합니다.

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user -e flax
    
  3. FLAX MNIST 학습 스크립트 실행

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

    스크립트 출력은 다음과 같아야 합니다.

    I0513 21:09:35.448946 140431261813824 train.py:125] train epoch: 1, loss: 0.2312, accuracy: 93.00
    I0513 21:09:36.402860 140431261813824 train.py:176] eval epoch: 1, loss: 0.0563, accuracy: 98.05
    I0513 21:09:37.321380

삭제

TPU VM 작업을 마쳤으면 다음 단계에 따라 리소스를 삭제합니다.

  1. Compute Engine 인스턴스에서 연결을 해제합니다.

    (vm)$ exit
    
  2. Cloud TPU를 삭제합니다.

    $ gcloud alpha compute tpus tpu-vm delete tpu-name \
      --zone europe-west4-a
    
  3. 다음 명령어를 실행하여 리소스가 삭제되었는지 확인합니다. TPU가 더 이상 표시되지 않는지 확인합니다. 삭제하는 데 몇 분 정도 걸릴 수 있습니다.

성능 참고사항

다음은 JAX에서 TPU 사용과 관련된 몇 가지 중요한 세부정보입니다.

패딩

TPU에서 성능이 저하되는 가장 일반적인 원인 중 하나는 부적절한 패딩을 도입하는 것입니다.

  • Cloud TPU의 배열은 타일로 나누어집니다. 여기에는 차원 중 하나를 8의 배수로 패딩하고 다른 차원을 128의 배수로 패딩하는 작업이 수반됩니다.
  • 행렬 곱셈 단위는 패딩의 필요성을 최소화하는 큰 행렬 쌍에서 가장 잘 작동합니다.

bfloat16 dtype

기본적으로 TPU에서 JAX의 행렬 곱셈은 float32 누적으로 bfloat16을 사용합니다. 이는 관련 jax.numpy 함수 호출(matmul, dot, einsum 등)의 정밀도 인수로 제어할 수 있습니다. 특히 다음 내용이 해당됩니다.

  • precision=jax.lax.Precision.DEFAULT: 혼합 bfloat16 정밀도(가장 빠름) 사용
  • precision=jax.lax.Precision.HIGH: 여러 MXU 패스를 사용하여 정밀도 향상
  • precision=jax.lax.Precision.HIGHEST: 훨씬 더 많은 MXU 패스를 사용하여 전체 float32 정밀도 달성

JAX는 bfloat16 dtype도 추가하여 배열을 bfloat16으로 명시적으로 변환하는 데 사용할 수 있으며 예를 들면 다음과 같습니다. jax.numpy.array(x, dtype=jax.numpy.bfloat16)

Colab에서 JAX 실행

Colab 노트북에서 JAX 코드를 실행하면 Colab은 자동으로 레거시 TPU 노드를 만듭니다. TPU 노드의 아키텍처가 다릅니다. 자세한 내용은 시스템 아키텍처를 참조하세요.