TPU Pod 슬라이스에서 JAX 코드 실행

단일 TPU 보드에서 JAX 코드를 실행한 후에는 TPU Pod 슬라이스에서 실행하여 코드를 수직 확장할 수 있습니다. TPU Pod 슬라이스는 전용 고속 네트워크 연결을 통해 서로 연결된 여러 TPU 보드입니다. 이 문서에서는 TPU Pod 슬라이스에서 JAX 코드 실행에 관한 내용을 소개합니다. 더 자세한 내용은 다중 호스트 및 다중 프로세스 환경에서 JAX 사용을 참조하세요.

데이터 스토리지에 마운트된 NFS를 사용하려면 포드 슬라이스에서 모든 TPU VM에 대해 OS 로그인을 설정해야 합니다. 자세한 내용은 데이터 스토리지에 NFS 사용을 참조하세요.

TPU Pod 슬라이스 만들기

이 문서의 명령어를 실행하기 전 계정 및 Cloud TPU 프로젝트 설정의 안내를 따랐는지 확인합니다. 로컬 머신에서 다음 명령어를 실행합니다.

gcloud 명령어를 사용하여 TPU Pod 슬라이스를 만듭니다. 예를 들어 v2-32 포드 슬라이스를 만들려면 다음 명령어를 사용합니다.

$ gcloud alpha compute tpus tpu-vm create tpu-name \
  --zone europe-west4-a \
  --accelerator-type v2-32 \
  --version v2-alpha

포드 슬라이스에 JAX 설치

TPU Pod 슬라이스를 만든 후 TPU Pod 슬라이스의 모든 호스트에서 JAX를 설치해야 합니다. --worker=all 옵션을 사용하여 단일 명령어로 모든 호스트에 JAX를 설치할 수 있습니다.

$ gcloud alpha compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a \
  --worker=all \
  --command="pip install 'jax[tpu]>=0.2.16' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

Pod 슬라이스에서 JAX 코드 실행

TPU Pod 슬라이스에서 JAX 코드를 실행하려면 TPU Pod 슬라이스의 각 호스트에서 코드를 실행해야 합니다. 각 호스트에 SSH로 연결하고 각 호스트에서 JAX 코드를 실행해야 합니다. 다음 Python 코드는 gcloud 명령어의 --worker=all 옵션을 사용하여 TPU Pod 슬라이스에서 간단한 JAX 계산을 실행하는 방법을 보여줍니다.

코드 준비

344.0.0 이상 gcloud 버전이 필요합니다(scp 명령어). gcloud --version을 사용하여 gcloud 버전을 확인하고 필요하면 gcloud components upgrade를 실행합니다.

로컬 머신에 example.py를 씁니다.

# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the Pod
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

example.py를 포드 슬라이스의 모든 VM에 복사합니다.

$ gcloud alpha compute tpus tpu-vm scp example.py tpu-name: \
  --worker=all --zone=europe-west4-a

이전에 scp 명령어를 사용하지 않았으면 다음과 비슷한 오류가 표시될 수 있습니다.

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH agent.
Please run "ssh-add /.../.ssh/google_compute_engine" to add it, and try again.

오류 메시지에 표시된 ssh-add 명령어를 실행하고 명령어를 다시 실행하여 오류를 해결합니다.

pod 슬라이스에서 코드 실행

모든 VM에서 example.py 프로그램을 실행합니다.

$ gcloud alpha compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a --worker=all --command "python3 example.py"

출력(v2-32 pod 슬라이스에서 생성됨):

global device count: 32
local device count: 8
pmap result: [32. 32. 32. 32. 32. 32. 32. 32.]

이것은 각 호스트에서 JAX Python 코드를 실행하는 한 가지 방법일 뿐 무엇이든 원하는 메서드를 사용하면 됩니다. 그러나 이 명령어를 실행하면 jax.device_count() 호출이 포드 슬라이스의 각 호스트에서 호출될 때까지 응답을 중지합니다. TPU 런타임을 초기화하려면 모든 호스트가 실행 중이어야 하기 때문입니다.

삭제

완료되었으면 gcloud 명령어를 사용하여 TPU VM 리소스를 해제할 수 있습니다.

$ gcloud alpha compute tpus tpu-vm delete tpu-name \
  --zone europe-west4-a