使用 Ray 扩缩机器学习工作负载

简介

Cloud TPU Ray 工具结合了 Cloud TPU APIRay Job,旨在改善用户在 Cloud TPU 上的开发体验。本用户指南提供了一个最小示例,说明如何将 Ray 与 Cloud TPU 搭配使用。这些示例仅作说明之用,不应用于生产服务。

此工具包含哪些内容?

为方便起见,该工具提供以下功能:

  • 隐藏常见 TPU 操作的样板的通用抽象
  • 可基于自己的基本工作流程创建分支的玩具示例

具体而言:

  • tpu_api.py:用于使用 Cloud TPU API 执行基本 TPU 操作的 Python 封装容器。
  • tpu_controller.py:TPU 的类表示法。这本质上是 tpu_api.py 的封装容器。
  • ray_tpu_controller.py:具有 Ray 功能的 TPU 控制器。这会抽象化 Ray 集群和 Ray Job 的样板文件。
  • run_basic_jax.py:基本示例,展示了如何对 print(jax.device_count()) 使用 RayTpuController
  • run_hp_search.py:基本示例,展示了如何在 MNIST 上将 Ray Tune 与 JAX/Flax 搭配使用。
  • run_pax_autoresume.py:此示例展示了如何使用 PAX 作为示例工作负载,使用 RayTpuController 进行容错训练。

设置 Ray 集群头节点

您可以将 Ray 与 TPU Pod 搭配使用的一种基本方法是将 TPU Pod 设置为 Ray 集群。自然而然的是创建一个单独的 CPU 虚拟机作为协调器虚拟机。下图显示了 Ray 集群配置的示例:

Ray 集群配置示例

以下命令展示了如何使用 Google Cloud CLI 设置 Ray 集群:

$ gcloud compute instances create my_tpu_admin --machine-type=n1-standard-4 ...
$ gcloud compute ssh my_tpu_admin

$ (vm) pip3 install ray[default]
$ (vm) ray start --head --port=6379 --num-cpus=0
...
# (Ray returns the IP address of the HEAD node, for example, RAY_HEAD_IP)
$ (vm) gcloud compute tpus tpu-vm create $TPU_NAME ... --metadata startup-script="pip3 install ray && ray start --address=$RAY_HEAD_IP --resources='{\"tpu_host\": 1}'"

为方便起见,我们还提供用于创建协调器虚拟机以及将此文件夹的内容部署到协调器虚拟机的基本脚本。如需了解源代码,请参阅 create_cpu.shdeploy.sh

这些脚本会设置一些默认值:

  • create_cpu.sh 将创建一个名为 $USER-admin 的虚拟机,并使用设置为 gcloud config 默认值的任何项目和区域。运行 gcloud config list 可查看这些默认值。
  • create_cpu.sh 默认分配大小为 200GB 的启动磁盘。
  • deploy.sh 假定您的虚拟机名称为 $USER-admin。如果您在 create_cpu.sh 中更改该值,请务必在 deploy.sh 中进行更改。

要使用便捷脚本,请执行以下操作:

  1. 将 GitHub 代码库克隆到本地机器并输入 ray_tpu 文件夹:

    $ git clone https://github.com/tensorflow/tpu.git
    $ cd tpu/tools/ray_tpu/
    
  2. 如果您没有用于管理 TPU 的专用服务帐号(强烈建议),请设置一个:

    $ ./create_tpu_service_account.sh
    
  3. 创建协调器虚拟机:

    $ ./create_cpu.sh
    

    此脚本使用启动脚本在虚拟机上安装依赖项,并自动阻止,直到启动脚本完成。

  4. 将本地代码部署到协调器虚拟机:

    $ ./deploy.sh
    
  5. 通过 SSH 连接到虚拟机:

    $ gcloud compute ssh $USER-admin -- -L8265:localhost:8265
    

    此处启用了端口转发,因为 Ray 会自动在端口 8265 启动信息中心。您可以从通过 SSH 连接到协调器虚拟机的机器,通过 http://127.0.0.1:8265/ 访问此信息中心。

  6. 如果您跳过了第 0 步,请在 CPU 虚拟机中设置 gcloud 凭据:

    $ (vm) gcloud auth login --update-adc
    

    此步骤会设置项目 ID 信息并允许 Cloud TPU API 在协调器虚拟机上运行。

  7. 安装要求:

    $ (vm) pip3 install -r src/requirements.txt
    
  8. 在协调器虚拟机上启动 Ray,协调器虚拟机将成为 Ray 集群的头节点:

    $ (vm) ray start --head --port=6379 --num-cpus=0
    

用法示例

基本 JAX 示例

run_basic_jax.py 是一个极为简单的示例,演示了如何在具有 TPU 虚拟机的 Ray 集群上使用 Ray Job 和 Ray 运行时环境来运行 JAX 工作负载。

对于与使用多控制器编程模型(例如 JAX 和 PyTorch/XLA PJRT)的 Cloud TPU 兼容的机器学习框架,您必须为每个主机运行至少一个进程。如需了解详情,请参阅多进程编程模型。在实践中,这可能如下所示:

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py"

