使用 Ray 扩缩机器学习工作负载

本文档详细介绍了如何在 TPU 上使用 Ray 和 JAX 运行机器学习 (ML) 工作负载。将 TPU 与 Ray 搭配使用有两种不同的模式:以设备为中心的模式 (PyTorch/XLA)以主机为中心的模式 (JAX)

本文档假定您已设置 TPU 环境。如需了解详情,请参阅以下资源:

以设备为中心的模式 (PyTorch/XLA)

以设备为中心的模式保留了经典 PyTorch 的大部分程序化样式。在此模式下,您可以添加新的 XLA 设备类型,该类型的工作方式与任何其他 PyTorch 设备一样。每个单独的进程都与一个 XLA 设备进行交互。

如果您已经熟悉带有 GPU 的 PyTorch,并且想要使用类似的编码抽象,则此模式非常适合您。

以下部分介绍了如何在不使用 Ray 的情况下在一个或多个设备上运行 PyTorch/XLA 工作负载,以及如何使用 Ray 在多个主机上运行同一工作负载。

创建 TPU

  1. 为 TPU 创建参数创建环境变量。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-8
    export RUNTIME_VERSION=v2-alpha-tpuv5

    环境变量说明

    变量 说明
    PROJECT_ID 您的 Google Cloud 项目 ID。使用现有项目或创建新项目
    TPU_NAME TPU 的名称。
    ZONE 要在其中创建 TPU 虚拟机的可用区。如需详细了解支持的可用区,请参阅 TPU 区域和可用区
    ACCELERATOR_TYPE 加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
    RUNTIME_VERSION Cloud TPU 软件版本

  2. 使用以下命令创建一个具有 8 个核心的 v5p TPU 虚拟机:

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE  \
       --version=$RUNTIME_VERSION
  3. 使用以下命令连接到 TPU 虚拟机:

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE

如果您使用的是 GKE,请参阅 KubeRay on GKE 指南,以了解设置信息。

安装要求

在 TPU 虚拟机上运行以下命令以安装所需的依赖项:

  1. 将以下内容保存到文件中。例如 requirements.txt

    --find-links https://storage.googleapis.com/libtpu-releases/index.html
    --find-links https://storage.googleapis.com/libtpu-wheels/index.html
    torch~=2.6.0
    torch_xla[tpu]~=2.6.0
    ray[default]==2.40.0
    
  2. 如需安装所需的依赖项,请运行以下命令:

    pip install -r requirements.txt
    

如果您在 GKE 上运行工作负载,我们建议您创建一个 Dockerfile,用于安装所需的依赖项。如需查看示例,请参阅 GKE 文档中的在 TPU 切片节点上运行工作负载

在单个设备上运行 PyTorch/XLA 工作负载

以下示例演示了如何在单个设备(即 TPU 芯片)上创建 XLA 张量。这与 PyTorch 处理其他设备类型的方式类似。

  1. 将以下代码段保存到文件中。例如 workload.py

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    t = torch.randn(2, 2, device=xm.xla_device())
    print(t.device)
    print(t)
    

    import torch_xla import 语句会初始化 PyTorch/XLA,而 xm.xla_device() 函数会返回当前的 XLA 设备(即 TPU 芯片)。

  2. PJRT_DEVICE 环境变量设置为 TPU。

    export PJRT_DEVICE=TPU
    
  3. 运行脚本。

    python workload.py
    

    输出类似于以下内容。确保输出指示找到了 XLA 设备。

    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

在多个设备上运行 PyTorch/XLA

  1. 更新上一部分中的代码段,以便在多个设备上运行。

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    def _mp_fn(index):
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    if __name__ == '__main__':
        torch_xla.launch(_mp_fn, args=())
    
  2. 运行脚本。

    python workload.py
    

    如果您在 TPU v5p-8 上运行代码段,输出会类似于以下内容:

    xla:0
    xla:0
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    

torch_xla.launch() 接受两个参数:一个函数和一个参数列表。它会为每个可用的 XLA 设备创建一个进程,并调用在参数中指定的函数。在此示例中,有 4 个可用的 TPU 设备,因此 torch_xla.launch() 会创建 4 个进程,并在每个设备上调用 _mp_fn()。每个进程只能访问一个设备,因此每个设备都有索引 0,并且系统会为所有进程输出 xla:0

使用 Ray 在多个主机上运行 PyTorch/XLA

以下部分介绍了如何在更大的多主机 TPU 切片上运行相同的代码段。如需详细了解多主机 TPU 架构,请参阅系统架构

在此示例中,您将手动设置 Ray。如果您已熟悉如何 Ray 设置,则可以跳至最后一部分,即运行 Ray 工作负载。如需详细了解如何为生产环境设置 Ray,请参阅以下资源:

创建多主机 TPU 虚拟机

  1. 为 TPU 创建参数创建环境变量。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-16
    export RUNTIME_VERSION=v2-alpha-tpuv5

    环境变量说明

    变量 说明
    PROJECT_ID 您的 Google Cloud 项目 ID。使用现有项目或创建新项目
    TPU_NAME TPU 的名称。
    ZONE 要在其中创建 TPU 虚拟机的可用区。如需详细了解支持的可用区,请参阅 TPU 区域和可用区
    ACCELERATOR_TYPE 加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
    RUNTIME_VERSION Cloud TPU 软件版本

  2. 使用以下命令创建一个具有 2 个主机的多主机 TPU v5p(v5p-16,每个主机上有 4 个 TPU 芯片):

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE \
       --version=$RUNTIME_VERSION

设置 Ray

TPU v5p-16 有 2 个 TPU 主机,每个主机有 4 个 TPU 芯片。在此示例中,您将在一个主机上启动 Ray 头节点,并将第二个主机作为工作器节点添加到 Ray 集群中。

  1. 使用 SSH 连接到第一个主机。

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=0
  2. 使用与安装要求部分中相同的要求文件安装依赖项。

    pip install -r requirements.txt
    
  3. 启动 Ray 进程。

    ray start --head --port=6379
    

    输出类似于以下内容:

    Enable usage stats collection? This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup. Confirm [Y/n]: y
    Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details.
    
    Local node IP: 10.130.0.76
    
    --------------------
    Ray runtime started.
    --------------------
    
    Next steps
    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    
    To connect to this Ray cluster:
        import ray
        ray.init()
    
    To terminate the Ray runtime, run
        ray stop
    
    To view the status of the cluster, use
        ray status
    

    此 TPU 主机现在是 Ray 头节点。请记下展示如何将其他节点添加到 Ray 集群的行,类似于以下内容:

    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    

    您将在后面的步骤中用到此命令。

  4. 检查 Ray 集群状态:

    ray status
    

    输出类似于以下内容:

    ======== Autoscaler status: 2025-01-14 22:03:39.385610 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/208.0 CPU
    0.0/4.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/268.44GiB memory
    0B/119.04GiB object_store_memory
    0.0/1.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    该集群仅包含 4 个 TPU (0.0/4.0 TPU),因为到目前为止您只添加了头节点。

    现在,头节点已在运行,您可以将第二个主机添加到集群。

  5. 使用 SSH 连接到第二个主机。

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=1
  6. 使用与安装要求部分中相同的要求文件安装依赖项。

    pip install -r requirements.txt
    
  7. 启动 Ray 进程。如需将此节点添加到现有的 Ray 集群,请使用 ray start 命令输出中的命令。请务必替换以下命令中的 IP 地址和端口:

    ray start --address='10.130.0.76:6379'

    输出类似于以下内容:

    Local node IP: 10.130.0.80
    [2025-01-14 22:30:07,397 W 75572 75572] global_state_accessor.cc:463: Retrying to get node with node ID 35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    
    --------------------
    Ray runtime started.
    --------------------
    
    To terminate the Ray runtime, run
    ray stop
    
  8. 再次检查 Ray 状态:

    ray status
    

    输出类似于以下内容:

    ======== Autoscaler status: 2025-01-14 22:45:21.485617 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    1 node_35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/416.0 CPU
    0.0/8.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/546.83GiB memory
    0B/238.35GiB object_store_memory
    0.0/2.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    第二个 TPU 主机现在是集群中的节点。可用资源列表现在显示 8 个 TPU (0.0/8.0 TPU)。

运行 Ray 工作负载

  1. 更新要在 Ray 集群上运行的代码段:

    import os
    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    import ray
    
    import torch.distributed as dist
    import torch_xla.runtime as xr
    from torch_xla._internal import pjrt
    
    # Defines the local PJRT world size, the number of processes per host.
    LOCAL_WORLD_SIZE = 4
    # Defines the number of hosts in the Ray cluster.
    NUM_OF_HOSTS = 4
    GLOBAL_WORLD_SIZE = LOCAL_WORLD_SIZE * NUM_OF_HOSTS
    
    def init_env():
        local_rank = int(os.environ['TPU_VISIBLE_CHIPS'])
    
        pjrt.initialize_multiprocess(local_rank, LOCAL_WORLD_SIZE)
        xr._init_world_size_ordinal()
    
    # This decorator signals to Ray that the `print_tensor()` function should be run on a single TPU chip.
    @ray.remote(resources={"TPU": 1})
    def print_tensor():
        # Initializes the runtime environment on each Ray worker. Equivalent to
        # the `torch_xla.launch call` in the Run PyTorch/XLA on multiple devices section.
        init_env()
    
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    ray.init()
    
    # Uses Ray to dispatch the function call across available nodes in the cluster.
    tasks = [print_tensor.remote() for _ in range(GLOBAL_WORLD_SIZE)]
    ray.get(tasks)
    
    ray.shutdown()
    
  2. 在 Ray 头节点上运行脚本。请将 ray-workload.py 替换为脚本的路径。

    python ray-workload.py

    输出类似于以下内容:

    WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
    xla:0
    xla:0
    xla:0
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

    输出表明,系统已成功在多主机 TPU 切片中的每个 XLA 设备(在此示例中为 8 个设备)上调用该函数。

以主机为中心的模式 (JAX)

以下部分介绍了使用 JAX 的以主机为中心的模式。JAX 使用函数式编程范式,并支持更高级别的单程序、多数据 (SPMD) 语义。JAX 代码旨在并发在单个主机上的多个设备中运行,而不是让每个进程与单个 XLA 设备进行交互。

JAX 专为高性能计算而设计,可高效利用 TPU 进行大规模训练和推理。如果您熟悉函数式编程概念,则此模式非常适合您,以便充分利用 JAX 的潜力。

以下说明假定您已设置 Ray 和 TPU 环境,包括包含 JAX 和其他相关软件包的软件环境。如需创建 Ray TPU 集群,请按照使用 TPU 为 KubeRay 启动 Google Cloud GKE 集群中的说明操作。如需详细了解如何将 TPU 与 KubeRay 搭配使用,请参阅将 TPU 与 KubeRay 搭配使用

在单主机 TPU 上运行 JAX 工作负载

以下示例脚本演示了如何在具有单主机 TPU(例如 v6e-4)的 Ray 集群上运行 JAX 函数。如果您有多主机 TPU,则由于 JAX 的多控制器执行模型,此脚本会停止响应。如需详细了解如何在多主机 TPU 上运行 Ray,请参阅在多主机 TPU 上运行 JAX 工作负载

  1. 为 TPU 创建参数创建环境变量。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-a
    export ACCELERATOR_TYPE=v6e-4
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    环境变量说明

    变量 说明
    PROJECT_ID 您的 Google Cloud 项目 ID。使用现有项目或创建新项目
    TPU_NAME TPU 的名称。
    ZONE 要在其中创建 TPU 虚拟机的可用区。如需详细了解支持的可用区,请参阅 TPU 区域和可用区
    ACCELERATOR_TYPE 加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
    RUNTIME_VERSION Cloud TPU 软件版本

  2. 使用以下命令创建一个具有 4 个核心的 v6e TPU 虚拟机:

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE  \
       --version=$RUNTIME_VERSION
  3. 使用以下命令连接到 TPU 虚拟机:

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
  4. 在 TPU 上安装 JAX 和 Ray。

    pip install ray jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  5. 将以下代码保存到文件中。例如 ray-jax-single-host.py

    import ray
    import jax
    
    @ray.remote(resources={"TPU": 4})
    def my_function() -> int:
        return jax.device_count()
    
    h = my_function.remote()
    print(ray.get(h)) # => 4
    

    如果您习惯使用 GPU 运行 Ray,则在使用 TPU 时会有一些主要差异:

    • 请勿设置 num_gpus,而是将 TPU 指定为自定义资源并设置 TPU 芯片数。
    • 使用每个 Ray 工作器节点的芯片数指定 TPU。例如,如果您使用的是 v6e-4,则运行远程函数并将 TPU 设置为 4 会占用整个 TPU 主机。
    • 这与 GPU 的常规运行方式不同,后者是每个主机一个进程。不建议将 TPU 设置为非 4 的数字。
      • 例外情况:如果您使用的是单主机 v6e-8v5litepod-8,则应将此值设置为 8。
  6. 运行脚本。

    python ray-jax-single-host.py

在多主机 TPU 上运行 JAX 工作负载

以下示例脚本演示了如何在具有多主机 TPU 的 Ray 集群上运行 JAX 函数。示例脚本使用 v6e-16。

  1. 为 TPU 创建参数创建环境变量。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-a
    export ACCELERATOR_TYPE=v6e-16
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    环境变量说明

    变量 说明
    PROJECT_ID 您的 Google Cloud 项目 ID。使用现有项目或创建新项目
    TPU_NAME TPU 的名称。
    ZONE 要在其中创建 TPU 虚拟机的可用区。如需详细了解支持的可用区,请参阅 TPU 区域和可用区
    ACCELERATOR_TYPE 加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
    RUNTIME_VERSION Cloud TPU 软件版本

  2. 使用以下命令创建一个具有 16 个核心的 v6e TPU 虚拟机:

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE  \
       --version=$RUNTIME_VERSION
  3. 在所有 TPU 工作器上安装 JAX 和 Ray。

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
       --zone=$ZONE \
       --worker=all \
       --command="pip install ray jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
  4. 将以下代码保存到文件中。例如 ray-jax-multi-host.py

    import ray
    import jax
    
    @ray.remote(resources={"TPU": 4})
    def my_function() -> int:
        return jax.device_count()
    
    ray.init()
    num_tpus = ray.available_resources()["TPU"]
    num_hosts = int(num_tpus) # 4
    h = [my_function.remote() for _ in range(num_hosts)]
    print(ray.get(h)) # [16, 16, 16, 16]
    

    如果您习惯使用 GPU 运行 Ray,则在使用 TPU 时会有一些主要差异:

    • 与 GPU 上的 PyTorch 工作负载类似:
    • 与 GPU 上的 PyTorch 工作负载不同,JAX 可以全局查看集群中的可用设备。
  5. 将脚本复制到所有 TPU 工作器。

    gcloud compute tpus tpu-vm scp ray-jax-multi-host.py $TPU_NAME: --zone=$ZONE --worker=all
  6. 运行脚本。

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
       --zone=$ZONE \
       --worker=all \
       --command="python ray-jax-multi-host.py"

运行多切片 JAX 工作负载

借助多切片,您可以在单个 TPU Pod 内跨多个 TPU 切片运行工作负载,也可以通过数据中心网络在多个 Pod 中运行工作负载。

您可以使用 ray-tpu 软件包来简化 Ray 与 TPU 切片的交互。

使用 pip 安装 ray-tpu

pip install ray-tpu

如需详细了解如何使用 ray-tpu 软件包,请参阅 GitHub 代码库中的使用入门。如需查看使用多切片的示例,请参阅在多切片上运行

使用 Ray 和 MaxText 编排工作负载

如需详细了解如何将 Ray 与 MaxText 搭配使用,请参阅使用 MaxText 运行训练作业

TPU 和 Ray 资源

Ray 对 TPU 的处理方式与对 GPU 的处理方式不同,以适应使用方面的差异。在以下示例中,共有九个 Ray 节点:

  • Ray 头节点在 n1-standard-16 虚拟机上运行。
  • Ray 工作器节点在两个 v6e-16 TPU 上运行。每个 TPU 由四个工作器组成。
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
 1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
 1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
 1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
 1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
 1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
 1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
 1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
 1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
 1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/727.0 CPU
 0.0/32.0 TPU
 0.0/2.0 TPU-v6e-16-head
 0B/5.13TiB memory
 0B/1.47TiB object_store_memory
 0.0/4.0 tpu-group-0
 0.0/4.0 tpu-group-1

Demands:
 (no resource demands)

资源用量字段说明:

  • CPU:集群中可用的 CPU 总数。
  • TPU:集群中的 TPU 芯片数。
  • TPU-v6e-16-head:与 TPU 切片的工作器 0 对应的资源的特殊标识符。这对于访问各个 TPU 切片非常重要。
  • memory:应用使用的工作器堆内存。
  • object_store_memory:应用使用 ray.put 在对象存储区中创建对象以及从远程函数返回值时使用的内存。
  • tpu-group-0tpu-group-1:各个 TPU 切片的唯一标识符。这对于在切片上运行作业非常重要。这些字段设置为 4,因为 v6e-16 中的每个 TPU 切片都有 4 个主机。