TPU Pod 슬라이스에서 JAX 코드 실행
단일 TPU 보드에서 JAX 코드를 실행한 후에는 TPU Pod 슬라이스에서 실행하여 코드를 수직 확장할 수 있습니다. TPU Pod 슬라이스는 전용 고속 네트워크 연결을 통해 서로 연결된 여러 TPU 보드입니다. 이 문서에서는 TPU Pod 슬라이스에서 JAX 코드 실행에 관한 내용을 소개합니다. 더 자세한 내용은 다중 호스트 및 다중 프로세스 환경에서 JAX 사용을 참조하세요.
데이터 스토리지에 마운트된 NFS를 사용하려면 포드 슬라이스의 모든 TPU VM에 대해 OS 로그인을 설정해야 합니다. 자세한 내용은 데이터 스토리지에 NFS 사용을 참조하세요.TPU Pod 슬라이스 만들기
이 문서의 명령어를 실행하기 전 계정 및 Cloud TPU 프로젝트 설정의 안내를 따르도록 유의하세요. 로컬 머신에서 다음 명령어를 실행합니다.
gcloud
명령어를 사용하여 TPU Pod 슬라이스를 만듭니다. 예를 들어 v4-32 포드 슬라이스를 만들려면 다음 명령어를 사용합니다.
$ gcloud compute tpus tpu-vm create tpu-name \ --zone=us-central2-b \ --accelerator-type=v4-32 \ --version=tpu-ubuntu2204-base
포드 슬라이스에 JAX 설치
TPU Pod 슬라이스를 만든 후 TPU Pod 슬라이스의 모든 호스트에서 JAX를 설치해야 합니다. --worker=all
옵션을 사용하여 단일 명령어로 모든 호스트에 JAX를 설치할 수 있습니다.
gcloud compute tpus tpu-vm ssh tpu-name \ --zone=us-central2-b --worker=all --command="pip install \ --upgrade 'jax[tpu]>0.3.0' \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
포드 슬라이스에서 JAX 코드 실행
TPU Pod 슬라이스에서 JAX 코드를 실행하려면 TPU Pod 슬라이스의 각 호스트에서 코드를 실행해야 합니다. jax.device_count()
호출은 포드 슬라이스의 각 호스트에서 호출될 때까지 응답을 중지합니다. 다음 예시에서는 TPU Pod 슬라이스에서 간단한 JAX 계산을 실행하는 방법을 보여줍니다.
코드 준비
344.0.0 이상의 gcloud
버전이 필요합니다(scp
명령어).
gcloud --version
을 사용하여 gcloud
버전을 확인하고 필요하면 gcloud components upgrade
를 실행합니다.
다음 코드를 사용하여 example.py
라는 파일을 만듭니다.
# The following code snippet will be run on all TPU hosts
import jax
# The total number of TPU cores in the Pod
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
example.py
를 포드 슬라이스의 모든 TPU 워커 VM에 복사
$ gcloud compute tpus tpu-vm scp example.py tpu-name: \ --worker=all \ --zone=us-central2-b
이전에 scp
명령어를 사용하지 않았으면 다음과 비슷한 오류가 표시될 수 있습니다.
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try again.
오류를 해결하려면 오류 메시지에 표시된 대로 ssh-add
명령어를 실행하고 명령어를 다시 실행합니다.
Pod 슬라이스에서 코드 실행
모든 VM에서 example.py
프로그램을 실행합니다.
$ gcloud compute tpus tpu-vm ssh tpu-name \ --zone=us-central2-b \ --worker=all \ --command="python3 example.py"
출력(v4-32 Pod 슬라이스에서 생성됨):
global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]
삭제
완료되었으면 gcloud
명령어를 사용하여 TPU VM 리소스를 해제할 수 있습니다.
$ gcloud compute tpus tpu-vm delete tpu-name \ --zone=us-central2-b