Ray を使用して ML ワークロードをスケーリングする
このドキュメントでは、TPU で Ray と JAX を使用して機械学習(ML)ワークロードを実行する方法について詳しく説明します。Ray で TPU を使用するには、デバイス中心モード(PyTorch/XLA)とホスト中心モード(JAX)の 2 つのモードがあります。
このドキュメントでは、TPU 環境がすでに設定されていることを前提としています。詳しくは、次のリソースをご覧ください。
- Cloud TPU: Cloud TPU 環境を設定すると TPU リソースを管理する
- Google Kubernetes Engine(GKE): GKE Autopilot に TPU ワークロードをデプロイするまたは GKE Standard に TPU ワークロードをデプロイする
デバイス中心モード(PyTorch/XLA)
デバイス中心モードでは、従来の PyTorch のプログラマティック スタイルの多くが保持されます。このモードでは、他の PyTorch デバイスと同様に動作する新しい XLA デバイスタイプを追加します。各プロセスは 1 つの XLA デバイスとやり取りします。
このモードは、GPU を使用した PyTorch にすでに精通していて、同様のコード抽象化を使用する場合に最適です。
以降のセクションでは、Ray を使用せずに 1 つ以上のデバイスで PyTorch/XLA ワークロードを実行する方法と、Ray を使用して複数のホストで同じワークロードを実行する方法について説明します。
TPU を作成する
TPU 作成パラメータの環境変数を作成します。
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5p-8 export RUNTIME_VERSION=v2-alpha-tpuv5
環境変数の説明
変数 説明 PROJECT_ID
実際の Google Cloud のプロジェクト ID。既存のプロジェクトを使用するか、新しいプロジェクトを作成します。 TPU_NAME
TPU の名前。 ZONE
TPU VM を作成するゾーン。サポートされているゾーンの詳細については、TPU のリージョンとゾーンをご覧ください。 ACCELERATOR_TYPE
アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。 RUNTIME_VERSION
Cloud TPU ソフトウェアのバージョン。 次のコマンドを使用して、8 コアの v5p TPU VM を作成します。
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
次のコマンドを使用して TPU VM に接続します。
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
GKE を使用している場合は、GKE 上の KubeRay ガイドで設定情報をご覧ください。
インストール要件
TPU VM で次のコマンドを実行して、必要な依存関係をインストールします。
以下をファイルに保存します。例:
requirements.txt
--find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/libtpu-wheels/index.html torch~=2.6.0 torch_xla[tpu]~=2.6.0 ray[default]==2.40.0
必要な依存関係をインストールするには、以下を実行します。
pip install -r requirements.txt
GKE でワークロードを実行する場合は、必要な依存関係をインストールする Dockerfile を作成することをおすすめします。例については、GKE のドキュメントの TPU スライスノードでワークロードを実行するをご覧ください。
単一のデバイスで PyTorch/XLA ワークロードを実行する
次の例は、単一のデバイス(TPU チップ)で XLA テンソルを作成する方法を示しています。これは、PyTorch が他のデバイスタイプを処理する方法と似ています。
次のコード スニペットをファイルに保存します。例:
workload.py
import torch import torch_xla import torch_xla.core.xla_model as xm t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t)
import torch_xla
import ステートメントは PyTorch/XLA を初期化し、xm.xla_device()
関数は現在の XLA デバイス(TPU チップ)を返します。PJRT_DEVICE
環境変数を TPU に設定します。export PJRT_DEVICE=TPU
スクリプトを実行します。
python workload.py
出力は次のようになります。出力に、XLA デバイスが検出されたことが示されていることを確認します。
xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0')
複数のデバイスで PyTorch/XLA を実行する
前のセクションのコード スニペットを、複数のデバイスで実行されるように更新します。
import torch import torch_xla import torch_xla.core.xla_model as xm def _mp_fn(index): t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t) if __name__ == '__main__': torch_xla.launch(_mp_fn, args=())
スクリプトを実行します。
python workload.py
TPU v5p-8 でコード スニペットを実行すると、出力は次のようになります。
xla:0 xla:0 xla:0 tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') xla:0 tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0')
torch_xla.launch()
は、関数とパラメータのリストという 2 つの引数を取り、使用可能な XLA デバイスごとにプロセスを作成して、引数で指定された関数を呼び出します。この例では、使用可能な TPU デバイスが 4 つあるため、torch_xla.launch()
は 4 つのプロセスを作成し、各デバイスで _mp_fn()
を呼び出します。各プロセスは 1 つのデバイスにのみアクセスできるため、各デバイスのインデックスは 0 になり、すべてのプロセスに xla:0
が出力されます。
Ray を使用して複数のホストで PyTorch/XLA を実行する
以降のセクションでは、より大きなマルチホスト TPU スライスで同じコード スニペットを実行する方法について説明します。マルチホスト TPU アーキテクチャの詳細については、システム アーキテクチャをご覧ください。
この例では、Ray を手動で設定します。Ray の設定に慣れている場合は、最後のセクション Ray ワークロードを実行するに進んでください。本番環境用に Ray を設定する方法については、次のリソースをご覧ください。
マルチホスト TPU VM を作成する
TPU 作成パラメータの環境変数を作成します。
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5p-16 export RUNTIME_VERSION=v2-alpha-tpuv5
環境変数の説明
変数 説明 PROJECT_ID
実際の Google Cloud のプロジェクト ID。既存のプロジェクトを使用するか、新しいプロジェクトを作成します。 TPU_NAME
TPU の名前。 ZONE
TPU VM を作成するゾーン。サポートされているゾーンの詳細については、TPU のリージョンとゾーンをご覧ください。 ACCELERATOR_TYPE
アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。 RUNTIME_VERSION
Cloud TPU ソフトウェアのバージョン。 次のコマンドを使用して、2 つのホストを持つマルチホスト TPU v5p(v5p-16、各ホストに 4 つの TPU チップ)を作成します。
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
Ray を設定する
TPU v5p-16 には 2 つの TPU ホストがあり、それぞれに 4 つの TPU チップがあります。この例では、1 つのホストで Ray ヘッドノードを起動し、2 つ目のホストをワーカーノードとして Ray クラスタに追加します。
SSH を使用して最初のホストに接続します。
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=0
インストール要件セクションと同じ要件ファイルを使用して依存関係をインストールします。
pip install -r requirements.txt
Ray プロセスを開始します。
ray start --head --port=6379
出力は次のようになります。
Enable usage stats collection? This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup. Confirm [Y/n]: y Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details. Local node IP: 10.130.0.76 -------------------- Ray runtime started. -------------------- Next steps To add another node to this Ray cluster, run ray start --address='10.130.0.76:6379' To connect to this Ray cluster: import ray ray.init() To terminate the Ray runtime, run ray stop To view the status of the cluster, use ray status
この TPU ホストが Ray ヘッドノードになりました。以下のような、Ray クラスタに別のノードを追加する方法を示している行をメモします。
To add another node to this Ray cluster, run ray start --address='10.130.0.76:6379'
このコマンドは後の手順で使用します。
Ray クラスタのステータスを確認します。
ray status
出力は次のようになります。
======== Autoscaler status: 2025-01-14 22:03:39.385610 ======== Node status --------------------------------------------------------------- Active: 1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79 Pending: (no pending nodes) Recent failures: (no failures) Resources --------------------------------------------------------------- Usage: 0.0/208.0 CPU 0.0/4.0 TPU 0.0/1.0 TPU-v5p-16-head 0B/268.44GiB memory 0B/119.04GiB object_store_memory 0.0/1.0 your-tpu-name Demands: (no resource demands)
これまでに追加したのはヘッドノードのみであるため、クラスタには 4 TPU(
0.0/4.0 TPU
)のみが含まれています。ヘッドノードが実行されたら、2 つ目のホストをクラスタに追加できます。
SSH を使用して 2 番目のホストに接続します。
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=1
インストール要件セクションと同じ要件ファイルを使用して依存関係をインストールします。
pip install -r requirements.txt
Ray プロセスを開始します。このノードを既存の Ray クラスタに追加するには、
ray start
コマンドの出力のコマンドを使用します。次のコマンドの IP アドレスとポートを置き換えてください。ray start --address='10.130.0.76:6379'
出力は次のようになります。
Local node IP: 10.130.0.80 [2025-01-14 22:30:07,397 W 75572 75572] global_state_accessor.cc:463: Retrying to get node with node ID 35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1 -------------------- Ray runtime started. -------------------- To terminate the Ray runtime, run ray stop
Ray のステータスをもう一度確認します。
ray status
出力は次のようになります。
======== Autoscaler status: 2025-01-14 22:45:21.485617 ======== Node status --------------------------------------------------------------- Active: 1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79 1 node_35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1 Pending: (no pending nodes) Recent failures: (no failures) Resources --------------------------------------------------------------- Usage: 0.0/416.0 CPU 0.0/8.0 TPU 0.0/1.0 TPU-v5p-16-head 0B/546.83GiB memory 0B/238.35GiB object_store_memory 0.0/2.0 your-tpu-name Demands: (no resource demands)
2 つ目の TPU ホストがクラスタ内のノードになり、使用可能なリソースのリストの表示が 8 TPU(
0.0/8.0 TPU
)になりました。
Ray ワークロードを実行する
Ray クラスタで実行されるようにコード スニペットを更新します。
import os import torch import torch_xla import torch_xla.core.xla_model as xm import ray import torch.distributed as dist import torch_xla.runtime as xr from torch_xla._internal import pjrt # Defines the local PJRT world size, the number of processes per host. LOCAL_WORLD_SIZE = 4 # Defines the number of hosts in the Ray cluster. NUM_OF_HOSTS = 4 GLOBAL_WORLD_SIZE = LOCAL_WORLD_SIZE * NUM_OF_HOSTS def init_env(): local_rank = int(os.environ['TPU_VISIBLE_CHIPS']) pjrt.initialize_multiprocess(local_rank, LOCAL_WORLD_SIZE) xr._init_world_size_ordinal() # This decorator signals to Ray that the `print_tensor()` function should be run on a single TPU chip. @ray.remote(resources={"TPU": 1}) def print_tensor(): # Initializes the runtime environment on each Ray worker. Equivalent to # the `torch_xla.launch call` in the Run PyTorch/XLA on multiple devices section. init_env() t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t) ray.init() # Uses Ray to dispatch the function call across available nodes in the cluster. tasks = [print_tensor.remote() for _ in range(GLOBAL_WORLD_SIZE)] ray.get(tasks) ray.shutdown()
Ray ヘッドノードでスクリプトを実行します。ray-workload.py は、スクリプトへのパスに置き換えます。
python ray-workload.py
出力は次のようになります。
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU. xla:0 xla:0 xla:0 xla:0 xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') xla:0 xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0')
出力は、マルチホスト TPU スライスの各 XLA デバイス(この例では 8 台のデバイス)で関数が正常に呼び出されたことを示しています。
ホスト中心モード(JAX)
以降のセクションでは、JAX を使用したホスト中心モードについて説明します。JAX は関数型プログラミングのパラダイムを使用し、高レベルの単一プログラム、複数データ(SPMD)セマンティクスをサポートしています。JAX コードは、各プロセスが単一の XLA デバイスとやり取りするのではなく、単一ホスト上の複数のデバイスで同時に動作するように設計されています。
JAX はハイ パフォーマンス コンピューティング用に設計されており、大規模なトレーニングと推論に TPU を効率的に使用できます。このモードは、関数型プログラミングの概念に精通していて、JAX の可能性を最大限に活用できる場合に最適です。
以降の手順では、JAX やその他の関連パッケージを含むソフトウェア環境など、Ray と TPU 環境がすでに設定されていることを前提としています。Ray TPU クラスタを作成するには、KubeRay 用の TPU で Google Cloud GKE クラスタを起動するの手順を行います。KubeRay での TPU の使用について詳しくは、KubeRay で TPU を使用するをご覧ください。
単一ホスト TPU で JAX ワークロードを実行する
次のサンプル スクリプトは、単一ホスト TPU(v6e-4 など)を使用して Ray クラスタで JAX 関数を実行する方法を示しています。マルチホスト TPU を使用している場合、JAX のマルチコントローラ実行モデルが原因で、このスクリプトは応答を停止します。マルチホスト TPU で Ray を実行する方法については、マルチホスト TPU で JAX ワークロードを実行するをご覧ください。
TPU 作成パラメータの環境変数を作成します。
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-a export ACCELERATOR_TYPE=v6e-4 export RUNTIME_VERSION=v2-alpha-tpuv6e
環境変数の説明
変数 説明 PROJECT_ID
実際の Google Cloud のプロジェクト ID。既存のプロジェクトを使用するか、新しいプロジェクトを作成します。 TPU_NAME
TPU の名前。 ZONE
TPU VM を作成するゾーン。サポートされているゾーンの詳細については、TPU のリージョンとゾーンをご覧ください。 ACCELERATOR_TYPE
アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。 RUNTIME_VERSION
Cloud TPU ソフトウェアのバージョン。 次のコマンドを使用して、4 コアの v6e TPU VM を作成します。
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
次のコマンドを使用して TPU VM に接続します。
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
TPU に JAX と Ray をインストールします。
pip install ray jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
次のコードをファイルに保存します。例:
ray-jax-single-host.py
import ray import jax @ray.remote(resources={"TPU": 4}) def my_function() -> int: return jax.device_count() h = my_function.remote() print(ray.get(h)) # => 4
GPU を使用して Ray を実行することに慣れている場合、TPU を使用する場合はいくつかの重要な違いがあります。
num_gpus
を設定する代わりに、カスタム リソースとしてTPU
を指定し、TPU チップの数を設定します。- Ray ワーカーノードあたりのチップ数を使用して TPU を指定します。たとえば、v6e-4 を使用している場合、
TPU
を 4 に設定してリモート関数を実行すると、TPU ホスト全体が使用されます。 - これは、通常の GPU の実行方法(ホストごとに 1 つのプロセス)とは異なります。
TPU
を 4 以外の数値に設定することはおすすめしません。- 例外: 単一ホストの
v6e-8
またはv5litepod-8
を使用している場合は、この値を 8 に設定する必要があります。
- 例外: 単一ホストの
スクリプトを実行します。
python ray-jax-single-host.py
マルチホスト TPU で JAX ワークロードを実行する
次のサンプル スクリプトは、マルチホスト TPU を使用して Ray クラスタで JAX 関数を実行する方法を示しています。サンプル スクリプトでは v6e-16 を使用しています。
TPU 作成パラメータの環境変数を作成します。
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-a export ACCELERATOR_TYPE=v6e-16 export RUNTIME_VERSION=v2-alpha-tpuv6e
環境変数の説明
変数 説明 PROJECT_ID
実際の Google Cloud のプロジェクト ID。既存のプロジェクトを使用するか、新しいプロジェクトを作成します。 TPU_NAME
TPU の名前。 ZONE
TPU VM を作成するゾーン。サポートされているゾーンの詳細については、TPU のリージョンとゾーンをご覧ください。 ACCELERATOR_TYPE
アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。 RUNTIME_VERSION
Cloud TPU ソフトウェアのバージョン。 次のコマンドを使用して、16 コアの v6e TPU VM を作成します。
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
すべての TPU ワーカーに JAX と Ray をインストールします。
gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=all \ --command="pip install ray jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
次のコードをファイルに保存します。例:
ray-jax-multi-host.py
import ray import jax @ray.remote(resources={"TPU": 4}) def my_function() -> int: return jax.device_count() ray.init() num_tpus = ray.available_resources()["TPU"] num_hosts = int(num_tpus) # 4 h = [my_function.remote() for _ in range(num_hosts)] print(ray.get(h)) # [16, 16, 16, 16]
GPU を使用して Ray を実行することに慣れている場合、TPU を使用する場合はいくつかの重要な違いがあります。
- GPU 上の PyTorch ワークロードと同様に、次のようになります。
- TPU 上の JAX ワークロードは、マルチコントローラ、単一プログラム、複数データ(SPMD)方式で実行されます。
- デバイス間のコレクティブは、ML フレームワークによって処理されます。
- GPU 上の PyTorch ワークロードとは異なり、JAX にはクラスタで使用可能なデバイスのグローバル ビューがあります。
- GPU 上の PyTorch ワークロードと同様に、次のようになります。
スクリプトをすべての TPU ワーカーにコピーします。
gcloud compute tpus tpu-vm scp ray-jax-multi-host.py $TPU_NAME: --zone=$ZONE --worker=all
スクリプトを実行します。
gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=all \ --command="python ray-jax-multi-host.py"
マルチスライス JAX ワークロードを実行する
マルチスライスを使用すると、単一の TPU Pod 内の複数の TPU スライス、またはデータセンター ネットワーク上の複数の Pod にまたがるワークロードを実行できます。
ray-tpu
パッケージを使用すると、Ray と TPU スライスのインタラクションを簡素化できます。
pip
を使用して ray-tpu
をインストールします。
pip install ray-tpu
ray-tpu
パッケージの使用方法に関する詳細については、GitHub リポジトリのスタートガイドをご覧ください。マルチスライスの使用例については、マルチスライスでの実行をご覧ください。
Ray と MaxText を使用してワークロードをオーケストレートする
Ray と MaxText の使用方法に関する詳細については、MaxText でトレーニング ジョブを実行するをご覧ください。
TPU リソースと Ray リソース
Ray は、使用方法の違いに対応するために、TPU を GPU とは異なる方法で処理します。次の例では、合計 9 個の Ray ノードがあります。
- Ray ヘッドノードは
n1-standard-16
VM で実行されています。 - Ray ワーカーノードは 2 つの
v6e-16
TPU で実行されています。各 TPU は 4 つのワーカーで構成されます。
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/727.0 CPU
0.0/32.0 TPU
0.0/2.0 TPU-v6e-16-head
0B/5.13TiB memory
0B/1.47TiB object_store_memory
0.0/4.0 tpu-group-0
0.0/4.0 tpu-group-1
Demands:
(no resource demands)
リソース使用量フィールドの説明:
CPU
: クラスタで使用可能な CPU の合計数。TPU
: クラスタ内の TPU チップの数。TPU-v6e-16-head
: TPU スライスのワーカー 0 に対応するリソースの特別な識別子。これは、個々の TPU スライスにアクセスする場合に重要です。memory
: アプリケーションで使用されるワーカー ヒープメモリ。object_store_memory
: アプリケーションがray.put
を使用してオブジェクト ストアにオブジェクトを作成するとき、およびリモート関数から値を返すときに使用されるメモリ。tpu-group-0
とtpu-group-1
: 個々の TPU スライスの一意の識別子。これは、スライスでジョブを実行する場合に重要です。v6e-16 には TPU スライスごとに 4 つのホストがあるため、これらのフィールドは 4 に設定されています。