如果您的主机超过 16 个以上(例如 v4-128),您将遇到 SSH 可伸缩性问题,并且您的命令可能必须更改为:

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all --batch-size=8
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py &" --batch-size=8

如果 my_bug_free_python_code 包含 bug,这可能会妨碍开发者的速度。解决此问题的一种方法是使用 Kubernetes 或 Ray 等 Orchestrator。Ray 包含运行时环境的概念,使用该环境后,该环境会在 Ray 应用运行时部署代码和依赖项。

通过将 Ray 运行时环境与 Ray 集群和 Ray Job 相结合,您可以绕过 SCP/SSH 周期。假设您遵循了上述示例,则可以使用以下代码运行此操作:

$ python3 legacy/run_basic_jax.py

输出类似于以下内容:

2023-03-01 22:12:10,065   INFO worker.py:1364 -- Connecting to existing Ray cluster at address: 10.130.0.19:6379...
2023-03-01 22:12:10,072   INFO worker.py:1544 -- Connected to Ray cluster. View the dashboard at http://127.0.0.1:8265
W0301 22:12:11.148555 140341931026240 ray_tpu_controller.py:143] TPU is not found, create tpu...
Creating TPU:  $USER-ray-test
Request:  {'accelerator_config': {'topology': '2x2x2', 'type': 'V4'}, 'runtimeVersion': 'tpu-ubuntu2204-base', 'networkConfig': {'enableExternalIps': True}, 'metadata': {'startup-script': '#! /bin/bash\necho "hello world"\nmkdir -p /dev/shm\nsudo mount -t tmpfs -o size=100g tmpfs /dev/shm\n pip3 install ray[default]\nray start --resources=\'{"tpu_host": 1}\' --address=10.130.0.19:6379'}}
Create TPU operation still running...
...
Create TPU operation complete.
I0301 22:13:17.795493 140341931026240 ray_tpu_controller.py:121] Detected 0 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:13:17.795823 140341931026240 ray_tpu_controller.py:160] Waiting for 30s for TPU hosts to join cluster...
…
I0301 22:15:17.986352 140341931026240 ray_tpu_controller.py:121] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:15:17.986503 140341931026240 ray_tpu_controller.py:90] Ray already started on each host.
2023-03-01 22:15:18,010   INFO dashboard_sdk.py:315 -- Uploading package gcs://_ray_pkg_3599972ae38ce933.zip.
2023-03-01 22:15:18,010   INFO packaging.py:503 -- Creating a file package for local directory '/home/$USER/src'.
2023-03-01 22:15:18,080   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_3599972ae38ce933.zip already exists, skipping upload.
I0301 22:15:18.455581 140341931026240 ray_tpu_controller.py:169] Queued 2 jobs.
...
I0301 22:15:48.523541 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_WRUtVB7nMaRTgK39: Status is SUCCEEDED
I0301 22:15:48.561111 140341931026240 ray_tpu_controller.py:256] [raysubmit_WRUtVB7nMaRTgK39]: E0301 22:15:36.294834089   21286 credentials_generic.cc:35]            Could not get HOME environment variable.
8

I0301 22:15:58.575289 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_yPCPXHiFgaCK2rBY: Status is SUCCEEDED
I0301 22:15:58.584667 140341931026240 ray_tpu_controller.py:256] [raysubmit_yPCPXHiFgaCK2rBY]: E0301 22:15:35.720800499    8561 credentials_generic.cc:35]            Could not get HOME environment variable.
8

容错训练

以下示例展示了如何使用 RayTpuController 实现容错训练。在此示例中,我们在 v4-16 上的 PAX 上预训练一个简单的 LLM,但请注意,您可以将此 PAX 工作负载替换为任何其他长时间运行的工作负载。如需了解源代码,请参阅 run_pax_autoresume.py

