使用 Ray 擴充機器學習工作負載

本文件將詳細說明如何在 TPU 上使用 Ray 和 JAX 執行機器學習 (ML) 工作負載。您可以透過兩種模式,在 Ray 中使用 TPU:以裝置為中心的模式 (PyTorch/XLA)以主機為中心的模式 (JAX)

本文假設您已設定 TPU 環境。詳情請參閱下列資源:

以裝置為中心的模式 (PyTorch/XLA)

以裝置為中心的模式保留了許多傳統 PyTorch 的程式設計風格。在這個模式下,您可以新增 XLA 裝置類型,其運作方式與任何其他 PyTorch 裝置相同。每個個別程序都會與一個 XLA 裝置互動。

如果您已熟悉 PyTorch 與 GPU,且想使用類似的程式碼抽象,這個模式就非常適合您。

以下各節將說明如何在不使用 Ray 的情況下,在一個或多個裝置上執行 PyTorch/XLA 工作負載,然後說明如何使用 Ray 在多個主機上執行相同的工作負載。

建立 TPU

  1. 建立 TPU 建立參數的環境變數。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-8
    export RUNTIME_VERSION=v2-alpha-tpuv5

    環境變數說明

    變數 說明
    PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    TPU_NAME TPU 的名稱。
    ZONE 建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱「TPU 地區和區域」。
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    RUNTIME_VERSION Cloud TPU 軟體版本

  2. 使用下列指令建立 8 核心的 v5p TPU VM:

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE  \
       --version=$RUNTIME_VERSION
  3. 使用下列指令連線至 TPU VM:

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE

如果您使用的是 GKE,請參閱 KubeRay on GKE 指南,瞭解設定資訊。

安裝需求

在 TPU VM 上執行下列指令,安裝必要的依附元件:

  1. 將以下內容儲存為檔案。例如 requirements.txt

    --find-links https://storage.googleapis.com/libtpu-releases/index.html
    --find-links https://storage.googleapis.com/libtpu-wheels/index.html
    torch~=2.6.0
    torch_xla[tpu]~=2.6.0
    ray[default]==2.40.0
    
  2. 如要安裝必要的依附元件,請執行以下指令:

    pip install -r requirements.txt
    

如果您在 GKE 上執行工作負載,建議您建立 Dockerfile 來安裝必要的依附元件。如需範例,請參閱 GKE 說明文件中的「在 TPU 區塊節點上執行工作負載」。

在單一裝置上執行 PyTorch/XLA 工作負載

以下範例說明如何在單一裝置 (即 TPU 晶片) 上建立 XLA 張量。這與 PyTorch 處理其他裝置類型的做法類似。

  1. 將下列程式碼片段儲存到檔案中。例如 workload.py

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    t = torch.randn(2, 2, device=xm.xla_device())
    print(t.device)
    print(t)
    

    import torch_xla 匯入陳述式會初始化 PyTorch/XLA,而 xm.xla_device() 函式會傳回目前的 XLA 裝置 (即 TPU 晶片)。

  2. PJRT_DEVICE 環境變數設為 TPU。

    export PJRT_DEVICE=TPU
    
  3. 執行指令碼。

    python workload.py
    

    輸出結果看起來與下列內容相似。請確認輸出內容顯示已找到 XLA 裝置。

    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

在多部裝置上執行 PyTorch/XLA

  1. 更新上一個部分的程式碼片段,讓程式碼可在多部裝置上執行。

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    def _mp_fn(index):
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    if __name__ == '__main__':
        torch_xla.launch(_mp_fn, args=())
    
  2. 執行指令碼。

    python workload.py
    

    如果您在 TPU v5p-8 上執行程式碼片段,輸出結果會類似以下內容:

    xla:0
    xla:0
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    

torch_xla.launch() 會使用兩個引數:函式和參數清單。它會為每個可用的 XLA 裝置建立程序,並呼叫在引數中指定的函式。在這個範例中,有 4 個可用的 TPU 裝置,因此 torch_xla.launch() 會建立 4 個程序,並在每個裝置上呼叫 _mp_fn()。每個程序只能存取一個裝置,因此每個裝置都有索引 0,且所有程序都會列印 xla:0

使用 Ray 在多個主機上執行 PyTorch/XLA

以下各節將說明如何在較大的多主機 TPU 切片上執行相同的程式碼片段。如要進一步瞭解多主機 TPU 架構,請參閱「系統架構」。

在本例中,您會手動設定 Ray。如果您已熟悉 Ray 的設定方式,可以直接跳到最後一節「執行 Ray 工作負載」。如要進一步瞭解如何為正式環境設定 Ray,請參閱以下資源:

建立多主機 TPU VM

  1. 建立 TPU 建立參數的環境變數。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-16
    export RUNTIME_VERSION=v2-alpha-tpuv5

    環境變數說明

    變數 說明
    PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    TPU_NAME TPU 的名稱。
    ZONE 建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱「TPU 地區和區域」一文。
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    RUNTIME_VERSION Cloud TPU 軟體版本

  2. 使用下列指令建立多主機 TPU v5p,其中包含 2 個主機 (v5p-16,每個主機有 4 個 TPU 晶片):

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE \
       --version=$RUNTIME_VERSION

設定 Ray

TPU v5p-16 有 2 個 TPU 主機,每個主機有 4 個 TPU 晶片。在這個範例中,您將在一個主機上啟動 Ray 主節點,並將第二個主機新增為 Ray 叢集中的工作節點。

  1. 使用 SSH 連線至第一個主機。

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=0
  2. 使用與安裝需求部分相同的依附元件需求檔案安裝依附元件。

    pip install -r requirements.txt
    
  3. 啟動 Ray 程序。

    ray start --head --port=6379
    

    輸出看起來類似以下內容:

    Enable usage stats collection? This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup. Confirm [Y/n]: y
    Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details.
    
    Local node IP: 10.130.0.76
    
    --------------------
    Ray runtime started.
    --------------------
    
    Next steps
    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    
    To connect to this Ray cluster:
        import ray
        ray.init()
    
    To terminate the Ray runtime, run
        ray stop
    
    To view the status of the cluster, use
        ray status
    

    這個 TPU 主機現在是 Ray 主節點。請記下說明如何在 Ray 叢集中新增其他節點的行,如下所示:

    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    

    您將在後續步驟中使用這個指令。

  4. 檢查 Ray 叢集狀態:

    ray status
    

    輸出看起來類似以下內容:

    ======== Autoscaler status: 2025-01-14 22:03:39.385610 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/208.0 CPU
    0.0/4.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/268.44GiB memory
    0B/119.04GiB object_store_memory
    0.0/1.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    叢集只包含 4 個 TPU (0.0/4.0 TPU),因為您目前只新增了主節點。

    主節點已開始運作,您可以將第二個主機新增至叢集。

  5. 使用 SSH 連線至第二個主機。

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=1
  6. 使用與「安裝需求」一節相同的依附元件需求檔案安裝依附元件。

    pip install -r requirements.txt
    
  7. 啟動 Ray 程序。如要將這個節點新增至現有的 Ray 叢集,請使用 ray start 指令輸出的指令。請務必在下列指令中替換 IP 位址和通訊埠:

    ray start --address='10.130.0.76:6379'

    輸出看起來類似以下內容:

    Local node IP: 10.130.0.80
    [2025-01-14 22:30:07,397 W 75572 75572] global_state_accessor.cc:463: Retrying to get node with node ID 35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    
    --------------------
    Ray runtime started.
    --------------------
    
    To terminate the Ray runtime, run
    ray stop
    
  8. 再次檢查 Ray 狀態:

    ray status
    

    輸出看起來類似以下內容:

    ======== Autoscaler status: 2025-01-14 22:45:21.485617 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    1 node_35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/416.0 CPU
    0.0/8.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/546.83GiB memory
    0B/238.35GiB object_store_memory
    0.0/2.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    第二個 TPU 主機現在是叢集中的節點。可用資源清單現在會顯示 8 個 TPU (0.0/8.0 TPU)。

執行 Ray 工作負載

  1. 更新要在 Ray 叢集中執行的程式碼片段:

    import os
    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    import ray
    
    import torch.distributed as dist
    import torch_xla.runtime as xr
    from torch_xla._internal import pjrt
    
    # Defines the local PJRT world size, the number of processes per host.
    LOCAL_WORLD_SIZE = 4
    # Defines the number of hosts in the Ray cluster.
    NUM_OF_HOSTS = 4
    GLOBAL_WORLD_SIZE = LOCAL_WORLD_SIZE * NUM_OF_HOSTS
    
    def init_env():
        local_rank = int(os.environ['TPU_VISIBLE_CHIPS'])
    
        pjrt.initialize_multiprocess(local_rank, LOCAL_WORLD_SIZE)
        xr._init_world_size_ordinal()
    
    # This decorator signals to Ray that the `print_tensor()` function should be run on a single TPU chip.
    @ray.remote(resources={"TPU": 1})
    def print_tensor():
        # Initializes the runtime environment on each Ray worker. Equivalent to
        # the `torch_xla.launch call` in the Run PyTorch/XLA on multiple devices section.
        init_env()
    
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    ray.init()
    
    # Uses Ray to dispatch the function call across available nodes in the cluster.
    tasks = [print_tensor.remote() for _ in range(GLOBAL_WORLD_SIZE)]
    ray.get(tasks)
    
    ray.shutdown()
    
  2. 在 Ray 主節點上執行指令碼。將 ray-workload.py 替換為指令碼的路徑。

    python ray-workload.py

    輸出看起來類似以下內容:

    WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
    xla:0
    xla:0
    xla:0
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

    輸出內容表示函式已成功在多主機 TPU 區塊中的每個 XLA 裝置 (此範例中有 8 個裝置) 上呼叫。

以主機為中心的模式 (JAX)

以下各節將說明 JAX 的以主機為中心模式。JAX 採用功能式程式設計模式,並支援高層級單一程式、多個資料 (SPMD) 語意。JAX 程式碼的設計目的是讓單一主機上的多部裝置同時運作,而非讓每個程序與單一 XLA 裝置互動。

JAX 專為高效能運算而設計,可有效運用 TPU 進行大規模訓練和推論。如果您熟悉函式程式設計概念,這個模式非常適合您,可讓您充分發揮 JAX 的潛力。

這些操作說明假設您已設定 Ray 和 TPU 環境,包括包含 JAX 和其他相關套件的軟體環境。如要建立 Ray TPU 叢集,請按照「啟動 Google Cloud 支援 KubeRay 的 TPU 叢集」一文中的操作說明進行。如要進一步瞭解如何在 KubeRay 中使用 TPU,請參閱「在 KubeRay 中使用 TPU」。

在單主機 TPU 上執行 JAX 工作負載

以下指令碼範例示範如何在 Ray 叢集上執行 JAX 函式,並使用單一主機 TPU (例如 v6e-4)。如果您有多主機 TPU,則此指令碼會因 JAX 的多控制器執行模式而停止回應。如要進一步瞭解如何在多主機 TPU 上執行 Ray,請參閱「在多主機 TPU 上執行 JAX 工作負載」。

  1. 建立 TPU 建立參數的環境變數。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-a
    export ACCELERATOR_TYPE=v6e-4
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    環境變數說明

    變數 說明
    PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    TPU_NAME TPU 的名稱。
    ZONE 建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱「TPU 地區和區域」。
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    RUNTIME_VERSION Cloud TPU 軟體版本

  2. 使用下列指令建立 4 核心的 v6e TPU VM:

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE  \
       --version=$RUNTIME_VERSION
  3. 使用下列指令連線至 TPU VM:

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
  4. 在 TPU 上安裝 JAX 和 Ray。

    pip install ray jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  5. 將下列程式碼儲存到檔案中。例如 ray-jax-single-host.py

    import ray
    import jax
    
    @ray.remote(resources={"TPU": 4})
    def my_function() -> int:
        return jax.device_count()
    
    h = my_function.remote()
    print(ray.get(h)) # => 4
    

    如果您習慣使用 GPU 執行 Ray,使用 TPU 時會有幾個主要差異:

    • 請勿設定 num_gpus,而是將 TPU 指定為自訂資源,並設定 TPU 晶片數量。
    • 使用每個 Ray 工作站節點的晶片數量指定 TPU。舉例來說,如果您使用 v6e-4,將 TPU 設為 4 時,執行遠端函式會耗用整個 TPU 主機。
    • 這與 GPU 通常執行的方式不同,後者每個主機都有一個程序。不建議將 TPU 設為非 4 的數字。
      • 例外狀況:如果您有單一主機 v6e-8v5litepod-8,應將這個值設為 8。
  6. 執行指令碼。

    python ray-jax-single-host.py

在多主機 TPU 上執行 JAX 工作負載

以下範例指令碼示範如何在具有多主機 TPU 的 Ray 叢集中執行 JAX 函式。範例指令碼使用 v6e-16。

  1. 建立 TPU 建立參數的環境變數。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-a
    export ACCELERATOR_TYPE=v6e-16
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    環境變數說明

    變數 說明
    PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    TPU_NAME TPU 的名稱。
    ZONE 建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱「TPU 地區和區域」一文。
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    RUNTIME_VERSION Cloud TPU 軟體版本

  2. 使用下列指令建立 16 核心的 v6e TPU VM:

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE  \
       --version=$RUNTIME_VERSION
  3. 在所有 TPU 工作站上安裝 JAX 和 Ray。

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
       --zone=$ZONE \
       --worker=all \
       --command="pip install ray jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
  4. 將下列程式碼儲存到檔案中。例如 ray-jax-multi-host.py

    import ray
    import jax
    
    @ray.remote(resources={"TPU": 4})
    def my_function() -> int:
        return jax.device_count()
    
    ray.init()
    num_tpus = ray.available_resources()["TPU"]
    num_hosts = int(num_tpus) # 4
    h = [my_function.remote() for _ in range(num_hosts)]
    print(ray.get(h)) # [16, 16, 16, 16]
    

    如果您習慣使用 GPU 執行 Ray,使用 TPU 時會有幾個主要差異:

    • 與 GPU 上的 PyTorch 工作負載類似:
    • 與 GPU 上的 PyTorch 工作負載不同,JAX 可全面查看叢集中的可用裝置。
  5. 將指令碼複製到所有 TPU 工作站。

    gcloud compute tpus tpu-vm scp ray-jax-multi-host.py $TPU_NAME: --zone=$ZONE --worker=all
  6. 執行指令碼。

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
       --zone=$ZONE \
       --worker=all \
       --command="python ray-jax-multi-host.py"

執行多切片 JAX 工作負載

Multislice 可讓您在單一 TPU Pod 中,或透過資料中心網路中的多個 Pod 執行跨越多個 TPU 切片的工作負載。

您可以使用 ray-tpu 套件,簡化 Ray 與 TPU 切片的互動。

使用 pip 安裝 ray-tpu

pip install ray-tpu

如要進一步瞭解如何使用 ray-tpu 套件,請參閱 GitHub 存放區中的「開始使用」一文。如需使用多配量的範例,請參閱「在多配量上執行」。

使用 Ray 和 MaxText 自動化調度管理工作負載

如要進一步瞭解如何搭配 MaxText 使用 Ray,請參閱「使用 MaxText 執行訓練工作」。

TPU 和 Ray 資源

Ray 會以不同於 GPU 的方式處理 TPU,以因應使用方式的差異。在以下範例中,共有九個 Ray 節點:

  • Ray 主節點會在 n1-standard-16 VM 上執行。
  • Ray 工作站節點會在兩個 v6e-16 TPU 上執行。每個 TPU 都由四個 worker 組成。
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
 1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
 1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
 1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
 1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
 1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
 1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
 1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
 1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
 1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/727.0 CPU
 0.0/32.0 TPU
 0.0/2.0 TPU-v6e-16-head
 0B/5.13TiB memory
 0B/1.47TiB object_store_memory
 0.0/4.0 tpu-group-0
 0.0/4.0 tpu-group-1

Demands:
 (no resource demands)

資源使用量欄位說明:

  • CPU:叢集中可用的 CPU 總數。
  • TPU:叢集中的 TPU 晶片數量。
  • TPU-v6e-16-head:與 TPU 區塊的 worker 0 相對應的資源專屬 ID。這對於存取個別 TPU 分片至關重要。
  • memory:應用程式使用的 worker 堆積記憶體。
  • object_store_memory:應用程式使用 ray.put 在物件儲存庫中建立物件,以及從遠端函式傳回值時所使用的記憶體。
  • tpu-group-0tpu-group-1:個別 TPU 切片的專屬 ID。這對於在區塊上執行工作至關重要。這些欄位會設為 4,因為 v6e-16 中的每個 TPU 切片都有 4 個主機。