使用 Ray 扩缩机器学习工作负载
本文档详细介绍了如何在 TPU 上使用 Ray 和 JAX 运行机器学习 (ML) 工作负载。将 TPU 与 Ray 搭配使用有两种不同的模式:以设备为中心的模式 (PyTorch/XLA) 和以主机为中心的模式 (JAX)。
本文档假定您已设置 TPU 环境。如需了解详情,请参阅以下资源:
- Cloud TPU:设置 Cloud TPU 环境和管理 TPU 资源
- Google Kubernetes Engine (GKE):在 GKE Autopilot 中部署 TPU 工作负载或在 GKE Standard 中部署 TPU 工作负载
以设备为中心的模式 (PyTorch/XLA)
以设备为中心的模式保留了经典 PyTorch 的大部分程序化样式。在此模式下,您可以添加新的 XLA 设备类型,该类型的工作方式与任何其他 PyTorch 设备一样。每个单独的进程都与一个 XLA 设备进行交互。
如果您已经熟悉带有 GPU 的 PyTorch,并且想要使用类似的编码抽象,则此模式非常适合您。
以下部分介绍了如何在不使用 Ray 的情况下在一个或多个设备上运行 PyTorch/XLA 工作负载,以及如何使用 Ray 在多个主机上运行同一工作负载。
创建 TPU
为 TPU 创建参数创建环境变量。
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5p-8 export RUNTIME_VERSION=v2-alpha-tpuv5
环境变量说明
变量 说明 PROJECT_ID
您的 Google Cloud 项目 ID。使用现有项目或创建新项目。 TPU_NAME
TPU 的名称。 ZONE
要在其中创建 TPU 虚拟机的可用区。如需详细了解支持的可用区,请参阅 TPU 区域和可用区。 ACCELERATOR_TYPE
加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本。 RUNTIME_VERSION
Cloud TPU 软件版本。 使用以下命令创建一个具有 8 个核心的 v5p TPU 虚拟机:
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
使用以下命令连接到 TPU 虚拟机:
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
如果您使用的是 GKE,请参阅 KubeRay on GKE 指南,以了解设置信息。
安装要求
在 TPU 虚拟机上运行以下命令以安装所需的依赖项:
将以下内容保存到文件中。例如
requirements.txt
。--find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/libtpu-wheels/index.html torch~=2.6.0 torch_xla[tpu]~=2.6.0 ray[default]==2.40.0
如需安装所需的依赖项,请运行以下命令:
pip install -r requirements.txt
如果您在 GKE 上运行工作负载,我们建议您创建一个 Dockerfile,用于安装所需的依赖项。如需查看示例,请参阅 GKE 文档中的在 TPU 切片节点上运行工作负载。
在单个设备上运行 PyTorch/XLA 工作负载
以下示例演示了如何在单个设备(即 TPU 芯片)上创建 XLA 张量。这与 PyTorch 处理其他设备类型的方式类似。
将以下代码段保存到文件中。例如
workload.py
。import torch import torch_xla import torch_xla.core.xla_model as xm t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t)
import torch_xla
import 语句会初始化 PyTorch/XLA,而xm.xla_device()
函数会返回当前的 XLA 设备(即 TPU 芯片)。将
PJRT_DEVICE
环境变量设置为 TPU。export PJRT_DEVICE=TPU
运行脚本。
python workload.py
输出类似于以下内容。确保输出指示找到了 XLA 设备。
xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0')
在多个设备上运行 PyTorch/XLA
更新上一部分中的代码段,以便在多个设备上运行。
import torch import torch_xla import torch_xla.core.xla_model as xm def _mp_fn(index): t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t) if __name__ == '__main__': torch_xla.launch(_mp_fn, args=())
运行脚本。
python workload.py
如果您在 TPU v5p-8 上运行代码段,输出会类似于以下内容:
xla:0 xla:0 xla:0 tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') xla:0 tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0')
torch_xla.launch()
接受两个参数:一个函数和一个参数列表。它会为每个可用的 XLA 设备创建一个进程,并调用在参数中指定的函数。在此示例中,有 4 个可用的 TPU 设备,因此 torch_xla.launch()
会创建 4 个进程,并在每个设备上调用 _mp_fn()
。每个进程只能访问一个设备,因此每个设备都有索引 0,并且系统会为所有进程输出 xla:0
。
使用 Ray 在多个主机上运行 PyTorch/XLA
以下部分介绍了如何在更大的多主机 TPU 切片上运行相同的代码段。如需详细了解多主机 TPU 架构,请参阅系统架构。
在此示例中,您将手动设置 Ray。如果您已熟悉如何 Ray 设置,则可以跳至最后一部分,即运行 Ray 工作负载。如需详细了解如何为生产环境设置 Ray,请参阅以下资源:
创建多主机 TPU 虚拟机
为 TPU 创建参数创建环境变量。
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5p-16 export RUNTIME_VERSION=v2-alpha-tpuv5
环境变量说明
变量 说明 PROJECT_ID
您的 Google Cloud 项目 ID。使用现有项目或创建新项目。 TPU_NAME
TPU 的名称。 ZONE
要在其中创建 TPU 虚拟机的可用区。如需详细了解支持的可用区,请参阅 TPU 区域和可用区。 ACCELERATOR_TYPE
加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本。 RUNTIME_VERSION
Cloud TPU 软件版本。 使用以下命令创建一个具有 2 个主机的多主机 TPU v5p(v5p-16,每个主机上有 4 个 TPU 芯片):
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
设置 Ray
TPU v5p-16 有 2 个 TPU 主机,每个主机有 4 个 TPU 芯片。在此示例中,您将在一个主机上启动 Ray 头节点,并将第二个主机作为工作器节点添加到 Ray 集群中。
使用 SSH 连接到第一个主机。
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=0
使用与安装要求部分中相同的要求文件安装依赖项。
pip install -r requirements.txt
启动 Ray 进程。
ray start --head --port=6379
输出类似于以下内容:
Enable usage stats collection? This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup. Confirm [Y/n]: y Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details. Local node IP: 10.130.0.76 -------------------- Ray runtime started. -------------------- Next steps To add another node to this Ray cluster, run ray start --address='10.130.0.76:6379' To connect to this Ray cluster: import ray ray.init() To terminate the Ray runtime, run ray stop To view the status of the cluster, use ray status
此 TPU 主机现在是 Ray 头节点。请记下展示如何将其他节点添加到 Ray 集群的行,类似于以下内容:
To add another node to this Ray cluster, run ray start --address='10.130.0.76:6379'
您将在后面的步骤中用到此命令。
检查 Ray 集群状态:
ray status
输出类似于以下内容:
======== Autoscaler status: 2025-01-14 22:03:39.385610 ======== Node status --------------------------------------------------------------- Active: 1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79 Pending: (no pending nodes) Recent failures: (no failures) Resources --------------------------------------------------------------- Usage: 0.0/208.0 CPU 0.0/4.0 TPU 0.0/1.0 TPU-v5p-16-head 0B/268.44GiB memory 0B/119.04GiB object_store_memory 0.0/1.0 your-tpu-name Demands: (no resource demands)
该集群仅包含 4 个 TPU (
0.0/4.0 TPU
),因为到目前为止您只添加了头节点。现在,头节点已在运行,您可以将第二个主机添加到集群。
使用 SSH 连接到第二个主机。
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=1
使用与安装要求部分中相同的要求文件安装依赖项。
pip install -r requirements.txt
启动 Ray 进程。如需将此节点添加到现有的 Ray 集群,请使用
ray start
命令输出中的命令。请务必替换以下命令中的 IP 地址和端口:ray start --address='10.130.0.76:6379'
输出类似于以下内容:
Local node IP: 10.130.0.80 [2025-01-14 22:30:07,397 W 75572 75572] global_state_accessor.cc:463: Retrying to get node with node ID 35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1 -------------------- Ray runtime started. -------------------- To terminate the Ray runtime, run ray stop
再次检查 Ray 状态:
ray status
输出类似于以下内容:
======== Autoscaler status: 2025-01-14 22:45:21.485617 ======== Node status --------------------------------------------------------------- Active: 1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79 1 node_35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1 Pending: (no pending nodes) Recent failures: (no failures) Resources --------------------------------------------------------------- Usage: 0.0/416.0 CPU 0.0/8.0 TPU 0.0/1.0 TPU-v5p-16-head 0B/546.83GiB memory 0B/238.35GiB object_store_memory 0.0/2.0 your-tpu-name Demands: (no resource demands)
第二个 TPU 主机现在是集群中的节点。可用资源列表现在显示 8 个 TPU (
0.0/8.0 TPU
)。
运行 Ray 工作负载
更新要在 Ray 集群上运行的代码段:
import os import torch import torch_xla import torch_xla.core.xla_model as xm import ray import torch.distributed as dist import torch_xla.runtime as xr from torch_xla._internal import pjrt # Defines the local PJRT world size, the number of processes per host. LOCAL_WORLD_SIZE = 4 # Defines the number of hosts in the Ray cluster. NUM_OF_HOSTS = 4 GLOBAL_WORLD_SIZE = LOCAL_WORLD_SIZE * NUM_OF_HOSTS def init_env(): local_rank = int(os.environ['TPU_VISIBLE_CHIPS']) pjrt.initialize_multiprocess(local_rank, LOCAL_WORLD_SIZE) xr._init_world_size_ordinal() # This decorator signals to Ray that the `print_tensor()` function should be run on a single TPU chip. @ray.remote(resources={"TPU": 1}) def print_tensor(): # Initializes the runtime environment on each Ray worker. Equivalent to # the `torch_xla.launch call` in the Run PyTorch/XLA on multiple devices section. init_env() t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t) ray.init() # Uses Ray to dispatch the function call across available nodes in the cluster. tasks = [print_tensor.remote() for _ in range(GLOBAL_WORLD_SIZE)] ray.get(tasks) ray.shutdown()
在 Ray 头节点上运行脚本。请将 ray-workload.py 替换为脚本的路径。
python ray-workload.py
输出类似于以下内容:
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU. xla:0 xla:0 xla:0 xla:0 xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') xla:0 xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0')
输出表明,系统已成功在多主机 TPU 切片中的每个 XLA 设备(在此示例中为 8 个设备)上调用该函数。
以主机为中心的模式 (JAX)
以下部分介绍了使用 JAX 的以主机为中心的模式。JAX 使用函数式编程范式,并支持更高级别的单程序、多数据 (SPMD) 语义。JAX 代码旨在并发在单个主机上的多个设备中运行,而不是让每个进程与单个 XLA 设备进行交互。
JAX 专为高性能计算而设计,可高效利用 TPU 进行大规模训练和推理。如果您熟悉函数式编程概念,则此模式非常适合您,以便充分利用 JAX 的潜力。
以下说明假定您已设置 Ray 和 TPU 环境,包括包含 JAX 和其他相关软件包的软件环境。如需创建 Ray TPU 集群,请按照使用 TPU 为 KubeRay 启动 Google Cloud GKE 集群中的说明操作。如需详细了解如何将 TPU 与 KubeRay 搭配使用,请参阅将 TPU 与 KubeRay 搭配使用。
在单主机 TPU 上运行 JAX 工作负载
以下示例脚本演示了如何在具有单主机 TPU(例如 v6e-4)的 Ray 集群上运行 JAX 函数。如果您有多主机 TPU,则由于 JAX 的多控制器执行模型,此脚本会停止响应。如需详细了解如何在多主机 TPU 上运行 Ray,请参阅在多主机 TPU 上运行 JAX 工作负载。
为 TPU 创建参数创建环境变量。
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-a export ACCELERATOR_TYPE=v6e-4 export RUNTIME_VERSION=v2-alpha-tpuv6e
环境变量说明
变量 说明 PROJECT_ID
您的 Google Cloud 项目 ID。使用现有项目或创建新项目。 TPU_NAME
TPU 的名称。 ZONE
要在其中创建 TPU 虚拟机的可用区。如需详细了解支持的可用区,请参阅 TPU 区域和可用区。 ACCELERATOR_TYPE
加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本。 RUNTIME_VERSION
Cloud TPU 软件版本。 使用以下命令创建一个具有 4 个核心的 v6e TPU 虚拟机:
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
使用以下命令连接到 TPU 虚拟机:
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
在 TPU 上安装 JAX 和 Ray。
pip install ray jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
将以下代码保存到文件中。例如
ray-jax-single-host.py
。import ray import jax @ray.remote(resources={"TPU": 4}) def my_function() -> int: return jax.device_count() h = my_function.remote() print(ray.get(h)) # => 4
如果您习惯使用 GPU 运行 Ray,则在使用 TPU 时会有一些主要差异:
- 请勿设置
num_gpus
,而是将TPU
指定为自定义资源并设置 TPU 芯片数。 - 使用每个 Ray 工作器节点的芯片数指定 TPU。例如,如果您使用的是 v6e-4,则运行远程函数并将
TPU
设置为 4 会占用整个 TPU 主机。 - 这与 GPU 的常规运行方式不同,后者是每个主机一个进程。不建议将
TPU
设置为非 4 的数字。- 例外情况:如果您使用的是单主机
v6e-8
或v5litepod-8
,则应将此值设置为 8。
- 例外情况:如果您使用的是单主机
- 请勿设置
运行脚本。
python ray-jax-single-host.py
在多主机 TPU 上运行 JAX 工作负载
以下示例脚本演示了如何在具有多主机 TPU 的 Ray 集群上运行 JAX 函数。示例脚本使用 v6e-16。
为 TPU 创建参数创建环境变量。
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-a export ACCELERATOR_TYPE=v6e-16 export RUNTIME_VERSION=v2-alpha-tpuv6e
环境变量说明
变量 说明 PROJECT_ID
您的 Google Cloud 项目 ID。使用现有项目或创建新项目。 TPU_NAME
TPU 的名称。 ZONE
要在其中创建 TPU 虚拟机的可用区。如需详细了解支持的可用区,请参阅 TPU 区域和可用区。 ACCELERATOR_TYPE
加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本。 RUNTIME_VERSION
Cloud TPU 软件版本。 使用以下命令创建一个具有 16 个核心的 v6e TPU 虚拟机:
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
在所有 TPU 工作器上安装 JAX 和 Ray。
gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=all \ --command="pip install ray jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
将以下代码保存到文件中。例如
ray-jax-multi-host.py
。import ray import jax @ray.remote(resources={"TPU": 4}) def my_function() -> int: return jax.device_count() ray.init() num_tpus = ray.available_resources()["TPU"] num_hosts = int(num_tpus) # 4 h = [my_function.remote() for _ in range(num_hosts)] print(ray.get(h)) # [16, 16, 16, 16]
如果您习惯使用 GPU 运行 Ray,则在使用 TPU 时会有一些主要差异:
- 与 GPU 上的 PyTorch 工作负载类似:
- TPU 上的 JAX 工作负载以多控制器、单程序、多数据 (SPMD) 方式运行。
- 设备之间的集合由机器学习框架处理。
- 与 GPU 上的 PyTorch 工作负载不同,JAX 可以全局查看集群中的可用设备。
- 与 GPU 上的 PyTorch 工作负载类似:
将脚本复制到所有 TPU 工作器。
gcloud compute tpus tpu-vm scp ray-jax-multi-host.py $TPU_NAME: --zone=$ZONE --worker=all
运行脚本。
gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=all \ --command="python ray-jax-multi-host.py"
运行多切片 JAX 工作负载
借助多切片,您可以在单个 TPU Pod 内跨多个 TPU 切片运行工作负载,也可以通过数据中心网络在多个 Pod 中运行工作负载。
您可以使用 ray-tpu
软件包来简化 Ray 与 TPU 切片的交互。
使用 pip
安装 ray-tpu
。
pip install ray-tpu
如需详细了解如何使用 ray-tpu
软件包,请参阅 GitHub 代码库中的使用入门。如需查看使用多切片的示例,请参阅在多切片上运行。
使用 Ray 和 MaxText 编排工作负载
如需详细了解如何将 Ray 与 MaxText 搭配使用,请参阅使用 MaxText 运行训练作业。
TPU 和 Ray 资源
Ray 对 TPU 的处理方式与对 GPU 的处理方式不同,以适应使用方面的差异。在以下示例中,共有九个 Ray 节点:
- Ray 头节点在
n1-standard-16
虚拟机上运行。 - Ray 工作器节点在两个
v6e-16
TPU 上运行。每个 TPU 由四个工作器组成。
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/727.0 CPU
0.0/32.0 TPU
0.0/2.0 TPU-v6e-16-head
0B/5.13TiB memory
0B/1.47TiB object_store_memory
0.0/4.0 tpu-group-0
0.0/4.0 tpu-group-1
Demands:
(no resource demands)
资源用量字段说明:
CPU
:集群中可用的 CPU 总数。TPU
:集群中的 TPU 芯片数。TPU-v6e-16-head
:与 TPU 切片的工作器 0 对应的资源的特殊标识符。这对于访问各个 TPU 切片非常重要。memory
:应用使用的工作器堆内存。object_store_memory
:应用使用ray.put
在对象存储区中创建对象以及从远程函数返回值时使用的内存。tpu-group-0
和tpu-group-1
:各个 TPU 切片的唯一标识符。这对于在切片上运行作业非常重要。这些字段设置为 4,因为 v6e-16 中的每个 TPU 切片都有 4 个主机。