如需运行此示例,请执行以下操作:

  1. paxml 克隆到协调者虚拟机:

    $ git clone https://github.com/google/paxml.git
    

    为了演示 Ray 运行时环境在进行和部署 JAX 更改时提供的易用性,此示例要求您修改 PAX。

  2. 添加新的实验配置:

    $ cat <<EOT >> paxml/paxml/tasks/lm/params/lm_cloud.py
    
    @experiment_registry.register
    class TestModel(LmCloudSpmd2BLimitSteps):
    ICI_MESH_SHAPE = [1, 4, 2]
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_CONTEXT_AND_OUT_PROJ
    
    def task(self) -> tasks_lib.SingleTask.HParams:
      task_p = super().task()
      task_p.train.num_train_steps = 1000
      task_p.train.save_interval_steps = 100
      return task_p
    EOT
    
  3. 运行 run_pax_autoresume.py

    $ python3 legacy/run_pax_autoresume.py --model_dir=gs://your/gcs/bucket
    
  4. 在工作负载运行时,尝试默认删除 TPU(名为 $USER-tpu-ray)后会发生什么情况:

    $ gcloud compute tpus tpu-vm delete -q $USER-tpu-ray --zone=us-central2-b
    

    Ray 将通过以下消息检测到 TPU 已宕机:

    I0303 05:12:47.384248 140280737294144 checkpointer.py:64] Saving item to gs://$USER-us-central2/pax/v4-16-autoresume-test/checkpoints/checkpoint_00000200/metadata.
    W0303 05:15:17.707648 140051311609600 ray_tpu_controller.py:127] TPU is not found, create tpu...
    2023-03-03 05:15:30,774 WARNING worker.py:1866 -- The node with node id: 9426f44574cce4866be798cfed308f2d3e21ba69487d422872cdd6e3 and address: 10.130.0.113 and node name: 10.130.0.113 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    2023-03-03 05:15:33,243 WARNING worker.py:1866 -- The node with node id: 214f5e4656d1ef48f99148ddde46448253fe18672534467ee94b02ba and address: 10.130.0.114 and node name: 10.130.0.114 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    

    该作业将自动重新创建 TPU 虚拟机并重启训练作业,以便从最新的检查点(本例中为 200 步)继续训练:

    I0303 05:22:43.141277 140226398705472 train.py:1149] Training loop starting...
    I0303 05:22:43.141381 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/train`...
    I0303 05:22:43.353654 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/eval_train`...
    I0303 05:22:44.008952 140226398705472 py_utils.py:350] Starting sync_global_devices Start training loop from step: 200 across 8 devices globally
    

此示例演示了如何使用 Ray AIR 中的 Ray Tune 通过 JAX/FLAX 进行超参数调优 MNIST。如需了解源代码,请参阅 run_hp_search.py

如需运行此示例,请执行以下操作:

  1. 安装要求:

    $ pip3 install -r src/tune/requirements.txt
    
  2. 运行 run_hp_search.py

    $ python3 src/tune/run_hp_search.py
    

    输出类似于以下内容:

    Number of trials: 3/3 (3 TERMINATED)
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    | Trial name                  | status     | loc               |   learning_rate |   momentum |    acc |   iter |   total time (s) |
    |-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------|
    | hp_search_mnist_8cbbb_00000 | TERMINATED | 10.130.0.84:21340 |     1.15258e-09 |   0.897988 | 0.0982 |      3 |          82.4525 |
    | hp_search_mnist_8cbbb_00001 | TERMINATED | 10.130.0.84:21340 |     0.000219523 |   0.825463 | 0.1009 |      3 |          73.1168 |
    | hp_search_mnist_8cbbb_00002 | TERMINATED | 10.130.0.84:21340 |     1.08035e-08 |   0.660416 | 0.098  |      3 |          71.6813 |
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    
    2023-03-02 21:50:47,378   INFO tune.py:798 -- Total run time: 318.07 seconds (318.01 seconds for the tuning loop).
    ...
    

问题排查

Ray 头节点无法连接

如果您运行的工作负载会创建/删除 TPU 生命周期,有时这不会断开 TPU 主机与 Ray 集群的连接。这可能显示为 gRPC 错误,指示 Ray 头节点无法连接到一组 IP 地址。

因此,您可能需要终止光线会话 (ray stop) 并重启 (ray start --head --port=6379 --num-cpus=0)。

Ray 作业直接失败,没有任何日志输出

PAX 目前处于实验阶段,此示例可能会因为 pip 依赖项而中断。如果发生这种情况,您可能会看到如下内容:

I0303 20:50:36.084963 140306486654720 ray_tpu_controller.py:174] Queued 2 jobs.
I0303 20:50:36.136786 140306486654720 ray_tpu_controller.py:238] Requested to clean up 1 stale jobs from previous failures.
I0303 20:50:36.148653 140306486654720 ray_tpu_controller.py:253] Job status: Counter({<JobStatus.FAILED: 'FAILED'>: 2})
I0303 20:51:38.582798 140306486654720 ray_tpu_controller.py:126] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
W0303 20:51:38.589029 140306486654720 ray_tpu_controller.py:196] Detected job raysubmit_8j85YLdHH9pPrmuz FAILED.
2023-03-03 20:51:38,641   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.
2023-03-03 20:51:38,706   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.

如需查看错误的根本原因,请转到 http://127.0.0.1:8265/ 并查看正在运行/失败作业的信息中心,其中会提供更多信息。runtime_env_agent.log 会显示与 Runtime_env 设置相关的所有错误信息,例如:

60    INFO: pip is looking at multiple versions of  to determine which version is compatible with other requirements. This could take a while.
61    INFO: pip is looking at multiple versions of orbax to determine which version is compatible with other requirements. This could take a while.
62    ERROR: Cannot install paxml because these package versions have conflicting dependencies.
63
64    The conflict is caused by:
65        praxis 0.3.0 depends on t5x
66        praxis 0.2.1 depends on t5x
67        praxis 0.2.0 depends on t5x
68        praxis 0.1 depends on t5x
69
70    To fix this you could try to:
71    1. loosen the range of package versions you've specified
72    2. remove package versions to allow pip attempt to solve the dependency conflict
73
74    ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts