使用 Ray 扩缩机器学习工作负载

本文档详细介绍了如何在 TPU 上使用 Ray 和 JAX 运行机器学习 (ML) 工作负载。您可以通过两种不同的模式将 TPU 与 Ray 搭配使用:以设备为中心的模式 (PyTorch/XLA)以主机为中心的模式 (JAX)

本文档假定您已设置 TPU 环境。如需了解详情,请参阅以下资源:

以设备为中心的模式 (PyTorch/XLA)

以设备为中心的模式保留了传统 PyTorch 的大部分编程风格。在此模式下,您可以添加新的 XLA 设备类型,该类型的运作方式与任何其他 PyTorch 设备一样。每个单独的进程都与一个 XLA 设备互动。

如果您已经熟悉 PyTorch with GPUs,并希望使用类似的编码抽象,此模式非常适合。

以下部分介绍了如何在不使用 Ray 的情况下在一个或多个设备上运行 PyTorch/XLA 工作负载,以及如何使用 Ray 在多个主机上运行同一工作负载。

创建 TPU

  1. 为 TPU 创建参数创建环境变量:

    export TPU_NAME=TPU_NAME
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-8
    export VERSION=v2-alpha-tpuv5

    环境变量说明

    TPU_NAME
    新 Cloud TPU 的名称。
    ZONE
    要在其中创建 Cloud TPU 的可用区
    accelerator-type
    加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。如需了解详情,请参阅 TPU 版本
    version
    您要使用的 TPU 软件版本。如需了解详情,请参阅 TPU 虚拟机映像
  2. 使用以下命令创建具有 8 个核心的 v5p TPU VM:

    gcloud compute tpus tpu-vm create $TPU_NAME \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE  \
        --version=$VERSION
  3. 使用以下命令连接到 TPU 虚拟机:

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE

如果您使用的是 GKE,请参阅 KubeRay on GKE 指南,了解设置信息。

安装要求

在 TPU 虚拟机上运行以下命令以安装所需的依赖项:

  1. 将以下内容保存到文件(例如 requirements.txt)中:

    --find-links https://storage.googleapis.com/libtpu-releases/index.html
    --find-links https://storage.googleapis.com/libtpu-wheels/index.html
    torch~=2.6.0
    torch_xla[tpu]~=2.6.0
    ray[default]==2.40.0
    
  2. 运行以下命令以安装所需的依赖项:

    pip install -r requirements.txt
    

如果您在 GKE 上运行工作负载,我们建议您创建一个用于安装所需依赖项的 Dockerfile。如需查看示例,请参阅 GKE 文档中的在 TPU 切片节点上运行工作负载

在单个设备上运行 PyTorch/XLA 工作负载

以下示例演示了如何在单个设备(即 TPU 芯片)上创建 XLA 张量。这与 PyTorch 处理其他设备类型的方式类似。

  1. 将以下代码段保存到文件(例如 workload.py)中:

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    t = torch.randn(2, 2, device=xm.xla_device())
    print(t.device)
    print(t)
    

    import torch_xla 导入语句会初始化 PyTorch/XLA,xm.xla_device() 函数会返回当前的 XLA 设备(即 TPU 芯片)。

  2. PJRT_DEVICE 环境变量设置为 TPU:

    export PJRT_DEVICE=TPU
    
  3. 运行脚本:

    python workload.py
    

    输出类似于以下内容。确保输出指示找到了 XLA 设备。

    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

在多台设备上运行 PyTorch/XLA

  1. 更新上一部分中的代码段,以便在多部设备上运行:

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    def _mp_fn(index):
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    if __name__ == '__main__':
        torch_xla.launch(_mp_fn, args=())
    
  2. 运行脚本:

    python workload.py
    

    如果您在 TPU v5p-8 上运行代码段,输出将类似于以下内容:

    xla:0
    xla:0
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    

torch_xla.launch() 接受两个参数:一个函数和一个参数列表。它会为每个可用的 XLA 设备创建一个进程,并调用参数中指定的函数。在此示例中,有 4 个可用的 TPU 设备,因此 torch_xla.launch() 会创建 4 个进程,并在每个设备上调用 _mp_fn()。每个进程只能访问一个设备,因此每个设备的编号为 0,并且系统会为所有进程输出 xla:0

使用 Ray 在多台主机上运行 PyTorch/XLA

以下部分介绍了如何在更大的多主机 TPU slice 上运行相同的代码段。如需详细了解多主机 TPU 架构,请参阅系统架构

在此示例中,您将手动设置 Ray。如果您已熟悉 Ray 的设置,可以跳到最后一部分,即运行 Ray 工作负载。如需详细了解如何为生产环境设置 Ray,请参阅以下资源:

创建多主机 TPU 虚拟机

  1. 为 TPU 创建参数创建环境变量:

    export TPU_NAME_MULTIHOST=TPU_NAME_MULTIHOST
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE_MULTIHOST=v5p-16
    export VERSION=v2-alpha-tpuv5
  2. 使用以下命令创建一个包含 2 个主机的多主机 TPU v5p(v5p-16,每个主机上有 4 个 TPU 芯片):

    gcloud compute tpus tpu-vm create $TPU_NAME_MULTIHOST \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE_MULTIHOST \
        --version=$VERSION

设置 Ray

TPU v5p-16 有 2 个 TPU 主机,每个主机有 4 个 TPU 芯片。在此示例中,您将在一个主机上启动 Ray 主节点,并将第二个主机添加为 Ray 集群的工作器节点。

  1. 使用 SSH 连接到第一个主机:

    gcloud compute tpus tpu-vm ssh $TPU_NAME_MULTIHOST --zone=$ZONE --worker=0
  2. 使用与“安装要求”部分中相同的要求文件安装依赖项:

    pip install -r requirements.txt
    
  3. 启动 Ray 进程:

    ray start --head --port=6379
    

    输出类似于以下内容:

    Enable usage stats collection? This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup. Confirm [Y/n]: y
    Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details.
    
    Local node IP: 10.130.0.76
    
    --------------------
    Ray runtime started.
    --------------------
    
    Next steps
    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    
    To connect to this Ray cluster:
        import ray
        ray.init()
    
    To terminate the Ray runtime, run
        ray stop
    
    To view the status of the cluster, use
        ray status
    

    此 TPU 主机现在是 Ray 主节点。记下显示如何将其他节点添加到 Ray 集群的行,类似于以下内容:

    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    

    您将在后面的步骤中使用此命令。

  4. 检查 Ray 集群状态:

    ray status
    

    输出类似于以下内容:

    ======== Autoscaler status: 2025-01-14 22:03:39.385610 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/208.0 CPU
    0.0/4.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/268.44GiB memory
    0B/119.04GiB object_store_memory
    0.0/1.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    该集群仅包含 4 个 TPU (0.0/4.0 TPU),因为您目前只添加了主节点。

现在,主节点已在运行,您可以将第二个主机添加到集群中。

  1. 使用 SSH 连接到第二个主机:

    gcloud compute tpus tpu-vm ssh $TPU_NAME_MULTIHOST --zone=$ZONE --worker=1
  2. 使用与“安装要求”部分中相同的要求文件安装依赖项:

    pip install -r requirements.txt
    
  3. 启动 Ray 进程。使用 ray start 命令的输出中的命令将此节点添加到现有 Ray 集群。请务必替换以下命令中的 IP 地址和端口:

    ray start --address='10.130.0.76:6379'

    输出类似于以下内容:

    Local node IP: 10.130.0.80
    [2025-01-14 22:30:07,397 W 75572 75572] global_state_accessor.cc:463: Retrying to get node with node ID 35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    
    --------------------
    Ray runtime started.
    --------------------
    
    To terminate the Ray runtime, run
    ray stop
    
  4. 再次检查 Ray 状态:

    ray status
    

    输出类似于以下内容:

    ======== Autoscaler status: 2025-01-14 22:45:21.485617 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    1 node_35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/416.0 CPU
    0.0/8.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/546.83GiB memory
    0B/238.35GiB object_store_memory
    0.0/2.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    第二个 TPU 主机现在是集群中的节点。可用资源列表现在显示 8 个 TPU (0.0/8.0 TPU)。

