使用 Ray 擴充機器學習工作負載
本文件將詳細說明如何在 TPU 上使用 Ray 和 JAX 執行機器學習 (ML) 工作負載。您可以透過兩種模式,在 Ray 中使用 TPU:以裝置為中心的模式 (PyTorch/XLA) 和以主機為中心的模式 (JAX)。
本文假設您已設定 TPU 環境。詳情請參閱下列資源:
- Cloud TPU:設定 Cloud TPU 環境和管理 TPU 資源
- Google Kubernetes Engine (GKE):在 GKE Autopilot 中部署 TPU 工作負載,或在 GKE Standard 中部署 TPU 工作負載
以裝置為中心的模式 (PyTorch/XLA)
以裝置為中心的模式保留了許多傳統 PyTorch 的程式設計風格。在這個模式下,您可以新增 XLA 裝置類型,其運作方式與任何其他 PyTorch 裝置相同。每個個別程序都會與一個 XLA 裝置互動。
如果您已熟悉 PyTorch 與 GPU,且想使用類似的程式碼抽象,這個模式就非常適合您。
以下各節將說明如何在不使用 Ray 的情況下,在一個或多個裝置上執行 PyTorch/XLA 工作負載,然後說明如何使用 Ray 在多個主機上執行相同的工作負載。
建立 TPU
建立 TPU 建立參數的環境變數。
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5p-8 export RUNTIME_VERSION=v2-alpha-tpuv5
使用下列指令建立 8 核心的 v5p TPU VM:
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
使用下列指令連線至 TPU VM:
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
如果您使用的是 GKE,請參閱 KubeRay on GKE 指南,瞭解設定資訊。
安裝需求
在 TPU VM 上執行下列指令,安裝必要的依附元件:
將以下內容儲存為檔案。例如
requirements.txt
。--find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/libtpu-wheels/index.html torch~=2.6.0 torch_xla[tpu]~=2.6.0 ray[default]==2.40.0
如要安裝必要的依附元件,請執行以下指令:
pip install -r requirements.txt
如果您在 GKE 上執行工作負載,建議您建立 Dockerfile 來安裝必要的依附元件。如需範例,請參閱 GKE 說明文件中的「在 TPU 區塊節點上執行工作負載」。
在單一裝置上執行 PyTorch/XLA 工作負載
以下範例說明如何在單一裝置 (即 TPU 晶片) 上建立 XLA 張量。這與 PyTorch 處理其他裝置類型的做法類似。
將下列程式碼片段儲存到檔案中。例如
workload.py
。import torch import torch_xla import torch_xla.core.xla_model as xm t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t)
import torch_xla
匯入陳述式會初始化 PyTorch/XLA,而xm.xla_device()
函式會傳回目前的 XLA 裝置 (即 TPU 晶片)。將
PJRT_DEVICE
環境變數設為 TPU。export PJRT_DEVICE=TPU
執行指令碼。
python workload.py
輸出結果看起來與下列內容相似。請確認輸出內容顯示已找到 XLA 裝置。
xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0')
在多部裝置上執行 PyTorch/XLA
更新上一個部分的程式碼片段,讓程式碼可在多部裝置上執行。
import torch import torch_xla import torch_xla.core.xla_model as xm def _mp_fn(index): t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t) if __name__ == '__main__': torch_xla.launch(_mp_fn, args=())
執行指令碼。
python workload.py
如果您在 TPU v5p-8 上執行程式碼片段,輸出結果會類似以下內容:
xla:0 xla:0 xla:0 tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') xla:0 tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0')
torch_xla.launch()
會使用兩個引數:函式和參數清單。它會為每個可用的 XLA 裝置建立程序,並呼叫在引數中指定的函式。在這個範例中,有 4 個可用的 TPU 裝置,因此 torch_xla.launch()
會建立 4 個程序,並在每個裝置上呼叫 _mp_fn()
。每個程序只能存取一個裝置,因此每個裝置都有索引 0,且所有程序都會列印 xla:0
。
使用 Ray 在多個主機上執行 PyTorch/XLA
以下各節將說明如何在較大的多主機 TPU 切片上執行相同的程式碼片段。如要進一步瞭解多主機 TPU 架構,請參閱「系統架構」。
在本例中,您會手動設定 Ray。如果您已熟悉 Ray 的設定方式,可以直接跳到最後一節「執行 Ray 工作負載」。如要進一步瞭解如何為正式環境設定 Ray,請參閱以下資源:
建立多主機 TPU VM
建立 TPU 建立參數的環境變數。
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5p-16 export RUNTIME_VERSION=v2-alpha-tpuv5
使用下列指令建立多主機 TPU v5p,其中包含 2 個主機 (v5p-16,每個主機有 4 個 TPU 晶片):
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
設定 Ray
TPU v5p-16 有 2 個 TPU 主機,每個主機有 4 個 TPU 晶片。在這個範例中,您將在一個主機上啟動 Ray 主節點,並將第二個主機新增為 Ray 叢集中的工作節點。
使用 SSH 連線至第一個主機。
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=0
使用與安裝需求部分相同的依附元件需求檔案安裝依附元件。
pip install -r requirements.txt
啟動 Ray 程序。
ray start --head --port=6379
輸出看起來類似以下內容:
Enable usage stats collection? This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup. Confirm [Y/n]: y Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details. Local node IP: 10.130.0.76 -------------------- Ray runtime started. -------------------- Next steps To add another node to this Ray cluster, run ray start --address='10.130.0.76:6379' To connect to this Ray cluster: import ray ray.init() To terminate the Ray runtime, run ray stop To view the status of the cluster, use ray status
這個 TPU 主機現在是 Ray 主節點。請記下說明如何在 Ray 叢集中新增其他節點的行,如下所示:
To add another node to this Ray cluster, run ray start --address='10.130.0.76:6379'
您將在後續步驟中使用這個指令。
檢查 Ray 叢集狀態:
ray status
輸出看起來類似以下內容:
======== Autoscaler status: 2025-01-14 22:03:39.385610 ======== Node status --------------------------------------------------------------- Active: 1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79 Pending: (no pending nodes) Recent failures: (no failures) Resources --------------------------------------------------------------- Usage: 0.0/208.0 CPU 0.0/4.0 TPU 0.0/1.0 TPU-v5p-16-head 0B/268.44GiB memory 0B/119.04GiB object_store_memory 0.0/1.0 your-tpu-name Demands: (no resource demands)
叢集只包含 4 個 TPU (
0.0/4.0 TPU
),因為您目前只新增了主節點。主節點已開始運作,您可以將第二個主機新增至叢集。
使用 SSH 連線至第二個主機。
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=1
使用與「安裝需求」一節相同的依附元件需求檔案安裝依附元件。
pip install -r requirements.txt
啟動 Ray 程序。如要將這個節點新增至現有的 Ray 叢集,請使用
ray start
指令輸出的指令。請務必在下列指令中替換 IP 位址和通訊埠:ray start --address='10.130.0.76:6379'
輸出看起來類似以下內容:
Local node IP: 10.130.0.80 [2025-01-14 22:30:07,397 W 75572 75572] global_state_accessor.cc:463: Retrying to get node with node ID 35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1 -------------------- Ray runtime started. -------------------- To terminate the Ray runtime, run ray stop
再次檢查 Ray 狀態:
ray status
輸出看起來類似以下內容:
======== Autoscaler status: 2025-01-14 22:45:21.485617 ======== Node status --------------------------------------------------------------- Active: 1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79 1 node_35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1 Pending: (no pending nodes) Recent failures: (no failures) Resources --------------------------------------------------------------- Usage: 0.0/416.0 CPU 0.0/8.0 TPU 0.0/1.0 TPU-v5p-16-head 0B/546.83GiB memory 0B/238.35GiB object_store_memory 0.0/2.0 your-tpu-name Demands: (no resource demands)
第二個 TPU 主機現在是叢集中的節點。可用資源清單現在會顯示 8 個 TPU (
0.0/8.0 TPU
)。
執行 Ray 工作負載
更新要在 Ray 叢集中執行的程式碼片段:
import os import torch import torch_xla import torch_xla.core.xla_model as xm import ray import torch.distributed as dist import torch_xla.runtime as xr from torch_xla._internal import pjrt # Defines the local PJRT world size, the number of processes per host. LOCAL_WORLD_SIZE = 4 # Defines the number of hosts in the Ray cluster. NUM_OF_HOSTS = 4 GLOBAL_WORLD_SIZE = LOCAL_WORLD_SIZE * NUM_OF_HOSTS def init_env(): local_rank = int(os.environ['TPU_VISIBLE_CHIPS']) pjrt.initialize_multiprocess(local_rank, LOCAL_WORLD_SIZE) xr._init_world_size_ordinal() # This decorator signals to Ray that the `print_tensor()` function should be run on a single TPU chip. @ray.remote(resources={"TPU": 1}) def print_tensor(): # Initializes the runtime environment on each Ray worker. Equivalent to # the `torch_xla.launch call` in the Run PyTorch/XLA on multiple devices section. init_env() t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t) ray.init() # Uses Ray to dispatch the function call across available nodes in the cluster. tasks = [print_tensor.remote() for _ in range(GLOBAL_WORLD_SIZE)] ray.get(tasks) ray.shutdown()
在 Ray 主節點上執行指令碼。將 ray-workload.py 替換為指令碼的路徑。
python ray-workload.py
輸出看起來類似以下內容:
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU. xla:0 xla:0 xla:0 xla:0 xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') xla:0 xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0')
輸出內容表示函式已成功在多主機 TPU 區塊中的每個 XLA 裝置 (此範例中有 8 個裝置) 上呼叫。
以主機為中心的模式 (JAX)
以下各節將說明 JAX 的以主機為中心模式。JAX 採用功能式程式設計模式,並支援高層級單一程式、多個資料 (SPMD) 語意。JAX 程式碼的設計目的是讓單一主機上的多部裝置同時運作,而非讓每個程序與單一 XLA 裝置互動。
JAX 專為高效能運算而設計,可有效運用 TPU 進行大規模訓練和推論。如果您熟悉函式程式設計概念,這個模式非常適合您,可讓您充分發揮 JAX 的潛力。
這些操作說明假設您已設定 Ray 和 TPU 環境,包括包含 JAX 和其他相關套件的軟體環境。如要建立 Ray TPU 叢集,請按照「啟動 Google Cloud 支援 KubeRay 的 TPU 叢集」一文中的操作說明進行。如要進一步瞭解如何在 KubeRay 中使用 TPU,請參閱「在 KubeRay 中使用 TPU」。
在單主機 TPU 上執行 JAX 工作負載
以下指令碼範例示範如何在 Ray 叢集上執行 JAX 函式,並使用單一主機 TPU (例如 v6e-4)。如果您有多主機 TPU,則此指令碼會因 JAX 的多控制器執行模式而停止回應。如要進一步瞭解如何在多主機 TPU 上執行 Ray,請參閱「在多主機 TPU 上執行 JAX 工作負載」。
建立 TPU 建立參數的環境變數。
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-a export ACCELERATOR_TYPE=v6e-4 export RUNTIME_VERSION=v2-alpha-tpuv6e
使用下列指令建立 4 核心的 v6e TPU VM:
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
使用下列指令連線至 TPU VM:
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
在 TPU 上安裝 JAX 和 Ray。
pip install ray jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
將下列程式碼儲存到檔案中。例如
ray-jax-single-host.py
。import ray import jax @ray.remote(resources={"TPU": 4}) def my_function() -> int: return jax.device_count() h = my_function.remote() print(ray.get(h)) # => 4
如果您習慣使用 GPU 執行 Ray,使用 TPU 時會有幾個主要差異:
- 請勿設定
num_gpus
,而是將TPU
指定為自訂資源,並設定 TPU 晶片數量。 - 使用每個 Ray 工作站節點的晶片數量指定 TPU。舉例來說,如果您使用 v6e-4,將
TPU
設為 4 時,執行遠端函式會耗用整個 TPU 主機。 - 這與 GPU 通常執行的方式不同,後者每個主機都有一個程序。不建議將
TPU
設為非 4 的數字。- 例外狀況:如果您有單一主機
v6e-8
或v5litepod-8
,應將這個值設為 8。
- 例外狀況:如果您有單一主機
- 請勿設定
執行指令碼。
python ray-jax-single-host.py
在多主機 TPU 上執行 JAX 工作負載
以下範例指令碼示範如何在具有多主機 TPU 的 Ray 叢集中執行 JAX 函式。範例指令碼使用 v6e-16。
建立 TPU 建立參數的環境變數。
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-a export ACCELERATOR_TYPE=v6e-16 export RUNTIME_VERSION=v2-alpha-tpuv6e
使用下列指令建立 16 核心的 v6e TPU VM:
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
在所有 TPU 工作站上安裝 JAX 和 Ray。
gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=all \ --command="pip install ray jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
將下列程式碼儲存到檔案中。例如
ray-jax-multi-host.py
。import ray import jax @ray.remote(resources={"TPU": 4}) def my_function() -> int: return jax.device_count() ray.init() num_tpus = ray.available_resources()["TPU"] num_hosts = int(num_tpus) # 4 h = [my_function.remote() for _ in range(num_hosts)] print(ray.get(h)) # [16, 16, 16, 16]
如果您習慣使用 GPU 執行 Ray,使用 TPU 時會有幾個主要差異:
- 與 GPU 上的 PyTorch 工作負載類似:
- TPU 上的 JAX 工作負載會以多控制器、單一程式、多資料 (SPMD) 模式執行。
- 裝置之間的集合會由機器學習架構處理。
- 與 GPU 上的 PyTorch 工作負載不同,JAX 可全面查看叢集中的可用裝置。
- 與 GPU 上的 PyTorch 工作負載類似:
將指令碼複製到所有 TPU 工作站。
gcloud compute tpus tpu-vm scp ray-jax-multi-host.py $TPU_NAME: --zone=$ZONE --worker=all
執行指令碼。
gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=all \ --command="python ray-jax-multi-host.py"
執行多切片 JAX 工作負載
Multislice 可讓您在單一 TPU Pod 中,或透過資料中心網路中的多個 Pod 執行跨越多個 TPU 切片的工作負載。
您可以使用 ray-tpu
套件,簡化 Ray 與 TPU 切片的互動。
使用 pip
安裝 ray-tpu
。
pip install ray-tpu
如要進一步瞭解如何使用 ray-tpu
套件,請參閱 GitHub 存放區中的「開始使用」一文。如需使用多配量的範例,請參閱「在多配量上執行」。
使用 Ray 和 MaxText 自動化調度管理工作負載
如要進一步瞭解如何搭配 MaxText 使用 Ray,請參閱「使用 MaxText 執行訓練工作」。
TPU 和 Ray 資源
Ray 會以不同於 GPU 的方式處理 TPU,以因應使用方式的差異。在以下範例中,共有九個 Ray 節點:
- Ray 主節點會在
n1-standard-16
VM 上執行。 - Ray 工作站節點會在兩個
v6e-16
TPU 上執行。每個 TPU 都由四個 worker 組成。
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/727.0 CPU
0.0/32.0 TPU
0.0/2.0 TPU-v6e-16-head
0B/5.13TiB memory
0B/1.47TiB object_store_memory
0.0/4.0 tpu-group-0
0.0/4.0 tpu-group-1
Demands:
(no resource demands)
資源使用量欄位說明:
CPU
:叢集中可用的 CPU 總數。TPU
:叢集中的 TPU 晶片數量。TPU-v6e-16-head
:與 TPU 區塊的 worker 0 相對應的資源專屬 ID。這對於存取個別 TPU 分片至關重要。memory
:應用程式使用的 worker 堆積記憶體。object_store_memory
:應用程式使用ray.put
在物件儲存庫中建立物件,以及從遠端函式傳回值時所使用的記憶體。tpu-group-0
和tpu-group-1
:個別 TPU 切片的專屬 ID。這對於在區塊上執行工作至關重要。這些欄位會設為 4,因為 v6e-16 中的每個 TPU 切片都有 4 個主機。