Ray를 사용하여 ML 워크로드 확장

소개

Cloud TPU Ray 도구는 Cloud TPU에서 사용자의 개발 환경을 개선하기 위해 Cloud TPU APIRay Jobs를 결합합니다. 이 사용자 가이드에서는 Cloud TPU에서 Ray를 사용하는 방법에 대한 최소한의 예시를 제공합니다. 이 예시는 프로덕션 서비스에서 사용할 수 없으며 참고용입니다.

이 도구에 포함된 기능

편의를 위해 도구에서 다음 기능을 제공합니다.

  • 일반적인 TPU 작업을 위해 상용구를 숨기는 일반 추상화
  • 자체 기본 워크플로에 포크할 수 있는 실습 예시

구체적으로는 다음과 같습니다.

  • tpu_api.py: Cloud TPU API를 사용하는 기본 TPU 작업을 위한 Python 래퍼입니다.
  • tpu_controller.py: TPU의 클래스 표현입니다. 기본적으로 tpu_api.py의 래퍼입니다.
  • ray_tpu_controller.py: Ray 기능이 있는 TPU 컨트롤러입니다. 이 기능을 사용하면 Ray 클러스터 및 Ray Jobs의 상용구가 추상화됩니다.
  • run_basic_jax.py: print(jax.device_count())RayTpuController를 사용하는 방법을 보여주는 기본 예시입니다.
  • run_hp_search.py: MNIST에서 JAX/Flax와 함께 Ray Tune을 사용하는 방법을 보여주는 기본 예시입니다.
  • run_pax_autoresume.py: 예시 워크로드로 PAX를 사용하여 내결함성 학습에 RayTpuController를 사용하는 방법을 보여주는 예시입니다.

Ray 클러스터 헤드 노드 설정

TPU Pod에서 Ray를 사용하는 기본 방법 중 하나는 TPU Pod를 Ray 클러스터로 설정하는 것입니다. 별도의 CPU VM을 조정자 VM으로 만드는 것이 자연스러운 방법입니다. 다음 그래픽은 Ray 클러스터 구성의 예시를 보여줍니다.

Ray 클러스터 구성 예시

다음 명령어는 Google Cloud CLI를 사용하여 Ray 클러스터를 설정하는 방법을 보여줍니다.

$ gcloud compute instances create my_tpu_admin --machine-type=n1-standard-4 ...
$ gcloud compute ssh my_tpu_admin

$ (vm) pip3 install ray[default]
$ (vm) ray start --head --port=6379 --num-cpus=0
...
# (Ray returns the IP address of the HEAD node, for example, RAY_HEAD_IP)
$ (vm) gcloud compute tpus tpu-vm create $TPU_NAME ... --metadata startup-script="pip3 install ray && ray start --address=$RAY_HEAD_IP --resources='{\"tpu_host\": 1}'"

편의를 위해 조정자 VM을 만들고 이 폴더의 콘텐츠를 조정자 VM에 배포하기 위한 기본 스크립트도 제공합니다. 소스 코드는 create_cpu.shdeploy.sh를 참조하세요.

이러한 스크립트는 몇 가지 기본값을 설정합니다.

  • create_cpu.sh$USER-admin이라는 VM을 만들고 gcloud config 기본값으로 설정된 모든 프로젝트 및 영역을 활용합니다. 이러한 기본값을 보려면 gcloud config list를 실행하세요.
  • create_cpu.sh는 기본적으로 부팅 디스크 크기를 200GB로 할당합니다.
  • deploy.sh은 VM 이름이 $USER-admin이라고 가정합니다. create_cpu.sh에서 이 값을 변경하는 경우 deploy.sh에서 변경해야 합니다.