运行 Ray 工作负载

  1. 更新要在 Ray 集群上运行的代码段:

    import os
    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    import ray
    
    import torch.distributed as dist
    import torch_xla.runtime as xr
    from torch_xla._internal import pjrt
    
    # Defines the local PJRT world size, the number of processes per host
    LOCAL_WORLD_SIZE = 4
    # Defines the number of hosts in the Ray cluster
    NUM_OF_HOSTS = 2
    GLOBAL_WORLD_SIZE = LOCAL_WORLD_SIZE * NUM_OF_HOSTS
    
    def init_env():
        local_rank = int(os.environ['TPU_VISIBLE_CHIPS'])
    
        pjrt.initialize_multiprocess(local_rank, LOCAL_WORLD_SIZE)
        xr._init_world_size_ordinal()
    
    # This decorator signals to Ray that the print_tensor() function should be run on a single TPU chip
    @ray.remote(resources={"TPU": 1})
    def print_tensor():
        # Initializes the runtime environment on each Ray worker. Equivalent to
        # the `torch_xla.launch call` in the Run PyTorch/XLA on multiple devices section.
        init_env()
    
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    ray.init()
    
    # Uses Ray to dispatch the function call across available nodes in the cluster
    tasks = [print_tensor.remote() for _ in range(GLOBAL_WORLD_SIZE)] 
    ray.get(tasks)
    
    ray.shutdown()
    
  2. 在 Ray 主节点上运行脚本。将 ray-workload.py 替换为脚本的路径。

    python ray-workload.py

    输出类似于以下内容:

    WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
    xla:0
    xla:0
    xla:0
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

    输出结果表明,系统已成功在多主机 TPU 切片中的每个 XLA 设备(在此示例中为 8 个设备)上调用该函数。

以主机为中心的模式 (JAX)

以下部分介绍了使用 JAX 的以主机为中心的模式。 JAX 采用函数式编程范式,支持更高级别的单程序多数据 (SPMD) 语义。JAX 代码旨在跨单个主机上的多个设备并发运行,而不是让每个进程与单个 XLA 设备互动。

JAX 专为高性能计算而设计,可高效利用 TPU 进行大规模训练和推理。如果您熟悉函数式编程概念,则非常适合使用此模式,以便充分利用 JAX 的潜力。

以下说明假定您已设置 Ray 和 TPU 环境,包括包含 JAX 和其他相关软件包的软件环境。如需创建 Ray TPU 集群,请按照启动Google Cloud 适用于 KubeRay 的 GKE 集群(含 TPU)中的说明操作。如需详细了解如何将 TPU 与 KubeRay 搭配使用,请参阅将 TPU 与 KubeRay 搭配使用

在单主机 TPU 上运行 JAX 工作负载

以下示例脚本演示了如何在包含单主机 TPU(例如 v6e-4)的 Ray 集群上运行 JAX 函数。如果您使用的是多主机 TPU,由于 JAX 的多控制器执行模型,此脚本会停止响应。如需详细了解如何在多主机 TPU 上运行 Ray,请参阅在多主机 TPU 上运行 JAX 工作负载

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

h = my_function.remote()
print(ray.get(h)) # => 4

如果您习惯使用 GPU 运行 Ray,那么在使用 TPU 时会发现一些关键差异:

  • 您可以将 TPU 指定为自定义资源,而不是设置 num_gpus,并设置 TPU 芯片的数量。
  • 您可以使用每个 Ray 工作器节点的芯片数来指定 TPU。例如,如果您使用的是 v6e-4,则将 TPU 设置为 4 并运行远程函数会占用整个 TPU 主机。
    • 这与 GPU 的常规运行方式不同,后者是每个主机一个进程。不建议将 TPU 设置为非 4 的数字。
    • 例外情况:如果您使用的是单主机 v6e-8v5litepod-8,则应将此值设置为 8。

在多主机 TPU 上运行 JAX 工作负载

以下示例脚本演示了如何在具有多主机 TPU 的 Ray 集群上运行 JAX 函数。示例脚本使用 v6e-16。

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]

如果您习惯使用 GPU 运行 Ray,那么在使用 TPU 时会发现一些关键差异:

  • 与 GPU 上的 PyTorch 工作负载类似:
  • 与 GPU 上的 PyTorch 工作负载不同,JAX 可以全局查看集群中的可用设备。

运行多切片 JAX 工作负载

借助 Multislice,您可以在单个 TPU Pod 内跨多个 TPU 切片运行工作负载,也可以在数据中心网络上的多个 Pod 中运行工作负载。

您可以使用 ray-tpu 软件包来简化 Ray 与 TPU slice 的交互。使用 pip 安装 ray-tpu

pip install ray-tpu

以下示例脚本展示了如何使用 ray-tpu 软件包使用 Ray 演员或任务运行多片工作负载:

from ray_tpu import RayTpuManager
import jax
import ray

ray.init()

# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
    def get_devices(self):
        return jax.device_count()

# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
    return jax.device_count()

tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus) 
"""
TPU resources:
{'v6e-16': [
    RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
    RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""

# if using actors
actors = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=MyActor,
    multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]

# if using tasks
h = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]

# note - you can also run this without Multislice
h = RayTpuManager.run_task(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]

使用 Ray 和 MaxText 编排工作负载

本部分介绍如何使用 Ray 使用 MaxText 编排工作负载。MaxText 是一个可扩缩且高性能的开源库,可用于使用 JAX 和 XLA 训练 LLM。

MaxText 包含一个训练脚本 train.py,需要在每个 TPU 主机上运行。这与其他 SPMD 机器学习工作负载类似。您可以使用 ray-tpu 软件包并在 train.py 主函数周围创建封装容器来实现此目的。以下步骤展示了如何使用 ray-tpu 软件包在 TPU v4-16 上运行 MaxText。

  1. 为 TPU 创建参数设置环境变量:

    export TPU_NAME=TPU_NAME
    export ZONE=ZONE
    export ACCELERATOR_TYPE=v6e-16
    export VERSION=v2-alpha-tpuv6e
  2. 创建 TPU v6e-16:

    gcloud compute tpus tpu-vm create $TPU_NAME \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE \
        --version=$VERSION
  3. 在所有 TPU 工作器上克隆 MaxText 代码库:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=all \
        --command="git clone https://github.com/AI-Hypercomputer/maxtext"
  4. 在所有 TPU 工作器上安装 MaxText 要求:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=all \
        --command="pip install -r maxtext/requirements.txt"
  5. 在所有 TPU 工作器上安装 ray-tpu 软件包:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=all \
        --command="pip install ray-tpu"
  6. 使用 SSH 连接到 worker 0:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=0
  7. 将以下脚本保存到 ~/maxtext/MaxText 目录中的名为 ray_trainer.py 的文件中。此脚本使用 ray-tpu 软件包,并在 MaxText 的 train.py 主函数周围创建封装容器。

    import ray
    import ray_tpu
    from train import main as maxtext_main
    
    import logging
    from typing import Sequence
    from absl import app
    
    # Default env vars that run on all TPU VMs.
    MACHINE_ENV_VARS = {
        "ENABLE_PJRT_COMPATIBILITY": "true",
        "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true",
        "TPU_SLICE_BUILDER_DUMP_ICI": "true",
        "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",  # Dumps HLOs for debugging
    }
    
    def setup_loggers():
        """Sets up loggers for Ray."""
        logging.basicConfig(level=logging.INFO)
    
    @ray_tpu.remote(
        topology={"v4-16": 1},
    )
    def run_maxtext_train(argv: Sequence[str]):
        maxtext_main(argv=argv)
    
    def main(argv: Sequence[str]):
        ray.init(runtime_env=dict(worker_process_setup_hook=setup_loggers))
    
        logging.info(f"argv: {argv}")
    
        try:
            ray.get(run_maxtext_train(argv=argv))
        except Exception as e:
            logging.error("Caught error during training: %s", e)
            logging.error("Shutting down...")
            ray.shutdown()
            raise e
    
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        app.run(main)
    
  8. 运行以下命令来执行脚本:

        python maxtext/MaxText/ray_trainer.py maxtext/MaxText/configs/base.yml \
            base_output_directory=/tmp/maxtext \
            dataset_type=synthetic \
            per_device_batch_size=2 \
            max_target_length=8192 \
            model_name=default \
            steps=100 \
            run_name=test
    

    输出类似于以下内容:

    (run_maxtext_train pid=78967, ip=10.130.0.11) Started an asynchronous checkpoint save for step 0
    (run_maxtext_train pid=78967, ip=10.130.0.11)
    (run_maxtext_train pid=78967, ip=10.130.0.11) Memstats: After params initialized:
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_4(process=1,(0,0,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_5(process=1,(1,0,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_6(process=1,(0,1,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_7(process=1,(1,1,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 0, seconds: 11.775, TFLOP/s/device: 13.153, Tokens/s/device: 1391.395, total_weights: 131072, loss: 12.066
    (run_maxtext_train pid=80538, ip=10.130.0.12)
    (run_maxtext_train pid=80538, ip=10.130.0.12) To see full metrics 'tensorboard --logdir=/tmp/maxtext/test/tensorboard/'
    (run_maxtext_train pid=80538, ip=10.130.0.12) Waiting for step 0 to finish before checkpoint...
    (run_maxtext_train pid=80538, ip=10.130.0.12) Waited 0.7087039947509766 seconds for step 0 to finish before starting checkpointing.
    (run_maxtext_train pid=80538, ip=10.130.0.12) Started an asynchronous checkpoint save for step 0
    (run_maxtext_train pid=80538, ip=10.130.0.12) Memstats: After params initialized:
    (run_maxtext_train pid=80538, ip=10.130.0.12)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_3(process=0,(1,1,0,0)) [repeated 4x across cluster]
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 4, seconds: 1.116, TFLOP/s/device: 138.799, Tokens/s/device: 14683.240, total_weights: 131072, loss: 0.000 [repeated 9x across cluster]
    (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 9, seconds: 1.068, TFLOP/s/device: 145.065, Tokens/s/device: 15346.083, total_weights: 131072, loss: 0.000 [repeated 9x across cluster]
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 14, seconds: 1.116, TFLOP/s/device: 138.754, Tokens/s/device: 14678.439, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    
    ...
    
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 89, seconds: 1.116, TFLOP/s/device: 138.760, Tokens/s/device: 14679.083, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 94, seconds: 1.091, TFLOP/s/device: 141.924, Tokens/s/device: 15013.837, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 99, seconds: 1.116, TFLOP/s/device: 138.763, Tokens/s/device: 14679.412, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    (run_maxtext_train pid=80538, ip=10.130.0.12) Output size: 1657041920, temp size: 4907988480, argument size: 1657366016, host temp size: 0, in bytes.
    I0121 01:39:46.830807 130655182204928 ray_trainer.py:47] Training complete!
    (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 99, seconds: 1.191, TFLOP/s/device: 130.014, Tokens/s/device: 13753.874, total_weights: 131072, loss: 0.000
    

TPU 和 Ray 资源

Ray 对 TPU 的处理方式与 GPU 不同,以适应使用方面的差异。在以下示例中,共有 9 个 Ray 节点:

  • Ray 主节点在 n1-standard-16 虚拟机上运行。
  • Ray 工作器节点在两个 v6e-16 TPU 上运行。每个 TPU 由四个工作器组成。
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
 1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
 1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
 1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
 1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
 1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
 1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
 1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
 1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
 1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/727.0 CPU
 0.0/32.0 TPU
 0.0/2.0 TPU-v6e-16-head
 0B/5.13TiB memory
 0B/1.47TiB object_store_memory
 0.0/4.0 tpu-group-0
 0.0/4.0 tpu-group-1

Demands:
 (no resource demands)

资源使用情况字段说明:

  • CPU:集群中可用的 CPU 总数。
  • TPU:集群中的 TPU 芯片数量。
  • TPU-v6e-16-head:与 TPU 切片的工作器 0 对应的资源的特殊标识符。这对于访问各个 TPU slice 至关重要。
  • memory:应用使用的 Worker 堆内存。
  • object_store_memory:应用使用 ray.put 在对象存储区中创建对象以及从远程函数返回值时所用的内存。
  • tpu-group-0tpu-group-1:各个 TPU slice 的唯一标识符。这对于在 slice 上运行作业非常重要。这些字段设置为 4,因为 v6e-16 中的每个 TPU 切片都有 4 个主机。