TPU Pod 슬라이스에서 JAX 코드 실행

이 문서에서는 TPU Pod 슬라이스에서 JAX 코드를 실행하는 방법을 설명합니다. TPU Pod 슬라이스는 전용 고속 네트워크 연결을 통해 서로 연결된 여러 TPU 보드입니다.

TPU Pod 슬라이스 만들기

gcloud 명령어를 사용하여 TPU Pod 슬라이스를 만듭니다. 예를 들어 v2-32 Pod 슬라이스를 만들려면 다음 명령어를 사용합니다.

$ gcloud alpha compute tpus tpu-vm create tpu-name \
  --zone europe-west4-a \
  --accelerator-type v2-32 \
  --version v2-alpha

Pod 슬라이스에 JAX 설치

TPU Pod 슬라이스를 만든 후에는 TPU Pod 슬라이스의 모든 호스트에 JAX를 설치해야 합니다. --worker=all 옵션을 사용하면 명령어 하나로 모든 호스트에 JAX를 설치할 수 있습니다.

$ gcloud alpha compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a \
  --worker=all \
  --command="pip install --upgrade jax jaxlib"

Pod 슬라이스에서 JAX 코드 실행

TPU Pod 슬라이스에서 JAX 코드를 실행하려면 TPU Pod 슬라이스의 각 호스트에서 코드를 실행해야 합니다. 즉, 각 호스트에 SSH로 연결하고 각 호스트에서 JAX 코드를 실행해야 합니다. 다음 Python 코드는 gcloud 명령어의 --worker=all 옵션을 사용하여 TPU Pod 슬라이스에서 간단한 JAX 계산을 실행하는 방법을 보여줍니다.

코드 준비

$ read -r -d '' PYTHON_CMD << EOF
# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the pod
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)
EOF

Pod 슬라이스에서 코드 실행

$ gcloud alpha compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a \
  --worker=all \
  --command "python3 -c \"$PYTHON_CMD\""

출력(v2-32 Pod 슬라이스에서 생성됨):

global device count: 32
local device count: 8
pmap result: [32. 32. 32. 32. 32. 32. 32. 32.]

이것은 각 호스트에서 JAX Python 코드를 실행하는 한 가지 방법일 뿐 무엇이든 원하는 메서드를 사용하면 됩니다. 그러나 이 명령어를 실행하면 위의 jax.device_count() 호출이 Pod 슬라이스의 각 호스트에서 호출될 때까지 대기하게 되는데, 이는 TPU 런타임을 초기화하려면 모든 호스트가 있어야 하기 때문입니다.

삭제

작업이 완료되면 gcloud 명령어를 사용하여 TPU VM 리소스를 해제할 수 있습니다.

$ gcloud alpha compute tpus tpu-vm delete tpu-name \
  --zone europe-west4-a