使用 Ray 扩缩机器学习工作负载
简介
Cloud TPU Ray 工具结合了 Cloud TPU API 和 Ray 作业,旨在改进用户在 Cloud TPU 上的开发体验。本用户指南提供了一个最小示例,展示了如何将 Ray 与 Cloud TPU 搭配使用。这些示例不适合在生产环境中使用 服务,仅作说明之用。
此工具包含哪些内容?
为方便起见,该工具提供了以下功能:
- 用于隐藏常用 TPU 操作的样板的通用抽象
- 可用于分叉以构建您自己的基本工作流的玩具示例
具体而言:
tpu_api.py
:使用 Cloud TPU API 执行基本 TPU 操作的 Python 封装容器。tpu_controller.py
:TPU 的类表示法。这实际上是tpu_api.py
的封装容器。ray_tpu_controller.py
:具有 Ray 功能的 TPU 控制器。这会提取 Ray 集群和 Ray 作业的样板代码。run_basic_jax.py
:基本示例,展示了如何对print(jax.device_count())
使用RayTpuController
。run_hp_search.py
: 基本示例:展示了如何在 MNIST 上将 Ray Tune 与 JAX/Flax 搭配使用。run_pax_autoresume.py
: 展示如何使用RayTpuController
处理故障的示例 使用 PAX 作为示例工作负载的容忍训练。
正在设置 Ray 集群头节点
Ray 与 TPU Pod 一起使用的基本方法之一是设置 TPU Pod 作为射线星团创建一个单独的 CPU 虚拟机作为协调器虚拟机是实现此目的的自然方式。下图显示了 Ray 集群配置的示例:
以下命令展示了如何使用 Google Cloud CLI 设置 Ray 集群:
$ gcloud compute instances create my_tpu_admin --machine-type=n1-standard-4 ... $ gcloud compute ssh my_tpu_admin $ (vm) pip3 install ray[default] $ (vm) ray start --head --port=6379 --num-cpus=0 ... # (Ray returns the IP address of the HEAD node, for example, RAY_HEAD_IP) $ (vm) gcloud compute tpus tpu-vm create $TPU_NAME ... --metadata startup-script="pip3 install ray && ray start --address=$RAY_HEAD_IP --resources='{\"tpu_host\": 1}'"
为方便起见,我们还提供了用于创建协调者虚拟机并将此文件夹中的内容部署到协调者虚拟机的基本脚本。如需了解源代码,请参阅
create_cpu.sh
和
deploy.sh
。
这些脚本会设置一些默认值:
create_cpu.sh
将创建一个名为$USER-admin
的虚拟机,并使用设置为gcloud config
默认值的项目和可用区。运行gcloud config list
即可查看这些默认值。create_cpu.sh
默认分配的启动磁盘大小为 200GB。deploy.sh
假定您的虚拟机名称为$USER-admin
。如果您在create_cpu.sh
中更改该值,请务必在deploy.sh
中更改该值。
如需使用便捷脚本,请执行以下操作:
将 GitHub 代码库克隆到本地机器,然后进入
ray_tpu
文件夹:$ git clone https://github.com/tensorflow/tpu.git $ cd tpu/tools/ray_tpu/
如果您没有专用于 TPU 管理的服务账号(强烈建议),请设置一个:
$ ./create_tpu_service_account.sh
创建协调器虚拟机:
$ ./create_cpu.sh
此脚本使用启动脚本在虚拟机上安装依赖项,并会在启动脚本完成之前自动阻塞。
将本地代码部署到协调器虚拟机:
$ ./deploy.sh
通过 SSH 连接到虚拟机:
$ gcloud compute ssh $USER-admin -- -L8265:localhost:8265
此处启用了端口转发,因为 Ray 会自动在端口 8265 上启动信息中心。从您通过 SSH 连接到协调器虚拟机的机器, 可通过以下网址访问此信息中心: http://127.0.0.1:8265/.
如果您跳过了第 0 步,请在 CPU 虚拟机中设置 gcloud 凭据:
$ (vm) gcloud auth login --update-adc
此步骤设置项目 ID 信息,并允许 Cloud TPU API 在 协调器虚拟机。
安装要求:
$ (vm) pip3 install -r src/requirements.txt
在协调者虚拟机上启动 Ray,协调者虚拟机将成为 Ray 集群的主节点:
$ (vm) ray start --head --port=6379 --num-cpus=0
用法示例
基本 JAX 示例
run_basic_jax.py
是一个最小示例,演示了如何在包含 TPU 虚拟机的 Ray 集群上使用 Ray Jobs 和 Ray 运行时环境运行 JAX 工作负载。
对于与 Cloud TPU 兼容且使用多控制器编程模型的机器学习框架(例如 JAX 和 PyTorch/XLA PJRT),您必须在每个主机上运行至少一个进程。如需了解详情,请参阅多进程编程模型。在实际操作中,这可能如下所示:
$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all $ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py"
如果您有大约 16 台以上的主机(例如 v4-128),则会遇到 SSH 可伸缩性问题,并且您的命令可能必须更改为:
$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all --batch-size=8 $ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py &" --batch-size=8
如果 my_bug_free_python_code
,这可能会妨碍开发者开发速度
包含 bug。解决此问题的方法之一就是使用
Kubernetes 或 Ray 等编排系统Ray 包含运行时环境的概念,在应用时,该环境会在运行 Ray 应用时部署代码和依赖项。
将 Ray 运行时环境与 Ray 集群以及 Ray Jobs 结合使用,您可以绕过 SCP/SSH 周期假设您按照上述示例操作,则可以使用以下命令运行此脚本:
$ python3 legacy/run_basic_jax.py
输出类似于以下内容:
2023-03-01 22:12:10,065 INFO worker.py:1364 -- Connecting to existing Ray cluster at address: 10.130.0.19:6379... 2023-03-01 22:12:10,072 INFO worker.py:1544 -- Connected to Ray cluster. View the dashboard at http://127.0.0.1:8265 W0301 22:12:11.148555 140341931026240 ray_tpu_controller.py:143] TPU is not found, create tpu... Creating TPU: $USER-ray-test Request: {'accelerator_config': {'topology': '2x2x2', 'type': 'V4'}, 'runtimeVersion': 'tpu-ubuntu2204-base', 'networkConfig': {'enableExternalIps': True}, 'metadata': {'startup-script': '#! /bin/bash\necho "hello world"\nmkdir -p /dev/shm\nsudo mount -t tmpfs -o size=100g tmpfs /dev/shm\n pip3 install ray[default]\nray start --resources=\'{"tpu_host": 1}\' --address=10.130.0.19:6379'}} Create TPU operation still running... ... Create TPU operation complete. I0301 22:13:17.795493 140341931026240 ray_tpu_controller.py:121] Detected 0 TPU hosts in cluster, expecting 2 hosts in total I0301 22:13:17.795823 140341931026240 ray_tpu_controller.py:160] Waiting for 30s for TPU hosts to join cluster... … I0301 22:15:17.986352 140341931026240 ray_tpu_controller.py:121] Detected 2 TPU hosts in cluster, expecting 2 hosts in total I0301 22:15:17.986503 140341931026240 ray_tpu_controller.py:90] Ray already started on each host. 2023-03-01 22:15:18,010 INFO dashboard_sdk.py:315 -- Uploading package gcs://_ray_pkg_3599972ae38ce933.zip. 2023-03-01 22:15:18,010 INFO packaging.py:503 -- Creating a file package for local directory '/home/$USER/src'. 2023-03-01 22:15:18,080 INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_3599972ae38ce933.zip already exists, skipping upload. I0301 22:15:18.455581 140341931026240 ray_tpu_controller.py:169] Queued 2 jobs. ... I0301 22:15:48.523541 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_WRUtVB7nMaRTgK39: Status is SUCCEEDED I0301 22:15:48.561111 140341931026240 ray_tpu_controller.py:256] [raysubmit_WRUtVB7nMaRTgK39]: E0301 22:15:36.294834089 21286 credentials_generic.cc:35] Could not get HOME environment variable. 8 I0301 22:15:58.575289 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_yPCPXHiFgaCK2rBY: Status is SUCCEEDED I0301 22:15:58.584667 140341931026240 ray_tpu_controller.py:256] [raysubmit_yPCPXHiFgaCK2rBY]: E0301 22:15:35.720800499 8561 credentials_generic.cc:35] Could not get HOME environment variable. 8
容错型训练
此示例展示了如何使用 RayTpuController
实现故障
容忍训练。在本例中,我们在 v4-16 上基于 PAX 预训练了一个简单的 LLM,但请注意,您可以将此 PAX 工作负载替换为任何其他长时间运行的工作负载。如需了解源代码,请参阅
run_pax_autoresume.py
。
如需运行此示例,请执行以下操作:
将
paxml
克隆到协调方虚拟机:$ git clone https://github.com/google/paxml.git
为了演示 Ray 运行时环境在进行和部署 JAX 更改方面的易用性,此示例要求您修改 PAX。
添加新的实验配置:
$ cat <<EOT >> paxml/paxml/tasks/lm/params/lm_cloud.py @experiment_registry.register class TestModel(LmCloudSpmd2BLimitSteps): ICI_MESH_SHAPE = [1, 4, 2] CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_CONTEXT_AND_OUT_PROJ def task(self) -> tasks_lib.SingleTask.HParams: task_p = super().task() task_p.train.num_train_steps = 1000 task_p.train.save_interval_steps = 100 return task_p EOT
运行
run_pax_autoresume.py
:$ python3 legacy/run_pax_autoresume.py --model_dir=gs://your/gcs/bucket
在工作负载运行时,对删除 TPU 后会发生什么情况进行实验, 默认为
$USER-tpu-ray
:$ gcloud compute tpus tpu-vm delete -q $USER-tpu-ray --zone=us-central2-b
Ray 会检测 TPU 是否已关闭,并显示以下消息:
I0303 05:12:47.384248 140280737294144 checkpointer.py:64] Saving item to gs://$USER-us-central2/pax/v4-16-autoresume-test/checkpoints/checkpoint_00000200/metadata. W0303 05:15:17.707648 140051311609600 ray_tpu_controller.py:127] TPU is not found, create tpu... 2023-03-03 05:15:30,774 WARNING worker.py:1866 -- The node with node id: 9426f44574cce4866be798cfed308f2d3e21ba69487d422872cdd6e3 and address: 10.130.0.113 and node name: 10.130.0.113 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a (1) raylet crashes unexpectedly (OOM, preempted node, etc.) (2) raylet has lagging heartbeats due to slow network or busy workload. 2023-03-03 05:15:33,243 WARNING worker.py:1866 -- The node with node id: 214f5e4656d1ef48f99148ddde46448253fe18672534467ee94b02ba and address: 10.130.0.114 and node name: 10.130.0.114 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a (1) raylet crashes unexpectedly (OOM, preempted node, etc.) (2) raylet has lagging heartbeats due to slow network or busy workload.
该作业会自动重新创建 TPU 虚拟机并重启训练作业 以便它可以从最新的检查点(本单元中的 200 步) 示例):
I0303 05:22:43.141277 140226398705472 train.py:1149] Training loop starting... I0303 05:22:43.141381 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/train`... I0303 05:22:43.353654 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/eval_train`... I0303 05:22:44.008952 140226398705472 py_utils.py:350] Starting sync_global_devices Start training loop from step: 200 across 8 devices globally
超参数搜索
此示例展示了如何使用 Ray AIR 中的 Ray Tune 对 JAX/FLAX 中的 MNIST 进行超参数调优。如需查看源代码,请参阅 run_hp_search.py
。
如需运行此示例,请执行以下操作:
安装要求:
$ pip3 install -r src/tune/requirements.txt
运行
run_hp_search.py
:$ python3 src/tune/run_hp_search.py
输出类似于以下内容:
Number of trials: 3/3 (3 TERMINATED) +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+ | Trial name | status | loc | learning_rate | momentum | acc | iter | total time (s) | |-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------| | hp_search_mnist_8cbbb_00000 | TERMINATED | 10.130.0.84:21340 | 1.15258e-09 | 0.897988 | 0.0982 | 3 | 82.4525 | | hp_search_mnist_8cbbb_00001 | TERMINATED | 10.130.0.84:21340 | 0.000219523 | 0.825463 | 0.1009 | 3 | 73.1168 | | hp_search_mnist_8cbbb_00002 | TERMINATED | 10.130.0.84:21340 | 1.08035e-08 | 0.660416 | 0.098 | 3 | 71.6813 | +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+ 2023-03-02 21:50:47,378 INFO tune.py:798 -- Total run time: 318.07 seconds (318.01 seconds for the tuning loop). ...
问题排查
Ray 头节点无法连接
如果您运行的工作负载会创建/删除 TPU 生命周期,有时这并不会将 TPU 主机与 Ray 集群断开连接。这可能会 显示为表示 Ray 头节点无法连接的 gRPC 错误 一组 IP 地址
因此,您可能需要终止 Ray 会话 (ray stop
) 并重启它 (ray start --head --port=6379 --num-cpus=0
)。
Ray Job 直接失败,没有任何日志输出
PAX 处于实验阶段,此示例可能会因 pip 依赖项而损坏。如果发生这种情况,您可能会看到如下内容:
I0303 20:50:36.084963 140306486654720 ray_tpu_controller.py:174] Queued 2 jobs. I0303 20:50:36.136786 140306486654720 ray_tpu_controller.py:238] Requested to clean up 1 stale jobs from previous failures. I0303 20:50:36.148653 140306486654720 ray_tpu_controller.py:253] Job status: Counter({<JobStatus.FAILED: 'FAILED'>: 2}) I0303 20:51:38.582798 140306486654720 ray_tpu_controller.py:126] Detected 2 TPU hosts in cluster, expecting 2 hosts in total W0303 20:51:38.589029 140306486654720 ray_tpu_controller.py:196] Detected job raysubmit_8j85YLdHH9pPrmuz FAILED. 2023-03-03 20:51:38,641 INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload. 2023-03-03 20:51:38,706 INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.
如需查看错误的根本原因,您可以前往 http://127.0.0.1:8265/,查看正在运行/失败作业的信息中心,其中会提供更多信息。runtime_env_agent.log
会显示与 runtime_env 设置相关的所有错误信息,例如:
60 INFO: pip is looking at multiple versions ofto determine which version is compatible with other requirements. This could take a while. 61 INFO: pip is looking at multiple versions of orbax to determine which version is compatible with other requirements. This could take a while. 62 ERROR: Cannot install paxml because these package versions have conflicting dependencies. 63 64 The conflict is caused by: 65 praxis 0.3.0 depends on t5x 66 praxis 0.2.1 depends on t5x 67 praxis 0.2.0 depends on t5x 68 praxis 0.1 depends on t5x 69 70 To fix this you could try to: 71 1. loosen the range of package versions you've specified 72 2. remove package versions to allow pip attempt to solve the dependency conflict 73 74 ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts