TPU Pod 슬라이스에서 JAX 코드 실행

단일 TPU 보드에서 JAX 코드를 실행한 후에는 TPU Pod 슬라이스에서 실행하여 코드를 수직 확장할 수 있습니다. TPU Pod 슬라이스는 전용 고속 네트워크 연결을 통해 서로 연결된 여러 TPU 보드입니다. 이 문서에서는 TPU Pod 슬라이스에서 JAX 코드 실행에 관한 내용을 소개합니다. 더 자세한 내용은 다중 호스트 및 다중 프로세스 환경에서 JAX 사용을 참조하세요.

데이터 스토리지에 마운트된 NFS를 사용하려면 포드 슬라이스의 모든 TPU VM에 대해 OS 로그인을 설정해야 합니다. 자세한 내용은 데이터 스토리지에 NFS 사용을 참조하세요.

TPU Pod 슬라이스 만들기

이 문서의 명령어를 실행하기 전 계정 및 Cloud TPU 프로젝트 설정의 안내를 따르도록 유의하세요. 로컬 머신에서 다음 명령어를 실행합니다.

gcloud 명령어를 사용하여 TPU Pod 슬라이스를 만듭니다. 예를 들어 v4-32 포드 슬라이스를 만들려면 다음 명령어를 사용합니다.

$ gcloud compute tpus tpu-vm create tpu-name  \
  --zone=us-central2-b \
  --accelerator-type=v4-32  \
  --version=tpu-ubuntu2204-base 

포드 슬라이스에 JAX 설치

TPU Pod 슬라이스를 만든 후 TPU Pod 슬라이스의 모든 호스트에서 JAX를 설치해야 합니다. --worker=all 옵션을 사용하여 단일 명령어로 모든 호스트에 JAX를 설치할 수 있습니다.

  gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b --worker=all --command="pip install \
  --upgrade 'jax[tpu]>0.3.0' \
  -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

포드 슬라이스에서 JAX 코드 실행

TPU Pod 슬라이스에서 JAX 코드를 실행하려면 TPU Pod 슬라이스의 각 호스트에서 코드를 실행해야 합니다. jax.device_count() 호출은 포드 슬라이스의 각 호스트에서 호출될 때까지 응답을 중지합니다. 다음 예시에서는 TPU Pod 슬라이스에서 간단한 JAX 계산을 실행하는 방법을 보여줍니다.

코드 준비

344.0.0 이상 gcloud 버전이 필요합니다(scp 명령어). gcloud --version을 사용하여 gcloud 버전을 확인하고 필요하면 gcloud components upgrade를 실행합니다.

다음 코드를 사용하여 example.py라는 파일을 만듭니다.

# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the Pod
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

example.py를 포드 슬라이스의 모든 TPU 워커 VM에 복사

$ gcloud compute tpus tpu-vm scp example.py tpu-name: \
  --worker=all \
  --zone=us-central2-b

이전에 scp 명령어를 사용하지 않았으면 다음과 비슷한 오류가 표시될 수 있습니다.

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

오류를 해결하려면 오류 메시지에 표시된 ssh-add 명령어를 실행하고 명령어를 다시 실행합니다.

Pod 슬라이스에서 코드 실행

모든 VM에서 example.py 프로그램을 실행합니다.

$ gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b \
  --worker=all \
  --command="python3 example.py"

출력(v4-32 Pod 슬라이스에서 생성됨):

global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]

삭제

완료되었으면 gcloud 명령어를 사용하여 TPU VM 리소스를 해제할 수 있습니다.

$ gcloud compute tpus tpu-vm delete tpu-name \
  --zone=us-central2-b