使用 Ray 扩缩机器学习工作负载
本文档详细介绍了如何在 TPU 上使用 Ray 和 JAX 运行机器学习 (ML) 工作负载。
以下说明假定您已设置 Ray 和 TPU 环境,包括包含 JAX 和其他相关软件包的软件环境。如需创建 Ray TPU 集群,请按照启动Google Cloud 适用于 KubeRay 的 GKE 集群(含 TPU)中的说明操作。如需详细了解如何将 TPU 与 KubeRay 搭配使用,请参阅将 TPU 与 KubeRay 搭配使用。
在单主机 TPU 上运行 JAX 工作负载
以下示例脚本演示了如何在包含单主机 TPU(例如 v6e-4)的 Ray 集群上运行 JAX 函数。如果您使用的是多主机 TPU,由于 JAX 的多控制器执行模型,此脚本会停止响应。如需详细了解如何在多主机 TPU 上运行 Ray,请参阅在多主机 TPU 上运行 JAX 工作负载。
import ray
import jax
@ray.remote(resources={"TPU": 4})
def my_function() -> int:
return jax.device_count()
h = my_function.remote()
print(ray.get(h)) # => 4
如果您习惯使用 GPU 运行 Ray,那么在使用 TPU 时会发现一些关键差异:
- 您可以将
TPU
指定为自定义资源,而不是设置num_gpus
,并设置 TPU 芯片的数量。 - 您可以使用每个 Ray 工作器节点的芯片数来指定 TPU。例如,如果您使用的是 v6e-4,则将
TPU
设置为 4 并运行远程函数会占用整个 TPU 主机。- 这与 GPU 的典型运行方式不同,后者是每个主机一个进程。不建议将
TPU
设置为非 4 的数字。 - 例外情况:如果您使用的是单主机
v6e-8
或v5litepod-8
,则应将此值设置为 8。
- 这与 GPU 的典型运行方式不同,后者是每个主机一个进程。不建议将
在多主机 TPU 上运行 JAX 工作负载
以下示例脚本演示了如何在具有多主机 TPU 的 Ray 集群上运行 JAX 函数。示例脚本使用 v6e-16。
如果您想在包含多个 TPU 切片的集群上运行工作负载,请参阅控制单个 TPU 切片。
import ray
import jax
@ray.remote(resources={"TPU": 4})
def my_function() -> int:
return jax.device_count()
num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]
如果您习惯使用 GPU 运行 Ray,那么在使用 TPU 时会发现一些关键差异:
- 与 GPU 上的 PyTorch 工作负载类似:
- TPU 上的 JAX 工作负载以多控制器、单程序、多数据 (SPMD) 方式运行。
- 设备之间的集合由机器学习框架处理。
- 与 GPU 上的 PyTorch 工作负载不同,JAX 可以全局查看集群中的可用设备。
运行多切片 JAX 工作负载
借助 Multislice,您可以在单个 TPU Pod 内跨多个 TPU 切片运行工作负载,也可以在数据中心网络上的多个 Pod 中运行工作负载。
为方便起见,您可以使用实验性 ray-tpu
软件包来简化 Ray 与 TPU slice 的互动。使用 pip
安装 ray-tpu
:
pip install ray-tpu
以下示例脚本展示了如何使用 ray-tpu
软件包使用 Ray 演员或任务运行多片工作负载:
from ray_tpu import RayTpuManager
import jax
import ray
ray.init()
# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
def get_devices(self):
return jax.device_count()
# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
return jax.device_count()
tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus)
"""
TPU resources:
{'v6e-16': [
RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""
# if using actors
actors = RayTpuManager.remote(
tpus=tpus["v6e-16"],
actor_or_fn=MyActor,
multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]
# if using tasks
h = RayTpuManager.remote(
tpus=tpus["v6e-16"],
actor_or_fn=get_devices,
multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]
# note - you can also run this without Multislice
h = RayTpuManager.run_task(
tpus=tpus["v6e-16"],
actor_or_fn=get_devices,
multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]
TPU 和 Ray 资源
Ray 对 TPU 的处理方式与 GPU 不同,以适应使用方面的差异。在以下示例中,共有 9 个 Ray 节点:
- Ray 主节点在
n1-standard-16
虚拟机上运行。 - Ray 工作器节点在两个
v6e-16
TPU 上运行。每个 TPU 由四个工作器组成。
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/727.0 CPU
0.0/32.0 TPU
0.0/2.0 TPU-v6e-16-head
0B/5.13TiB memory
0B/1.47TiB object_store_memory
0.0/4.0 tpu-group-0
0.0/4.0 tpu-group-1
Demands:
(no resource demands)
资源使用情况字段说明:
CPU
:集群中可用的 CPU 总数。TPU
:集群中的 TPU 芯片数量。TPU-v6e-16-head
:与 TPU 切片的工作器 0 对应的资源的特殊标识符。这对于访问各个 TPU slice 至关重要。memory
:应用使用的 Worker 堆内存。object_store_memory
:应用使用ray.put
在对象存储区中创建对象以及从远程函数返回值时所用的内存。tpu-group-0
和tpu-group-1
:各个 TPU slice 的唯一标识符。这对于在 slice 上运行作业非常重要。这些字段设置为 4,因为 v6e-16 中的每个 TPU 切片都有 4 个主机。
控制单个 TPU 切片
在使用 Ray 和 TPU 时,常见的做法是在同一 TPU 切片中运行多个工作负载,例如在超参数调优或服务中。
在使用 Ray 进行预配和作业调度时,需要特别注意 TPU slice。
运行单切片工作负载
当 Ray 进程在 TPU 分片(运行 ray start
)上启动时,该进程会自动检测分片的相关信息。例如,拓扑、切片中的工作器数量,以及进程是否在工作器 0 上运行。
在名称为“my-tpu”的 TPU v6e-16 上运行 ray status
时,输出类似于以下内容:
worker 0: {"TPU-v6e-16-head": 1, "TPU": 4, "my-tpu": 1"}
worker 1-3: {"TPU": 4, "my-tpu": 1}
"TPU-v6e-16-head"
是 slice 的工作器 0 的资源标签。"TPU": 4
表示每个工作器有 4 个芯片。"my-tpu"
是 TPU 的名称。您可以使用这些值在同一 slice 中的 TPU 上运行工作负载,如以下示例所示。
假设您想在 slice 中的所有工作器上运行以下函数:
@ray.remote()
def my_function():
return jax.device_count()
您需要定位到该 slice 的工作器 0,然后告知工作器 0 如何将 my_function
广播到 slice 中的每个工作器:
@ray.remote(resources={"TPU-v6e-16-head": 1})
def run_on_pod(remote_fn):
tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> returns my-tpu
num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> returns 4
remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4}) # required resources are {"my-tpu": 1, "TPU": 4}
return ray.get([remote_fn.remote() for _ in range(num_hosts)])
h = run_on_pod(my_function).remote() # -> returns a single remote handle
ray.get(h) # -> returns ["16"] * 4
该示例会执行以下步骤:
@ray.remote(resources={"TPU-v6e-16-head": 1})
:run_on_pod
函数在具有资源标签TPU-v6e-16-head
的工作器上运行,该标签以任意工作器 0 为目标。tpu_name = ray.util.accelerators.tpu.get_current_pod_name()
:获取 TPU 名称。num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count()
:获取 slice 中的工作器数量。remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4})
:将包含 TPU 名称和"TPU": 4
资源要求的资源标签添加到函数my_function
。- 由于 TPU 切片中的每个工作器都有其所属切片的自定义资源标签,因此 Ray 只会在同一 TPU 切片中的工作器上调度工作负载。
- 这还会为远程函数预留 4 个 TPU 工作器,因此 Ray 不会在该 Ray Pod 上调度其他 TPU 工作负载。
- 由于
run_on_pod
仅使用TPU-v6e-16-head
逻辑资源,因此my_function
也将在 worker 0 上运行,但在不同的进程中运行。
return ray.get([remote_fn.remote() for _ in range(num_hosts)])
:调用修改后的my_function
函数的次数等于工作器数量,并返回结果。h = run_on_pod(my_function).remote()
:run_on_pod
将异步执行,不会阻塞主进程。
TPU 切片自动扩缩
TPU 上的 Ray 支持按 TPU 切片的粒度进行自动扩缩。您可以使用 GKE 节点自动预配 (NAP) 功能启用此功能。您可以使用 Ray 自动扩缩器和 KubeRay 执行此功能。主资源类型用于向 Ray 发出自动扩缩信号,例如 TPU-v6e-32-head
。