Cloud TPU PyTorch/XLA 사용자 가이드

PyTorch/XLA를 사용하여 ML 워크로드 실행

이 가이드에서는 PyTorch를 사용하여 v4 TPU에서 간단한 계산을 수행하는 방법을 안내합니다.

기본 설정

  1. Pytorch 2.0용 TPU VM 런타임을 실행하는 v4 TPU로 TPU VM을 만듭니다.

      gcloud compute tpus tpu-vm create your-tpu-name \
      --zone=us-central2-b \
      --accelerator-type=v4-8 \
      --version=tpu-vm-v4-pt-2.0
    
  2. SSH를 사용하여 TPU VM에 연결:

      gcloud compute tpus tpu-vm ssh your-tpu-name \
      --zone=us-central2-b \
      --accelerator-type=v4-8
    
  3. PJRT 또는 XRT TPU 기기 구성을 설정합니다.

    PJRT

        (vm)$ export PJRT_DEVICE=TPU
     

    XRT

        (vm)$ export XRT_TPU_CONFIG="localservice;0;localhost:51011"
     

  4. Cloud TPU v4로 학습시키는 경우 다음 환경 변수도 설정합니다.

      (vm)$ export TPU_NUM_DEVICES=4
    

간단한 계산 수행

  1. TPU VM에서 Python 인터프리터를 시작합니다.

    (vm)$ python3
    
  2. 다음 PyTorch 패키지를 가져옵니다.

    import torch
    import torch_xla.core.xla_model as xm
    
  3. 다음 스크립트를 입력합니다.

    dev = xm.xla_device()
    t1 = torch.randn(3,3,device=dev)
    t2 = torch.randn(3,3,device=dev)
    print(t1 + t2)
    

    다음 출력이 표시됩니다.

    tensor([[-0.2121,  1.5589, -0.6951],
           [-0.7886, -0.2022,  0.9242],
           [ 0.8555, -1.8698,  1.4333]], device='xla:1')
    

단일 기기 TPU에서 Resnet 실행

이제 원하는 PyTorch/XLA 코드를 실행할 수 있습니다. 예를 들어 가짜 데이터를 사용해서 ResNet 모델을 실행할 수 있습니다.

(vm)$ git clone --recursive https://github.com/pytorch/xla.git
(vm)$ python3 xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1

ResNet 샘플은 1개 에포크에 대해 학습이 진행되고 약 7분 정도가 걸립니다. 다음과 비슷한 출력이 반환됩니다.

Epoch 1 test end 20:57:52, Accuracy=100.00 Max Accuracy: 100.00%

ResNet 학습이 끝나면 TPU VM을 삭제하세요.

(vm)$ exit
$ gcloud compute tpus tpu-vm delete tpu-name \
--zone=zone

삭제하는 데 몇 분 정도 걸릴 수 있습니다. gcloud compute tpus list --zone=${ZONE}를 실행하여 리소스가 삭제되었는지 확인합니다.

심화 주제

상당한 크기의 자주 할당되는 모델의 경우 tcmalloc가 C/C++ 런타임 함수 malloc에 비해 성능을 개선합니다. TPU VM에 사용되는 기본 malloctcmalloc입니다. LD_PRELOAD 환경 변수를 설정 해제하여 TPU VM 소프트웨어가 표준 malloc를 사용하도록 강제할 수 있습니다.

   (vm)$ unset LD_PRELOAD

이전 예시(간단한 계산 및 ResNet50)에서 PyTorch/XLA 프로그램은 Python 인터프리터와 동일한 프로세스로 로컬 XRT 서버를 시작합니다. 별도의 프로세스로 XRT 로컬 서비스를 시작할 수도 있습니다.

(vm)$ python3 -m torch_xla.core.xrt_run_server --port 51011 --restart

이 접근 방식의 장점은 학습을 실행할 때 컴파일 캐시가 지속된다는 것입니다. 개별 프로세스로 XLA 서버를 실행하면 서버 측 로깅 정보가 /tmp/xrt_server_log에 기록됩니다.

(vm)$ ls /tmp/xrt_server_log/
server_20210401-031010.log

TPU VM 성능 프로파일링

TPU VM에서 모델 프로파일링에 대한 자세한 내용은 PyTorch XLA 성능 프로파일링을 참조하세요.

PyTorch/XLA TPU Pod 예시

TPU VM Pod에서 PyTorch/XLA를 실행하는 설정 정보와 예시는 PyTorch TPU VM Pod를 참조하세요.

TPU VM의 Docker

이 섹션에서는 PyTorch/XLA가 사전 설치된 TPU VM에서 Docker를 실행하는 방법을 보여줍니다.

사용 가능한 Docker 이미지

GitHub 리드미를 참조하여 사용 가능한 모든 TPU VM Docker 이미지를 찾을 수 있습니다.

TPU VM에서 Docker 이미지 실행

(tpuvm): sudo docker pull gcr.io/tpu-pytorch/xla:nightly_3.8_tpuvm
(tpuvm): sudo docker run --privileged  --shm-size 16G --name tpuvm_docker -it -d  gcr.io/tpu-pytorch/xla:nightly_3.8_tpuvm
(tpuvm): sudo docker exec --privileged -it tpuvm_docker /bin/bash
(pytorch) root:/#

libtpu 확인

libtpu가 설치되어 있는지 확인하려면 다음을 실행합니다.

(pytorch) root:/# ls /root/anaconda3/envs/pytorch/lib/python3.8/site-packages/ | grep libtpu
다음과 비슷한 출력이 생성되어야 합니다.
libtpu
libtpu_nightly-0.1.dev20220518.dist-info

결과가 표시되지 않으면 다음을 사용하여 해당 libtpu를 수동으로 설치할 수 있습니다.

(pytorch) root:/# pip install torch_xla[tpuvm]

tcmalloc 확인하기

tcmalloc은 TPU VM에서 사용되는 기본 malloc입니다. 자세한 내용은 이 섹션을 참조하세요. 이 라이브러리는 최신 TPU VM Docker 이미지에 사전 설치되어 있어야 하지만 항상 수동으로 확인하는 것이 좋습니다. 다음 명령어를 실행하면 라이브러리가 설치되었는지 확인할 수 있습니다.

(pytorch) root:/# echo $LD_PRELOAD
다음과 비슷한 출력이 생성되어야 합니다.
/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4

LD_PRELOAD가 설정되지 않았으면 수동으로 실행할 수 있습니다.

(pytorch) root:/# sudo apt-get install -y google-perftools
(pytorch) root:/# export LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4"

기기 확인

다음을 실행하여 TPU VM 기기를 사용할 수 있는지 확인할 수 있습니다.

(pytorch) root:/# ls /dev | grep accel
그러면 다음과 같은 결과가 생성되어야 합니다.
accel0
accel1
accel2
accel3

결과가 표시되지 않으면 --privileged 플래그로 컨테이너를 시작하지 않았을 가능성이 높습니다.

모델 실행

다음을 실행하여 TPU VM 기기를 사용할 수 있는지 확인할 수 있습니다.

(pytorch) root:/# export XRT_TPU_CONFIG="localservice;0;localhost:51011"
(pytorch) root:/# python3 pytorch/xla/test/test_train_mp_imagenet.py --fake_data --num_epochs 1