제공된 스크립트를 사용하려면 다음 안내를 따르세요.

  1. GitHub 저장소를 로컬 머신에 클론하고 ray_tpu 폴더를 입력합니다.

    $ git clone https://github.com/tensorflow/tpu.git
    $ cd tpu/tools/ray_tpu/
    
  2. TPU 관리를 위한 전용 서비스 계정(적극 권장)이 없으면 다음과 같이 계정을 설정합니다.

    $ ./create_tpu_service_account.sh
    
  3. 조정자 VM을 만듭니다.

    $ ./create_cpu.sh
    

    이 스크립트는 시작 스크립트를 사용하여 VM에 종속 항목을 설치하고 시작 스크립트가 완료될 때까지 자동으로 차단합니다.

  4. 조정자 VM에 로컬 코드를 배포합니다.

    $ ./deploy.sh
    
  5. SSH를 통해 VM에 연결합니다.

    $ gcloud compute ssh $USER-admin -- -L8265:localhost:8265
    

    Ray가 포트 8265에서 대시보드를 자동으로 시작하므로 포트 전달은 여기에서 사용 설정됩니다. SSH를 통해 조정자 VM에 있는 머신에서 http://127.0.0.1:8265/의 이 대시보드에 액세스할 수 있습니다.

  6. 0단계를 건너뛰었으면 CPU VM 내에 gcloud 사용자 인증 정보를 설정합니다.

    $ (vm) gcloud auth login --update-adc
    

    이 단계에서는 프로젝트 ID 정보를 설정하고 조정자 VM에서 Cloud TPU API를 실행할 수 있도록 합니다.

  7. 설치 요구사항:

    $ (vm) pip3 install -r src/requirements.txt
    
  8. 조정자 VM에서 Ray를 시작하고 조정자 VM은 Ray 클러스터의 헤드 노드가 됩니다.

    $ (vm) ray start --head --port=6379 --num-cpus=0
    

사용 예시

기본 JAX 예시

run_basic_jax.py는 TPU VM이 있는 Ray 클러스터에서 Ray Jobs 및 Ray 런타임 환경을 사용하여 JAX 워크로드를 실행하는 방법을 보여주는 최소한의 예시입니다.

JAX 및 PyTorch/XLA PJRT와 같은 멀티 컨트롤러 프로그래밍 모델을 사용하는 Cloud TPU와 호환되는 ML 프레임워크의 경우 호스트당 프로세스를 최소 한 개 이상 실행해야 합니다. 자세한 내용은 다중 프로세스 프로그래밍 모델을 참조하세요. 실제로는 다음과 같이 표시됩니다.

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py"

v4-128과 같이 16개 이상의 호스트가 있는 경우 SSH 확장성 문제가 발생하고 명령어를 다음과 같이 변경해야 할 수 있습니다.

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all --batch-size=8
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py &" --batch-size=8

my_bug_free_python_code에 버그가 있으면 개발자 속도에 방해가 될 수 있습니다. 이 문제를 해결하는 방법 중 하나는 Kubernetes 또는 Ray와 같은 조정자를 사용하는 것입니다. Ray에는 적용 시 Ray 애플리케이션이 실행될 때 코드 및 종속 항목을 배포하는 런타임 환경이라는 개념이 포함되어 있습니다.

Ray 런타임 환경을 Ray 클러스터 및 Ray Jobs와 결합하면 SCP/SSH 주기를 우회할 수 있습니다. 위의 예시를 따라했다고 가정하면 다음과 같이 실행할 수 있습니다.

$ python3 legacy/run_basic_jax.py

출력은 다음과 비슷합니다.

2023-03-01 22:12:10,065   INFO worker.py:1364 -- Connecting to existing Ray cluster at address: 10.130.0.19:6379...
2023-03-01 22:12:10,072   INFO worker.py:1544 -- Connected to Ray cluster. View the dashboard at http://127.0.0.1:8265
W0301 22:12:11.148555 140341931026240 ray_tpu_controller.py:143] TPU is not found, create tpu...
Creating TPU:  $USER-ray-test
Request:  {'accelerator_config': {'topology': '2x2x2', 'type': 'V4'}, 'runtimeVersion': 'tpu-ubuntu2204-base', 'networkConfig': {'enableExternalIps': True}, 'metadata': {'startup-script': '#! /bin/bash\necho "hello world"\nmkdir -p /dev/shm\nsudo mount -t tmpfs -o size=100g tmpfs /dev/shm\n pip3 install ray[default]\nray start --resources=\'{"tpu_host": 1}\' --address=10.130.0.19:6379'}}
Create TPU operation still running...
...
Create TPU operation complete.
I0301 22:13:17.795493 140341931026240 ray_tpu_controller.py:121] Detected 0 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:13:17.795823 140341931026240 ray_tpu_controller.py:160] Waiting for 30s for TPU hosts to join cluster...
…
I0301 22:15:17.986352 140341931026240 ray_tpu_controller.py:121] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:15:17.986503 140341931026240 ray_tpu_controller.py:90] Ray already started on each host.
2023-03-01 22:15:18,010   INFO dashboard_sdk.py:315 -- Uploading package gcs://_ray_pkg_3599972ae38ce933.zip.
2023-03-01 22:15:18,010   INFO packaging.py:503 -- Creating a file package for local directory '/home/$USER/src'.
2023-03-01 22:15:18,080   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_3599972ae38ce933.zip already exists, skipping upload.
I0301 22:15:18.455581 140341931026240 ray_tpu_controller.py:169] Queued 2 jobs.
...
I0301 22:15:48.523541 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_WRUtVB7nMaRTgK39: Status is SUCCEEDED
I0301 22:15:48.561111 140341931026240 ray_tpu_controller.py:256] [raysubmit_WRUtVB7nMaRTgK39]: E0301 22:15:36.294834089   21286 credentials_generic.cc:35]            Could not get HOME environment variable.
8

I0301 22:15:58.575289 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_yPCPXHiFgaCK2rBY: Status is SUCCEEDED
I0301 22:15:58.584667 140341931026240 ray_tpu_controller.py:256] [raysubmit_yPCPXHiFgaCK2rBY]: E0301 22:15:35.720800499    8561 credentials_generic.cc:35]            Could not get HOME environment variable.
8

내결함성 학습

이 예시에서는 RayTpuController를 사용하여 내결함성 학습을 구현하는 방법을 보여줍니다. 이 예시에서는 v4-16의 PAX에서 간단한 LLM을 사전 학습하지만 이 PAX 워크로드를 다른 장기 실행 워크로드로 바꿀 수 있습니다. 소스 코드는 run_pax_autoresume.py를 참조하세요.

이 예시를 실행하려면 다음 안내를 따르세요.

  1. 조정자 VM에 paxml을 클론합니다.

    $ git clone https://github.com/google/paxml.git
    

    Ray 런타임 환경에서 JAX 변경 및 배포를 위해 제공하는 사용 편의성을 보여주기 위해 이 예시에서는 PAX를 수정해야 합니다.

  2. 새 실험 구성을 추가합니다.

    $ cat <<EOT >> paxml/paxml/tasks/lm/params/lm_cloud.py
    
    @experiment_registry.register
    class TestModel(LmCloudSpmd2BLimitSteps):
    ICI_MESH_SHAPE = [1, 4, 2]
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_CONTEXT_AND_OUT_PROJ
    
    def task(self) -> tasks_lib.SingleTask.HParams:
      task_p = super().task()
      task_p.train.num_train_steps = 1000
      task_p.train.save_interval_steps = 100
      return task_p
    EOT
    
  3. run_pax_autoresume.py를 실행합니다.

    $ python3 legacy/run_pax_autoresume.py --model_dir=gs://your/gcs/bucket
    
  4. 워크로드가 실행되면 기본적으로 $USER-tpu-ray라는 TPU를 삭제하면 어떻게 되는지 실험합니다.

    $ gcloud compute tpus tpu-vm delete -q $USER-tpu-ray --zone=us-central2-b
    

    Ray는 다음 메시지와 함께 TPU가 다운되었음을 감지합니다.

    I0303 05:12:47.384248 140280737294144 checkpointer.py:64] Saving item to gs://$USER-us-central2/pax/v4-16-autoresume-test/checkpoints/checkpoint_00000200/metadata.
    W0303 05:15:17.707648 140051311609600 ray_tpu_controller.py:127] TPU is not found, create tpu...
    2023-03-03 05:15:30,774 WARNING worker.py:1866 -- The node with node id: 9426f44574cce4866be798cfed308f2d3e21ba69487d422872cdd6e3 and address: 10.130.0.113 and node name: 10.130.0.113 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    2023-03-03 05:15:33,243 WARNING worker.py:1866 -- The node with node id: 214f5e4656d1ef48f99148ddde46448253fe18672534467ee94b02ba and address: 10.130.0.114 and node name: 10.130.0.114 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    

    그리고 작업이 자동으로 TPU VM을 다시 만들고 최신 체크포인트(이 예시에서는 200단계)에서 학습을 재개할 수 있도록 학습 작업을 다시 시작합니다.

    I0303 05:22:43.141277 140226398705472 train.py:1149] Training loop starting...
    I0303 05:22:43.141381 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/train`...
    I0303 05:22:43.353654 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/eval_train`...
    I0303 05:22:44.008952 140226398705472 py_utils.py:350] Starting sync_global_devices Start training loop from step: 200 across 8 devices globally
    

