TPU Pod 슬라이스에서 JAX 코드 실행

단일 TPU 보드에서 JAX 코드를 실행한 후에는 TPU Pod 슬라이스에서 실행하여 코드를 수직 확장할 수 있습니다. TPU Pod 슬라이스는 전용 고속 네트워크 연결을 통해 서로 연결된 여러 TPU 보드입니다. 이 문서에서는 TPU Pod 슬라이스에서 JAX 코드 실행에 관한 내용을 소개합니다. 더 자세한 내용은 다중 호스트 및 다중 프로세스 환경에서 JAX 사용을 참조하세요.

데이터 스토리지에 마운트된 NFS를 사용하려면 포드 슬라이스의 모든 TPU VM에 대해 OS 로그인을 설정해야 합니다. 자세한 내용은 데이터 스토리지에 NFS 사용을 참조하세요.

환경 설정하기

  1. Cloud Shell에서 다음 명령어를 실행하여 gcloud의 최신 버전을 실행하세요.

    $ gcloud components update

    gcloud를 설치해야 하는 경우 다음 명령어를 사용합니다.

    $ sudo apt install -y google-cloud-sdk
  2. 다음과 같이 몇 가지 환경 변수를 만듭니다.

    $ export TPU_NAME=tpu-name
    $ export ZONE=us-central2-b
    $ export RUNTIME_VERSION=tpu-ubuntu2204-base
    $ export ACCELERATOR_TYPE=v4-32

TPU Pod 슬라이스 만들기

이 문서의 명령어를 실행하기 전 계정 및 Cloud TPU 프로젝트 설정의 안내를 따르도록 유의하세요. 로컬 머신에서 다음 명령어를 실행합니다.

gcloud 명령어를 사용하여 TPU Pod 슬라이스를 만듭니다. 예를 들어 v4-32 포드 슬라이스를 만들려면 다음 명령어를 사용합니다.

$ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
  --zone=${ZONE} \
  --accelerator-type=${ACCELERATOR_TYPE}  \
  --version=${RUNTIME_VERSION} 

포드 슬라이스에 JAX 설치

TPU Pod 슬라이스를 만든 후 TPU Pod 슬라이스의 모든 호스트에서 JAX를 설치해야 합니다. --worker=all 옵션을 사용하여 단일 명령어로 모든 호스트에 JAX를 설치할 수 있습니다.

  gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --worker=all \
  --command="pip install \
  --upgrade 'jax[tpu]>0.3.0' \
  -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

포드 슬라이스에서 JAX 코드 실행

TPU Pod 슬라이스에서 JAX 코드를 실행하려면 TPU Pod 슬라이스의 각 호스트에서 코드를 실행해야 합니다. jax.device_count() 호출은 포드 슬라이스의 각 호스트에서 호출될 때까지 응답을 중지합니다. 다음 예시에서는 TPU Pod 슬라이스에서 JAX 계산을 실행하는 방법을 보여줍니다.

코드 준비

344.0.0 이상 gcloud 버전이 필요합니다(scp 명령어). gcloud --version을 사용하여 gcloud 버전을 확인하고 필요하면 gcloud components upgrade를 실행합니다.

다음 코드를 사용하여 example.py라는 파일을 만듭니다.

# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the Pod
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

example.py를 포드 슬라이스의 모든 TPU 워커 VM에 복사

$ gcloud compute tpus tpu-vm scp example.py ${TPU_NAME} \
  --worker=all \
  --zone=${ZONE}

이전에 scp 명령어를 사용하지 않았으면 다음과 비슷한 오류가 표시될 수 있습니다.

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

오류를 해결하려면 오류 메시지에 표시된 대로 ssh-add 명령어를 실행하고 명령어를 다시 실행합니다.

Pod 슬라이스에서 코드 실행

모든 VM에서 example.py 프로그램을 실행합니다.

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --worker=all \
  --command="python3 example.py"

출력(v4-32 Pod 슬라이스에서 생성됨):

global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]

삭제

TPU VM 사용이 완료되었으면 다음 단계에 따라 리소스를 삭제하세요.

  1. Compute Engine 인스턴스에서 연결을 해제합니다.

    (vm)$ exit

    프롬프트가 username@projectname으로 바뀌면 Cloud Shell에 있는 것입니다.

  2. Cloud TPU 및 Compute Engine 리소스를 삭제합니다.

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE}
  3. gcloud compute tpus execution-groups list를 실행하여 리소스가 삭제되었는지 확인합니다. 삭제하는 데 몇 분 정도 걸릴 수 있습니다. 다음 명령어의 출력에는 이 튜토리얼에서 만든 리소스가 포함되어서는 안 됩니다.

    $ gcloud compute tpus tpu-vm list --zone=${ZONE}