使用 Ray 扩缩机器学习工作负载

本文档详细介绍了如何在 TPU 上使用 Ray 和 JAX 运行机器学习 (ML) 工作负载。

以下说明假定您已设置 Ray 和 TPU 环境,包括包含 JAX 和其他相关软件包的软件环境。如需创建 Ray TPU 集群,请按照启动Google Cloud 适用于 KubeRay 的 GKE 集群(含 TPU)中的说明操作。如需详细了解如何将 TPU 与 KubeRay 搭配使用,请参阅将 TPU 与 KubeRay 搭配使用

在单主机 TPU 上运行 JAX 工作负载

以下示例脚本演示了如何在包含单主机 TPU(例如 v6e-4)的 Ray 集群上运行 JAX 函数。如果您使用的是多主机 TPU,由于 JAX 的多控制器执行模型,此脚本会停止响应。如需详细了解如何在多主机 TPU 上运行 Ray,请参阅在多主机 TPU 上运行 JAX 工作负载

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

h = my_function.remote()
print(ray.get(h)) # => 4

如果您习惯使用 GPU 运行 Ray,那么在使用 TPU 时会发现一些关键差异:

  • 您可以将 TPU 指定为自定义资源,而不是设置 num_gpus,并设置 TPU 芯片的数量。
  • 您可以使用每个 Ray 工作器节点的芯片数来指定 TPU。例如,如果您使用的是 v6e-4,则将 TPU 设置为 4 并运行远程函数会占用整个 TPU 主机。
    • 这与 GPU 的典型运行方式不同,后者是每个主机一个进程。不建议将 TPU 设置为非 4 的数字。
    • 例外情况:如果您使用的是单主机 v6e-8v5litepod-8,则应将此值设置为 8。

在多主机 TPU 上运行 JAX 工作负载

以下示例脚本演示了如何在具有多主机 TPU 的 Ray 集群上运行 JAX 函数。示例脚本使用 v6e-16。

如果您想在包含多个 TPU 切片的集群上运行工作负载,请参阅控制单个 TPU 切片

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]

如果您习惯使用 GPU 运行 Ray,那么在使用 TPU 时会发现一些关键差异:

  • 与 GPU 上的 PyTorch 工作负载类似:
  • 与 GPU 上的 PyTorch 工作负载不同,JAX 可以全局查看集群中的可用设备。

运行多切片 JAX 工作负载

借助 Multislice,您可以在单个 TPU Pod 内跨多个 TPU 切片运行工作负载,也可以在数据中心网络上的多个 Pod 中运行工作负载。

为方便起见,您可以使用实验性 ray-tpu 软件包来简化 Ray 与 TPU slice 的互动。使用 pip 安装 ray-tpu

pip install ray-tpu

以下示例脚本展示了如何使用 ray-tpu 软件包使用 Ray 演员或任务运行多片工作负载:

from ray_tpu import RayTpuManager
import jax
import ray

ray.init()

# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
    def get_devices(self):
        return jax.device_count()

# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
    return jax.device_count()

tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus) 
"""
TPU resources:
{'v6e-16': [
    RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
    RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""

# if using actors
actors = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=MyActor,
    multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]

# if using tasks
h = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]

# note - you can also run this without Multislice
h = RayTpuManager.run_task(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]

TPU 和 Ray 资源

Ray 对 TPU 的处理方式与 GPU 不同,以适应使用方面的差异。在以下示例中,共有 9 个 Ray 节点:

  • Ray 主节点在 n1-standard-16 虚拟机上运行。
  • Ray 工作器节点在两个 v6e-16 TPU 上运行。每个 TPU 由四个工作器组成。
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
 1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
 1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
 1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
 1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
 1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
 1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
 1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
 1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
 1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/727.0 CPU
 0.0/32.0 TPU
 0.0/2.0 TPU-v6e-16-head
 0B/5.13TiB memory
 0B/1.47TiB object_store_memory
 0.0/4.0 tpu-group-0
 0.0/4.0 tpu-group-1

Demands:
 (no resource demands)

资源使用情况字段说明:

  • CPU:集群中可用的 CPU 总数。
  • TPU:集群中的 TPU 芯片数量。
  • TPU-v6e-16-head:与 TPU 切片的工作器 0 对应的资源的特殊标识符。这对于访问各个 TPU slice 至关重要。
  • memory:应用使用的 Worker 堆内存。
  • object_store_memory:应用使用 ray.put 在对象存储区中创建对象以及从远程函数返回值时所用的内存。
  • tpu-group-0tpu-group-1:各个 TPU slice 的唯一标识符。这对于在 slice 上运行作业非常重要。这些字段设置为 4,因为 v6e-16 中的每个 TPU 切片都有 4 个主机。

控制单个 TPU 切片

在使用 Ray 和 TPU 时,常见的做法是在同一 TPU 切片中运行多个工作负载,例如在超参数调优或服务中。

在使用 Ray 进行预配和作业调度时,需要特别注意 TPU slice。

运行单切片工作负载

当 Ray 进程在 TPU 分片(运行 ray start)上启动时,该进程会自动检测分片的相关信息。例如,拓扑、切片中的工作器数量,以及进程是否在工作器 0 上运行。

在名称为“my-tpu”的 TPU v6e-16 上运行 ray status 时,输出类似于以下内容:

worker 0: {"TPU-v6e-16-head": 1, "TPU": 4, "my-tpu": 1"}
worker 1-3: {"TPU": 4, "my-tpu": 1}

"TPU-v6e-16-head" 是 slice 的工作器 0 的资源标签。"TPU": 4 表示每个工作器有 4 个芯片。"my-tpu" 是 TPU 的名称。您可以使用这些值在同一 slice 中的 TPU 上运行工作负载,如以下示例所示。

假设您想在 slice 中的所有工作器上运行以下函数:

@ray.remote()
def my_function():
    return jax.device_count()

您需要定位到该 slice 的工作器 0,然后告知工作器 0 如何将 my_function 广播到 slice 中的每个工作器:

@ray.remote(resources={"TPU-v6e-16-head": 1})
def run_on_pod(remote_fn):
    tpu_name = ray.util.accelerators.tpu.get_current_pod_name()  # -> returns my-tpu
    num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> returns 4
    remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4}) # required resources are {"my-tpu": 1, "TPU": 4}
    return ray.get([remote_fn.remote() for _ in range(num_hosts)])

h = run_on_pod(my_function).remote() # -> returns a single remote handle
ray.get(h) # -> returns ["16"] * 4

该示例会执行以下步骤:

  • @ray.remote(resources={"TPU-v6e-16-head": 1})run_on_pod 函数在具有资源标签 TPU-v6e-16-head 的工作器上运行,该标签以任意工作器 0 为目标。
  • tpu_name = ray.util.accelerators.tpu.get_current_pod_name():获取 TPU 名称。
  • num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count():获取 slice 中的工作器数量。
  • remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4}):将包含 TPU 名称和 "TPU": 4 资源要求的资源标签添加到函数 my_function
    • 由于 TPU 切片中的每个工作器都有其所属切片的自定义资源标签,因此 Ray 只会在同一 TPU 切片中的工作器上调度工作负载。
    • 这还会为远程函数预留 4 个 TPU 工作器,因此 Ray 不会在该 Ray Pod 上调度其他 TPU 工作负载。
    • 由于 run_on_pod 仅使用 TPU-v6e-16-head 逻辑资源,因此 my_function 也将在 worker 0 上运行,但在不同的进程中运行。
  • return ray.get([remote_fn.remote() for _ in range(num_hosts)]):调用修改后的 my_function 函数的次数等于工作器数量,并返回结果。
  • h = run_on_pod(my_function).remote()run_on_pod 将异步执行,不会阻塞主进程。

TPU 切片自动扩缩

TPU 上的 Ray 支持按 TPU 切片的粒度进行自动扩缩。您可以使用 GKE 节点自动预配 (NAP) 功能启用此功能。您可以使用 Ray 自动扩缩器和 KubeRay 执行此功能。主资源类型用于向 Ray 发出自动扩缩信号,例如 TPU-v6e-32-head