이 예시에서는 Ray AIR의 Ray Tune을 사용하여 JAX/FLAX의 초매개변수 조정 MNIST를 사용하는 방법을 보여줍니다. 소스 코드는 run_hp_search.py를 참조하세요.

이 예시를 실행하려면 다음 안내를 따르세요.

  1. 요구사항을 설치합니다.

    $ pip3 install -r src/tune/requirements.txt
    
  2. run_hp_search.py를 실행합니다.

    $ python3 src/tune/run_hp_search.py
    

    출력은 다음과 비슷합니다.

    Number of trials: 3/3 (3 TERMINATED)
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    | Trial name                  | status     | loc               |   learning_rate |   momentum |    acc |   iter |   total time (s) |
    |-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------|
    | hp_search_mnist_8cbbb_00000 | TERMINATED | 10.130.0.84:21340 |     1.15258e-09 |   0.897988 | 0.0982 |      3 |          82.4525 |
    | hp_search_mnist_8cbbb_00001 | TERMINATED | 10.130.0.84:21340 |     0.000219523 |   0.825463 | 0.1009 |      3 |          73.1168 |
    | hp_search_mnist_8cbbb_00002 | TERMINATED | 10.130.0.84:21340 |     1.08035e-08 |   0.660416 | 0.098  |      3 |          71.6813 |
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    
    2023-03-02 21:50:47,378   INFO tune.py:798 -- Total run time: 318.07 seconds (318.01 seconds for the tuning loop).
    ...
    

문제 해결

Ray 헤드 노드에 연결할 수 없음

TPU 수명 주기를 생성/삭제하는 워크로드를 실행하면 간혹 Ray 클러스터에서 TPU 호스트 연결이 해제되지 않습니다. 이는 Ray 헤드 노드가 IP 주소 집합에 연결할 수 없음을 나타내는 gRPC 오류로 표시될 수 있습니다.

따라서 Ray 세션(ray stop)을 종료하고 다시 시작해야 할 수 있습니다(ray start --head --port=6379 --num-cpus=0).

Ray Job이 로그 출력 없이 직접 실패함

PAX는 실험용이며 이 예시는 pip 종속 항목으로 인해 중단될 수 있습니다. 이 경우 다음과 같이 표시될 수 있습니다.

I0303 20:50:36.084963 140306486654720 ray_tpu_controller.py:174] Queued 2 jobs.
I0303 20:50:36.136786 140306486654720 ray_tpu_controller.py:238] Requested to clean up 1 stale jobs from previous failures.
I0303 20:50:36.148653 140306486654720 ray_tpu_controller.py:253] Job status: Counter({<JobStatus.FAILED: 'FAILED'>: 2})
I0303 20:51:38.582798 140306486654720 ray_tpu_controller.py:126] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
W0303 20:51:38.589029 140306486654720 ray_tpu_controller.py:196] Detected job raysubmit_8j85YLdHH9pPrmuz FAILED.
2023-03-03 20:51:38,641   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.
2023-03-03 20:51:38,706   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.

오류의 근본 원인을 보려면 http://127.0.0.1:8265/로 이동하여 실행/실패한 작업에 대한 자세한 정보를 제공하는 대시보드를 확인합니다. runtime_env_agent.log는 runtime_env 설정과 관련된 모든 오류 정보를 표시합니다. 예를 들면 다음과 같습니다.

60    INFO: pip is looking at multiple versions of  to determine which version is compatible with other requirements. This could take a while.
61    INFO: pip is looking at multiple versions of orbax to determine which version is compatible with other requirements. This could take a while.
62    ERROR: Cannot install paxml because these package versions have conflicting dependencies.
63
64    The conflict is caused by:
65        praxis 0.3.0 depends on t5x
66        praxis 0.2.1 depends on t5x
67        praxis 0.2.0 depends on t5x
68        praxis 0.1 depends on t5x
69
70    To fix this you could try to:
71    1. loosen the range of package versions you've specified
72    2. remove package versions to allow pip attempt to solve the dependency conflict
73
74    ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts