使用 Ray 扩缩机器学习工作负载
本文档详细介绍了如何在 TPU 上使用 Ray 和 JAX 运行机器学习 (ML) 工作负载。您可以通过两种不同的模式将 TPU 与 Ray 搭配使用:以设备为中心的模式 (PyTorch/XLA) 和以主机为中心的模式 (JAX)。
本文档假定您已设置 TPU 环境。如需了解详情,请参阅以下资源:
- Cloud TPU:设置 Cloud TPU 环境和管理 TPU 资源
- Google Kubernetes Engine (GKE):在 GKE Autopilot 中部署 TPU 工作负载或在 GKE Standard 中部署 TPU 工作负载
以设备为中心的模式 (PyTorch/XLA)
以设备为中心的模式保留了传统 PyTorch 的大部分编程风格。在此模式下,您可以添加新的 XLA 设备类型,该类型的运作方式与任何其他 PyTorch 设备一样。每个单独的进程都与一个 XLA 设备互动。
如果您已经熟悉 PyTorch with GPUs,并希望使用类似的编码抽象,此模式非常适合。
以下部分介绍了如何在不使用 Ray 的情况下在一个或多个设备上运行 PyTorch/XLA 工作负载,以及如何使用 Ray 在多个主机上运行同一工作负载。
创建 TPU
为 TPU 创建参数创建环境变量:
export TPU_NAME=TPU_NAME export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5p-8 export VERSION=v2-alpha-tpuv5
使用以下命令创建具有 8 个核心的 v5p TPU VM:
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$VERSION
使用以下命令连接到 TPU 虚拟机:
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
如果您使用的是 GKE,请参阅 KubeRay on GKE 指南,了解设置信息。
安装要求
在 TPU 虚拟机上运行以下命令以安装所需的依赖项:
将以下内容保存到文件(例如
requirements.txt
)中:--find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/libtpu-wheels/index.html torch~=2.6.0 torch_xla[tpu]~=2.6.0 ray[default]==2.40.0
运行以下命令以安装所需的依赖项:
pip install -r requirements.txt
如果您在 GKE 上运行工作负载,我们建议您创建一个用于安装所需依赖项的 Dockerfile。如需查看示例,请参阅 GKE 文档中的在 TPU 切片节点上运行工作负载。
在单个设备上运行 PyTorch/XLA 工作负载
以下示例演示了如何在单个设备(即 TPU 芯片)上创建 XLA 张量。这与 PyTorch 处理其他设备类型的方式类似。
将以下代码段保存到文件(例如
workload.py
)中:import torch import torch_xla import torch_xla.core.xla_model as xm t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t)
import torch_xla
导入语句会初始化 PyTorch/XLA,xm.xla_device()
函数会返回当前的 XLA 设备(即 TPU 芯片)。将
PJRT_DEVICE
环境变量设置为 TPU:export PJRT_DEVICE=TPU
运行脚本:
python workload.py
输出类似于以下内容。确保输出指示找到了 XLA 设备。
xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0')
在多台设备上运行 PyTorch/XLA
更新上一部分中的代码段,以便在多部设备上运行:
import torch import torch_xla import torch_xla.core.xla_model as xm def _mp_fn(index): t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t) if __name__ == '__main__': torch_xla.launch(_mp_fn, args=())
运行脚本:
python workload.py
如果您在 TPU v5p-8 上运行代码段,输出将类似于以下内容:
xla:0 xla:0 xla:0 tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') xla:0 tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0')
torch_xla.launch()
接受两个参数:一个函数和一个参数列表。它会为每个可用的 XLA 设备创建一个进程,并调用参数中指定的函数。在此示例中,有 4 个可用的 TPU 设备,因此 torch_xla.launch()
会创建 4 个进程,并在每个设备上调用 _mp_fn()
。每个进程只能访问一个设备,因此每个设备的编号为 0,并且系统会为所有进程输出 xla:0
。
使用 Ray 在多台主机上运行 PyTorch/XLA
以下部分介绍了如何在更大的多主机 TPU slice 上运行相同的代码段。如需详细了解多主机 TPU 架构,请参阅系统架构。
在此示例中,您将手动设置 Ray。如果您已熟悉 Ray 的设置,可以跳到最后一部分,即运行 Ray 工作负载。如需详细了解如何为生产环境设置 Ray,请参阅以下资源:
创建多主机 TPU 虚拟机
为 TPU 创建参数创建环境变量:
export TPU_NAME_MULTIHOST=TPU_NAME_MULTIHOST export ZONE=europe-west4-b export ACCELERATOR_TYPE_MULTIHOST=v5p-16 export VERSION=v2-alpha-tpuv5
使用以下命令创建一个包含 2 个主机的多主机 TPU v5p(v5p-16,每个主机上有 4 个 TPU 芯片):
gcloud compute tpus tpu-vm create $TPU_NAME_MULTIHOST \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE_MULTIHOST \ --version=$VERSION
设置 Ray
TPU v5p-16 有 2 个 TPU 主机,每个主机有 4 个 TPU 芯片。在此示例中,您将在一个主机上启动 Ray 主节点,并将第二个主机添加为 Ray 集群的工作器节点。
使用 SSH 连接到第一个主机:
gcloud compute tpus tpu-vm ssh $TPU_NAME_MULTIHOST --zone=$ZONE --worker=0
使用与“安装要求”部分中相同的要求文件安装依赖项:
pip install -r requirements.txt
启动 Ray 进程:
ray start --head --port=6379
输出类似于以下内容:
Enable usage stats collection? This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup. Confirm [Y/n]: y Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details. Local node IP: 10.130.0.76 -------------------- Ray runtime started. -------------------- Next steps To add another node to this Ray cluster, run ray start --address='10.130.0.76:6379' To connect to this Ray cluster: import ray ray.init() To terminate the Ray runtime, run ray stop To view the status of the cluster, use ray status
此 TPU 主机现在是 Ray 主节点。记下显示如何将其他节点添加到 Ray 集群的行,类似于以下内容:
To add another node to this Ray cluster, run ray start --address='10.130.0.76:6379'
您将在后面的步骤中使用此命令。
检查 Ray 集群状态:
ray status
输出类似于以下内容:
======== Autoscaler status: 2025-01-14 22:03:39.385610 ======== Node status --------------------------------------------------------------- Active: 1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79 Pending: (no pending nodes) Recent failures: (no failures) Resources --------------------------------------------------------------- Usage: 0.0/208.0 CPU 0.0/4.0 TPU 0.0/1.0 TPU-v5p-16-head 0B/268.44GiB memory 0B/119.04GiB object_store_memory 0.0/1.0 your-tpu-name Demands: (no resource demands)
该集群仅包含 4 个 TPU (
0.0/4.0 TPU
),因为您目前只添加了主节点。
现在,主节点已在运行,您可以将第二个主机添加到集群中。
使用 SSH 连接到第二个主机:
gcloud compute tpus tpu-vm ssh $TPU_NAME_MULTIHOST --zone=$ZONE --worker=1
使用与“安装要求”部分中相同的要求文件安装依赖项:
pip install -r requirements.txt
启动 Ray 进程。使用
ray start
命令的输出中的命令将此节点添加到现有 Ray 集群。请务必替换以下命令中的 IP 地址和端口:ray start --address='10.130.0.76:6379'
输出类似于以下内容:
Local node IP: 10.130.0.80 [2025-01-14 22:30:07,397 W 75572 75572] global_state_accessor.cc:463: Retrying to get node with node ID 35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1 -------------------- Ray runtime started. -------------------- To terminate the Ray runtime, run ray stop
再次检查 Ray 状态:
ray status
输出类似于以下内容:
======== Autoscaler status: 2025-01-14 22:45:21.485617 ======== Node status --------------------------------------------------------------- Active: 1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79 1 node_35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1 Pending: (no pending nodes) Recent failures: (no failures) Resources --------------------------------------------------------------- Usage: 0.0/416.0 CPU 0.0/8.0 TPU 0.0/1.0 TPU-v5p-16-head 0B/546.83GiB memory 0B/238.35GiB object_store_memory 0.0/2.0 your-tpu-name Demands: (no resource demands)
第二个 TPU 主机现在是集群中的节点。可用资源列表现在显示 8 个 TPU (
0.0/8.0 TPU
)。
运行 Ray 工作负载
更新要在 Ray 集群上运行的代码段:
import os import torch import torch_xla import torch_xla.core.xla_model as xm import ray import torch.distributed as dist import torch_xla.runtime as xr from torch_xla._internal import pjrt # Defines the local PJRT world size, the number of processes per host LOCAL_WORLD_SIZE = 4 # Defines the number of hosts in the Ray cluster NUM_OF_HOSTS = 2 GLOBAL_WORLD_SIZE = LOCAL_WORLD_SIZE * NUM_OF_HOSTS def init_env(): local_rank = int(os.environ['TPU_VISIBLE_CHIPS']) pjrt.initialize_multiprocess(local_rank, LOCAL_WORLD_SIZE) xr._init_world_size_ordinal() # This decorator signals to Ray that the print_tensor() function should be run on a single TPU chip @ray.remote(resources={"TPU": 1}) def print_tensor(): # Initializes the runtime environment on each Ray worker. Equivalent to # the `torch_xla.launch call` in the Run PyTorch/XLA on multiple devices section. init_env() t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t) ray.init() # Uses Ray to dispatch the function call across available nodes in the cluster tasks = [print_tensor.remote() for _ in range(GLOBAL_WORLD_SIZE)] ray.get(tasks) ray.shutdown()
在 Ray 主节点上运行脚本。将 ray-workload.py 替换为脚本的路径。
python ray-workload.py
输出类似于以下内容:
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU. xla:0 xla:0 xla:0 xla:0 xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') xla:0 xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0')
输出结果表明,系统已成功在多主机 TPU 切片中的每个 XLA 设备(在此示例中为 8 个设备)上调用该函数。
以主机为中心的模式 (JAX)
以下部分介绍了使用 JAX 的以主机为中心的模式。 JAX 采用函数式编程范式,支持更高级别的单程序多数据 (SPMD) 语义。JAX 代码旨在跨单个主机上的多个设备并发运行,而不是让每个进程与单个 XLA 设备互动。
JAX 专为高性能计算而设计,可高效利用 TPU 进行大规模训练和推理。如果您熟悉函数式编程概念,则非常适合使用此模式,以便充分利用 JAX 的潜力。
以下说明假定您已设置 Ray 和 TPU 环境,包括包含 JAX 和其他相关软件包的软件环境。如需创建 Ray TPU 集群,请按照启动Google Cloud 适用于 KubeRay 的 GKE 集群(含 TPU)中的说明操作。如需详细了解如何将 TPU 与 KubeRay 搭配使用,请参阅将 TPU 与 KubeRay 搭配使用。
在单主机 TPU 上运行 JAX 工作负载
以下示例脚本演示了如何在包含单主机 TPU(例如 v6e-4)的 Ray 集群上运行 JAX 函数。如果您使用的是多主机 TPU,由于 JAX 的多控制器执行模型,此脚本会停止响应。如需详细了解如何在多主机 TPU 上运行 Ray,请参阅在多主机 TPU 上运行 JAX 工作负载。
import ray
import jax
@ray.remote(resources={"TPU": 4})
def my_function() -> int:
return jax.device_count()
h = my_function.remote()
print(ray.get(h)) # => 4
如果您习惯使用 GPU 运行 Ray,那么在使用 TPU 时会发现一些关键差异:
- 您可以将
TPU
指定为自定义资源,而不是设置num_gpus
,并设置 TPU 芯片的数量。 - 您可以使用每个 Ray 工作器节点的芯片数来指定 TPU。例如,如果您使用的是 v6e-4,则将
TPU
设置为 4 并运行远程函数会占用整个 TPU 主机。- 这与 GPU 的常规运行方式不同,后者是每个主机一个进程。不建议将
TPU
设置为非 4 的数字。 - 例外情况:如果您使用的是单主机
v6e-8
或v5litepod-8
,则应将此值设置为 8。
- 这与 GPU 的常规运行方式不同,后者是每个主机一个进程。不建议将
在多主机 TPU 上运行 JAX 工作负载
以下示例脚本演示了如何在具有多主机 TPU 的 Ray 集群上运行 JAX 函数。示例脚本使用 v6e-16。
import ray
import jax
@ray.remote(resources={"TPU": 4})
def my_function() -> int:
return jax.device_count()
num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]
如果您习惯使用 GPU 运行 Ray,那么在使用 TPU 时会发现一些关键差异:
- 与 GPU 上的 PyTorch 工作负载类似:
- TPU 上的 JAX 工作负载以多控制器、单程序、多数据 (SPMD) 方式运行。
- 设备之间的集合由机器学习框架处理。
- 与 GPU 上的 PyTorch 工作负载不同,JAX 可以全局查看集群中的可用设备。
运行多切片 JAX 工作负载
借助 Multislice,您可以在单个 TPU Pod 内跨多个 TPU 切片运行工作负载,也可以在数据中心网络上的多个 Pod 中运行工作负载。
您可以使用 ray-tpu
软件包来简化 Ray 与 TPU slice 的交互。使用 pip
安装 ray-tpu
:
pip install ray-tpu
以下示例脚本展示了如何使用 ray-tpu
软件包使用 Ray 演员或任务运行多片工作负载:
from ray_tpu import RayTpuManager
import jax
import ray
ray.init()
# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
def get_devices(self):
return jax.device_count()
# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
return jax.device_count()
tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus)
"""
TPU resources:
{'v6e-16': [
RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""
# if using actors
actors = RayTpuManager.remote(
tpus=tpus["v6e-16"],
actor_or_fn=MyActor,
multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]
# if using tasks
h = RayTpuManager.remote(
tpus=tpus["v6e-16"],
actor_or_fn=get_devices,
multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]
# note - you can also run this without Multislice
h = RayTpuManager.run_task(
tpus=tpus["v6e-16"],
actor_or_fn=get_devices,
multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]
使用 Ray 和 MaxText 编排工作负载
本部分介绍如何使用 Ray 使用 MaxText 编排工作负载。MaxText 是一个可扩缩且高性能的开源库,可用于使用 JAX 和 XLA 训练 LLM。
MaxText 包含一个训练脚本 train.py
,需要在每个 TPU 主机上运行。这与其他 SPMD 机器学习工作负载类似。您可以使用 ray-tpu
软件包并在 train.py
主函数周围创建封装容器来实现此目的。以下步骤展示了如何使用 ray-tpu
软件包在 TPU v4-16 上运行 MaxText。
为 TPU 创建参数设置环境变量:
export TPU_NAME=TPU_NAME export ZONE=ZONE export ACCELERATOR_TYPE=v6e-16 export VERSION=v2-alpha-tpuv6e
创建 TPU v6e-16:
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$VERSION
在所有 TPU 工作器上克隆 MaxText 代码库:
gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=all \ --command="git clone https://github.com/AI-Hypercomputer/maxtext"
在所有 TPU 工作器上安装 MaxText 要求:
gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=all \ --command="pip install -r maxtext/requirements.txt"
在所有 TPU 工作器上安装
ray-tpu
软件包:gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=all \ --command="pip install ray-tpu"
使用 SSH 连接到 worker 0:
gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=0
将以下脚本保存到
~/maxtext/MaxText
目录中的名为ray_trainer.py
的文件中。此脚本使用ray-tpu
软件包,并在 MaxText 的train.py
主函数周围创建封装容器。import ray import ray_tpu from train import main as maxtext_main import logging from typing import Sequence from absl import app # Default env vars that run on all TPU VMs. MACHINE_ENV_VARS = { "ENABLE_PJRT_COMPATIBILITY": "true", "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true", "TPU_SLICE_BUILDER_DUMP_ICI": "true", "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto", # Dumps HLOs for debugging } def setup_loggers(): """Sets up loggers for Ray.""" logging.basicConfig(level=logging.INFO) @ray_tpu.remote( topology={"v4-16": 1}, ) def run_maxtext_train(argv: Sequence[str]): maxtext_main(argv=argv) def main(argv: Sequence[str]): ray.init(runtime_env=dict(worker_process_setup_hook=setup_loggers)) logging.info(f"argv: {argv}") try: ray.get(run_maxtext_train(argv=argv)) except Exception as e: logging.error("Caught error during training: %s", e) logging.error("Shutting down...") ray.shutdown() raise e logging.info("Training complete!") ray.shutdown() if __name__ == "__main__": logger = logging.getLogger() logger.setLevel(logging.INFO) app.run(main)
运行以下命令来执行脚本:
python maxtext/MaxText/ray_trainer.py maxtext/MaxText/configs/base.yml \ base_output_directory=/tmp/maxtext \ dataset_type=synthetic \ per_device_batch_size=2 \ max_target_length=8192 \ model_name=default \ steps=100 \ run_name=test
输出类似于以下内容:
(run_maxtext_train pid=78967, ip=10.130.0.11) Started an asynchronous checkpoint save for step 0 (run_maxtext_train pid=78967, ip=10.130.0.11) (run_maxtext_train pid=78967, ip=10.130.0.11) Memstats: After params initialized: (run_maxtext_train pid=78967, ip=10.130.0.11) Using (GB) 1.59 / 30.75 (5.170732%) on TPU_4(process=1,(0,0,1,0)) (run_maxtext_train pid=78967, ip=10.130.0.11) Using (GB) 1.59 / 30.75 (5.170732%) on TPU_5(process=1,(1,0,1,0)) (run_maxtext_train pid=78967, ip=10.130.0.11) Using (GB) 1.59 / 30.75 (5.170732%) on TPU_6(process=1,(0,1,1,0)) (run_maxtext_train pid=78967, ip=10.130.0.11) Using (GB) 1.59 / 30.75 (5.170732%) on TPU_7(process=1,(1,1,1,0)) (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 0, seconds: 11.775, TFLOP/s/device: 13.153, Tokens/s/device: 1391.395, total_weights: 131072, loss: 12.066 (run_maxtext_train pid=80538, ip=10.130.0.12) (run_maxtext_train pid=80538, ip=10.130.0.12) To see full metrics 'tensorboard --logdir=/tmp/maxtext/test/tensorboard/' (run_maxtext_train pid=80538, ip=10.130.0.12) Waiting for step 0 to finish before checkpoint... (run_maxtext_train pid=80538, ip=10.130.0.12) Waited 0.7087039947509766 seconds for step 0 to finish before starting checkpointing. (run_maxtext_train pid=80538, ip=10.130.0.12) Started an asynchronous checkpoint save for step 0 (run_maxtext_train pid=80538, ip=10.130.0.12) Memstats: After params initialized: (run_maxtext_train pid=80538, ip=10.130.0.12) Using (GB) 1.59 / 30.75 (5.170732%) on TPU_3(process=0,(1,1,0,0)) [repeated 4x across cluster] (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 4, seconds: 1.116, TFLOP/s/device: 138.799, Tokens/s/device: 14683.240, total_weights: 131072, loss: 0.000 [repeated 9x across cluster] (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 9, seconds: 1.068, TFLOP/s/device: 145.065, Tokens/s/device: 15346.083, total_weights: 131072, loss: 0.000 [repeated 9x across cluster] (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 14, seconds: 1.116, TFLOP/s/device: 138.754, Tokens/s/device: 14678.439, total_weights: 131072, loss: 0.000 [repeated 10x across cluster] ... (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 89, seconds: 1.116, TFLOP/s/device: 138.760, Tokens/s/device: 14679.083, total_weights: 131072, loss: 0.000 [repeated 10x across cluster] (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 94, seconds: 1.091, TFLOP/s/device: 141.924, Tokens/s/device: 15013.837, total_weights: 131072, loss: 0.000 [repeated 10x across cluster] (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 99, seconds: 1.116, TFLOP/s/device: 138.763, Tokens/s/device: 14679.412, total_weights: 131072, loss: 0.000 [repeated 10x across cluster] (run_maxtext_train pid=80538, ip=10.130.0.12) Output size: 1657041920, temp size: 4907988480, argument size: 1657366016, host temp size: 0, in bytes. I0121 01:39:46.830807 130655182204928 ray_trainer.py:47] Training complete! (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 99, seconds: 1.191, TFLOP/s/device: 130.014, Tokens/s/device: 13753.874, total_weights: 131072, loss: 0.000
TPU 和 Ray 资源
Ray 对 TPU 的处理方式与 GPU 不同,以适应使用方面的差异。在以下示例中,共有 9 个 Ray 节点:
- Ray 主节点在
n1-standard-16
虚拟机上运行。 - Ray 工作器节点在两个
v6e-16
TPU 上运行。每个 TPU 由四个工作器组成。
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/727.0 CPU
0.0/32.0 TPU
0.0/2.0 TPU-v6e-16-head
0B/5.13TiB memory
0B/1.47TiB object_store_memory
0.0/4.0 tpu-group-0
0.0/4.0 tpu-group-1
Demands:
(no resource demands)
资源使用情况字段说明:
CPU
:集群中可用的 CPU 总数。TPU
:集群中的 TPU 芯片数量。TPU-v6e-16-head
:与 TPU 切片的工作器 0 对应的资源的特殊标识符。这对于访问各个 TPU slice 至关重要。memory
:应用使用的 Worker 堆内存。object_store_memory
:应用使用ray.put
在对象存储区中创建对象以及从远程函数返回值时所用的内存。tpu-group-0
和tpu-group-1
:各个 TPU slice 的唯一标识符。这对于在 slice 上运行作业非常重要。这些字段设置为 4,因为 v6e-16 中的每个 TPU 切片都有 4 个主机。