ML-Arbeitslasten mit Ray skalieren
In diesem Dokument finden Sie Details zum Ausführen von ML-Arbeitslasten mit Ray und JAX auf TPUs. Es gibt zwei verschiedene Modi für die Verwendung von TPUs mit Ray: geräteorientierter Modus (PyTorch/XLA) und hostorientierter Modus (JAX).
In diesem Dokument wird davon ausgegangen, dass Sie bereits eine TPU-Umgebung eingerichtet haben. Weitere Informationen finden Sie in den folgenden Ressourcen:
- Cloud TPU: Cloud TPU-Umgebung einrichten und TPU-Ressourcen verwalten
- Google Kubernetes Engine (GKE): TPU-Arbeitslasten in GKE Autopilot bereitstellen oder TPU-Arbeitslasten in GKE Standard bereitstellen
Geräteorientierter Modus (PyTorch/XLA)
Im geräteorientierten Modus bleibt der programmatische Stil der klassischen PyTorch-Version weitgehend erhalten. In diesem Modus fügen Sie einen neuen XLA-Gerätetyp hinzu, der wie jedes andere PyTorch-Gerät funktioniert. Jeder einzelne Prozess interagiert mit einem XLA-Gerät.
Dieser Modus eignet sich ideal, wenn Sie bereits mit PyTorch mit GPUs vertraut sind und ähnliche Codierungsabstraktionen verwenden möchten.
In den folgenden Abschnitten wird beschrieben, wie Sie eine PyTorch-/XLA-Arbeitslast auf einem oder mehreren Geräten ohne Ray ausführen und dann dieselbe Arbeitslast mit Ray auf mehreren Hosts ausführen.
TPU erstellen
Erstellen Sie Umgebungsvariablen für die Parameter zur TPU-Erstellung.
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5p-8 export RUNTIME_VERSION=v2-alpha-tpuv5
Beschreibungen von Umgebungsvariablen
Variable Beschreibung PROJECT_ID
Ihre Google Cloud Projekt-ID. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues. TPU_NAME
Der Name der TPU. ZONE
Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und ‑Zonen. ACCELERATOR_TYPE
Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für jede TPU-Version finden Sie unter TPU-Versionen. RUNTIME_VERSION
Die Cloud TPU-Softwareversion. Verwenden Sie den folgenden Befehl, um eine v5p-TPU-VM mit 8 Kernen zu erstellen:
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
Stellen Sie mit dem folgenden Befehl eine Verbindung zur TPU-VM her:
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
Wenn Sie GKE verwenden, finden Sie Informationen zur Einrichtung im Leitfaden KubeRay in GKE.
Installationsanforderungen
Führen Sie die folgenden Befehle auf Ihrer TPU-VM aus, um die erforderlichen Abhängigkeiten zu installieren:
Speichern Sie Folgendes in einer Datei. Beispiel:
requirements.txt
.--find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/libtpu-wheels/index.html torch~=2.6.0 torch_xla[tpu]~=2.6.0 ray[default]==2.40.0
Führen Sie den folgenden Befehl aus, um die erforderlichen Abhängigkeiten zu installieren:
pip install -r requirements.txt
Wenn Sie Ihre Arbeitslast in GKE ausführen, empfehlen wir, ein Dockerfile zu erstellen, mit dem die erforderlichen Abhängigkeiten installiert werden. Ein Beispiel finden Sie in der GKE-Dokumentation unter Arbeitslast auf TPU-Slice-Knoten ausführen.
PyTorch/XLA-Arbeitslast auf einem einzelnen Gerät ausführen
Im folgenden Beispiel wird gezeigt, wie Sie einen XLA-Tensor auf einem einzelnen Gerät erstellen, also auf einem TPU-Chip. Das entspricht der Vorgehensweise von PyTorch bei anderen Gerätetypen.
Speichern Sie das folgende Code-Snippet in einer Datei. Beispiel:
workload.py
.import torch import torch_xla import torch_xla.core.xla_model as xm t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t)
Die
import torch_xla
-Importanweisung initialisiert PyTorch/XLA und diexm.xla_device()
-Funktion gibt das aktuelle XLA-Gerät zurück, einen TPU-Chip.Legen Sie die Umgebungsvariable
PJRT_DEVICE
auf „TPU“ fest.export PJRT_DEVICE=TPU
Führen Sie das Skript aus.
python workload.py
Die Ausgabe sieht ungefähr so aus: Achte darauf, dass in der Ausgabe angezeigt wird, dass das XLA-Gerät gefunden wurde.
xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0')
PyTorch/XLA auf mehreren Geräten ausführen
Aktualisieren Sie das Code-Snippet aus dem vorherigen Abschnitt, damit es auf mehreren Geräten ausgeführt werden kann.
import torch import torch_xla import torch_xla.core.xla_model as xm def _mp_fn(index): t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t) if __name__ == '__main__': torch_xla.launch(_mp_fn, args=())
Führen Sie das Skript aus.
python workload.py
Wenn Sie das Code-Snippet auf einer TPU v5p-8 ausführen, sieht die Ausgabe in etwa so aus:
xla:0 xla:0 xla:0 tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') xla:0 tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0')
torch_xla.launch()
nimmt zwei Argumente an: eine Funktion und eine Liste von Parametern. Es wird ein Prozess für jedes verfügbare XLA-Gerät erstellt und die in den Argumenten angegebene Funktion aufgerufen. In diesem Beispiel sind 4 TPU-Geräte verfügbar. torch_xla.launch()
erstellt also 4 Prozesse und ruft _mp_fn()
auf jedem Gerät auf. Jeder Prozess hat nur Zugriff auf ein Gerät. Daher hat jedes Gerät den Index 0 und xla:0
wird für alle Prozesse ausgegeben.
PyTorch/XLA mit Ray auf mehreren Hosts ausführen
In den folgenden Abschnitten wird gezeigt, wie Sie dasselbe Code-Snippet auf einem größeren TPU-Speicherbereich mit mehreren Hosts ausführen. Weitere Informationen zur TPU-Architektur mit mehreren Hosts finden Sie unter Systemarchitektur.
In diesem Beispiel richten Sie Ray manuell ein. Wenn Sie mit der Einrichtung von Ray bereits vertraut sind, können Sie mit dem letzten Abschnitt Ray-Arbeitslast ausführen fortfahren. Weitere Informationen zum Einrichten von Ray für eine Produktionsumgebung finden Sie in den folgenden Ressourcen:
TPU-VM mit mehreren Hosts erstellen
Erstellen Sie Umgebungsvariablen für die Parameter zur TPU-Erstellung.
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5p-16 export RUNTIME_VERSION=v2-alpha-tpuv5
Beschreibungen von Umgebungsvariablen
Variable Beschreibung PROJECT_ID
Ihre Google Cloud Projekt-ID. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues. TPU_NAME
Der Name der TPU. ZONE
Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und ‑Zonen. ACCELERATOR_TYPE
Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für jede TPU-Version finden Sie unter TPU-Versionen. RUNTIME_VERSION
Die Cloud TPU-Softwareversion. Erstellen Sie mit dem folgenden Befehl eine TPU v5p mit mehreren Hosts und zwei Hosts (v5p-16 mit jeweils 4 TPU-Chips auf jedem Host):
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
Ray einrichten
Ein TPU v5p-16 hat 2 TPU-Hosts mit jeweils 4 TPU-Chips. In diesem Beispiel starten Sie den Ray-Leitknoten auf einem Host und fügen den zweiten Host als Workerknoten dem Ray-Cluster hinzu.
Stellen Sie über SSH eine Verbindung zum ersten Host her.
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=0
Installieren Sie die Abhängigkeiten mit derselben Anforderungsdatei wie im Abschnitt Anforderungen an die Installation.
pip install -r requirements.txt
Starten Sie den Ray-Prozess.
ray start --head --port=6379
Die Ausgabe sieht dann ungefähr so aus:
Enable usage stats collection? This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup. Confirm [Y/n]: y Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details. Local node IP: 10.130.0.76 -------------------- Ray runtime started. -------------------- Next steps To add another node to this Ray cluster, run ray start --address='10.130.0.76:6379' To connect to this Ray cluster: import ray ray.init() To terminate the Ray runtime, run ray stop To view the status of the cluster, use ray status
Dieser TPU-Host ist jetzt der Ray-Leitknoten. Notieren Sie sich die Zeilen, die zeigen, wie Sie dem Ray-Cluster einen weiteren Knoten hinzufügen, ähnlich wie in der folgenden Abbildung:
To add another node to this Ray cluster, run ray start --address='10.130.0.76:6379'
Sie verwenden diesen Befehl in einem späteren Schritt.
Prüfen Sie den Status des Ray-Clusters:
ray status
Die Ausgabe sieht dann ungefähr so aus:
======== Autoscaler status: 2025-01-14 22:03:39.385610 ======== Node status --------------------------------------------------------------- Active: 1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79 Pending: (no pending nodes) Recent failures: (no failures) Resources --------------------------------------------------------------- Usage: 0.0/208.0 CPU 0.0/4.0 TPU 0.0/1.0 TPU-v5p-16-head 0B/268.44GiB memory 0B/119.04GiB object_store_memory 0.0/1.0 your-tpu-name Demands: (no resource demands)
Der Cluster enthält nur vier TPUs (
0.0/4.0 TPU
), da Sie bisher nur den Kopfknoten hinzugefügt haben.Nachdem der Head-Knoten ausgeführt wird, können Sie dem Cluster den zweiten Host hinzufügen.
Stellen Sie eine SSH-Verbindung zum zweiten Host her.
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=1
Installieren Sie die Abhängigkeiten mit derselben Anforderungsdatei wie im Abschnitt Installationsanforderungen.
pip install -r requirements.txt
Starten Sie den Ray-Prozess. Wenn Sie diesen Knoten dem vorhandenen Ray-Cluster hinzufügen möchten, verwenden Sie den Befehl aus der Ausgabe des Befehls
ray start
. Ersetzen Sie im folgenden Befehl die IP-Adresse und den Port:ray start --address='10.130.0.76:6379'
Die Ausgabe sieht dann ungefähr so aus:
Local node IP: 10.130.0.80 [2025-01-14 22:30:07,397 W 75572 75572] global_state_accessor.cc:463: Retrying to get node with node ID 35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1 -------------------- Ray runtime started. -------------------- To terminate the Ray runtime, run ray stop
Prüfen Sie noch einmal den Ray-Status:
ray status
Die Ausgabe sieht dann ungefähr so aus:
======== Autoscaler status: 2025-01-14 22:45:21.485617 ======== Node status --------------------------------------------------------------- Active: 1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79 1 node_35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1 Pending: (no pending nodes) Recent failures: (no failures) Resources --------------------------------------------------------------- Usage: 0.0/416.0 CPU 0.0/8.0 TPU 0.0/1.0 TPU-v5p-16-head 0B/546.83GiB memory 0B/238.35GiB object_store_memory 0.0/2.0 your-tpu-name Demands: (no resource demands)
Der zweite TPU-Host ist jetzt ein Knoten im Cluster. Die Liste der verfügbaren Ressourcen enthält jetzt 8 TPUs (
0.0/8.0 TPU
).
Ray-Arbeitslast ausführen
Aktualisieren Sie das Code-Snippet, damit es im Ray-Cluster ausgeführt wird:
import os import torch import torch_xla import torch_xla.core.xla_model as xm import ray import torch.distributed as dist import torch_xla.runtime as xr from torch_xla._internal import pjrt # Defines the local PJRT world size, the number of processes per host. LOCAL_WORLD_SIZE = 4 # Defines the number of hosts in the Ray cluster. NUM_OF_HOSTS = 4 GLOBAL_WORLD_SIZE = LOCAL_WORLD_SIZE * NUM_OF_HOSTS def init_env(): local_rank = int(os.environ['TPU_VISIBLE_CHIPS']) pjrt.initialize_multiprocess(local_rank, LOCAL_WORLD_SIZE) xr._init_world_size_ordinal() # This decorator signals to Ray that the `print_tensor()` function should be run on a single TPU chip. @ray.remote(resources={"TPU": 1}) def print_tensor(): # Initializes the runtime environment on each Ray worker. Equivalent to # the `torch_xla.launch call` in the Run PyTorch/XLA on multiple devices section. init_env() t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t) ray.init() # Uses Ray to dispatch the function call across available nodes in the cluster. tasks = [print_tensor.remote() for _ in range(GLOBAL_WORLD_SIZE)] ray.get(tasks) ray.shutdown()
Führen Sie das Script auf dem Ray-Leitknoten aus. Ersetzen Sie ray-workload.py durch den Pfad zu Ihrem Script.
python ray-workload.py
Die Ausgabe sieht dann ungefähr so aus:
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU. xla:0 xla:0 xla:0 xla:0 xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') xla:0 xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0')
Die Ausgabe zeigt an, dass die Funktion auf jedem XLA-Gerät (in diesem Beispiel 8 Geräte) im TPU-Slice mit mehreren Hosts erfolgreich aufgerufen wurde.
Host-zentrierter Modus (JAX)
In den folgenden Abschnitten wird der hostzentrierte Modus mit JAX beschrieben. JAX verwendet ein Paradigma der funktionalen Programmierung und unterstützt die höhere Ebene der SPMD-Semantik (Single Program, Multiple Data). Anstatt dass jeder Prozess mit einem einzelnen XLA-Gerät interagiert, ist JAX-Code so konzipiert, dass er gleichzeitig auf mehreren Geräten auf einem einzigen Host ausgeführt werden kann.
JAX wurde für Hochleistungs-Computing entwickelt und kann TPUs effizient für die groß angelegte Modellerstellung und Inferenz nutzen. Dieser Modus eignet sich ideal, wenn Sie mit den Konzepten der funktionalen Programmierung vertraut sind, damit Sie das volle Potenzial von JAX nutzen können.
In dieser Anleitung wird davon ausgegangen, dass Sie bereits eine Ray- und TPU-Umgebung eingerichtet haben, einschließlich einer Softwareumgebung mit JAX und anderen zugehörigen Paketen. Folgen Sie der Anleitung unter GKE-Cluster mit TPUs für KubeRay starten Google Cloud , um einen Ray-TPU-Cluster zu erstellen. Weitere Informationen zur Verwendung von TPUs mit KubeRay finden Sie unter TPUs mit KubeRay verwenden.
JAX-Arbeitslast auf einer TPU mit einem einzelnen Host ausführen
Das folgende Beispielskript zeigt, wie eine JAX-Funktion in einem Ray-Cluster mit einer TPU mit einem einzelnen Host ausgeführt wird, z. B. einer v6e-4. Wenn Sie eine TPU mit mehreren Hosts haben, reagiert dieses Script aufgrund des Multi-Controller-Ausführungsmodells von JAX nicht mehr. Weitere Informationen zum Ausführen von Ray auf einer TPU mit mehreren Hosts finden Sie unter JAX-Arbeitslast auf einer TPU mit mehreren Hosts ausführen.
Erstellen Sie Umgebungsvariablen für die Parameter zur TPU-Erstellung.
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-a export ACCELERATOR_TYPE=v6e-4 export RUNTIME_VERSION=v2-alpha-tpuv6e
Beschreibungen von Umgebungsvariablen
Variable Beschreibung PROJECT_ID
Ihre Google Cloud Projekt-ID. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues. TPU_NAME
Der Name der TPU. ZONE
Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und ‑Zonen. ACCELERATOR_TYPE
Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für jede TPU-Version finden Sie unter TPU-Versionen. RUNTIME_VERSION
Die Cloud TPU-Softwareversion. Verwenden Sie den folgenden Befehl, um eine v6e-TPU-VM mit 4 Kernen zu erstellen:
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
Stellen Sie mit dem folgenden Befehl eine Verbindung zur TPU-VM her:
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
Installieren Sie JAX und Ray auf Ihrer TPU.
pip install ray jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Speichern Sie den folgenden Code in einer Datei. Beispiel:
ray-jax-single-host.py
.import ray import jax @ray.remote(resources={"TPU": 4}) def my_function() -> int: return jax.device_count() h = my_function.remote() print(ray.get(h)) # => 4
Wenn Sie Ray bisher mit GPUs ausgeführt haben, gibt es einige wichtige Unterschiede bei der Verwendung von TPUs:
- Anstatt
num_gpus
festzulegen, geben SieTPU
als benutzerdefinierte Ressource an und legen Sie die Anzahl der TPU-Chips fest. - Geben Sie die TPU mit der Anzahl der Chips pro Ray-Worker-Knoten an. Wenn Sie beispielsweise v6e-4 verwenden, wird beim Ausführen einer Remote-Funktion mit
TPU
= 4 der gesamte TPU-Host belegt. - Das unterscheidet sich von der üblichen Ausführung von GPUs mit einem Prozess pro Host.
Es wird nicht empfohlen,
TPU
auf eine andere Zahl als 4 festzulegen.- Ausnahme: Wenn Sie einen
v6e-8
- oderv5litepod-8
-Wert für einen einzelnen Host haben, sollten Sie diesen Wert auf 8 festlegen.
- Ausnahme: Wenn Sie einen
- Anstatt
Führen Sie das Skript aus.
python ray-jax-single-host.py
JAX-Arbeitslast auf einer TPU mit mehreren Hosts ausführen
Das folgende Beispielscript zeigt, wie eine JAX-Funktion in einem Ray-Cluster mit einer TPU mit mehreren Hosts ausgeführt wird. Im Beispielscript wird eine v6e-16 verwendet.
Erstellen Sie Umgebungsvariablen für die Parameter zur TPU-Erstellung.
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-a export ACCELERATOR_TYPE=v6e-16 export RUNTIME_VERSION=v2-alpha-tpuv6e
Beschreibungen von Umgebungsvariablen
Variable Beschreibung PROJECT_ID
Ihre Google Cloud Projekt-ID. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues. TPU_NAME
Der Name der TPU. ZONE
Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und ‑Zonen. ACCELERATOR_TYPE
Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für jede TPU-Version finden Sie unter TPU-Versionen. RUNTIME_VERSION
Die Cloud TPU-Softwareversion. Verwenden Sie den folgenden Befehl, um eine v6e-TPU-VM mit 16 Kernen zu erstellen:
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
Installieren Sie JAX und Ray auf allen TPU-Workern.
gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=all \ --command="pip install ray jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
Speichern Sie den folgenden Code in einer Datei. Beispiel:
ray-jax-multi-host.py
.import ray import jax @ray.remote(resources={"TPU": 4}) def my_function() -> int: return jax.device_count() ray.init() num_tpus = ray.available_resources()["TPU"] num_hosts = int(num_tpus) # 4 h = [my_function.remote() for _ in range(num_hosts)] print(ray.get(h)) # [16, 16, 16, 16]
Wenn Sie Ray bisher mit GPUs ausgeführt haben, gibt es einige wichtige Unterschiede bei der Verwendung von TPUs:
- Ähnlich wie bei PyTorch-Arbeitslasten auf GPUs:
- JAX-Arbeitslasten auf TPUs werden im SPMD-Verfahren (Single Program Multiple Data) ausgeführt.
- Kollektive zwischen Geräten werden vom Framework für maschinelles Lernen verarbeitet.
- Im Gegensatz zu PyTorch-Arbeitslasten auf GPUs hat JAX eine globale Ansicht der verfügbaren Geräte im Cluster.
- Ähnlich wie bei PyTorch-Arbeitslasten auf GPUs:
Kopieren Sie das Script auf alle TPU-Worker.
gcloud compute tpus tpu-vm scp ray-jax-multi-host.py $TPU_NAME: --zone=$ZONE --worker=all
Führen Sie das Skript aus.
gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=all \ --command="python ray-jax-multi-host.py"
Multislice-JAX-Arbeitslast ausführen
Mit Multislice können Sie Arbeitslasten ausführen, die mehrere TPU-Slices innerhalb eines einzelnen TPU-Pod oder in mehreren Pods über das Rechenzentrumsnetzwerk hinweg umfassen.
Mit dem Paket ray-tpu
können Sie die Interaktionen von Ray mit TPU-Scheiben vereinfachen.
Installieren Sie ray-tpu
mit pip
.
pip install ray-tpu
Weitere Informationen zur Verwendung des ray-tpu
-Pakets finden Sie im GitHub-Repository unter Einstieg. Ein Beispiel für die Verwendung von Multislice finden Sie unter Auf Multislice ausführen.
Arbeitslasten mit Ray und MaxText orchestrieren
Weitere Informationen zur Verwendung von Ray mit MaxText finden Sie unter Trainingsjob mit MaxText ausführen.
TPU- und Ray-Ressourcen
Ray behandelt TPUs anders als GPUs, um den Unterschied bei der Nutzung zu berücksichtigen. Im folgenden Beispiel gibt es insgesamt neun Ray-Knoten:
- Der Ray-Leitknoten wird auf einer
n1-standard-16
-VM ausgeführt. - Die Ray-Worker-Knoten werden auf zwei
v6e-16
TPUs ausgeführt. Jede TPU besteht aus vier Workern.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/727.0 CPU
0.0/32.0 TPU
0.0/2.0 TPU-v6e-16-head
0B/5.13TiB memory
0B/1.47TiB object_store_memory
0.0/4.0 tpu-group-0
0.0/4.0 tpu-group-1
Demands:
(no resource demands)
Beschreibungen der Felder zur Ressourcennutzung:
CPU
: Die Gesamtzahl der im Cluster verfügbaren CPUs.TPU
: Die Anzahl der TPU-Chips im Cluster.TPU-v6e-16-head
: Eine spezielle Kennung für die Ressource, die Worker 0 eines TPU-Slabs entspricht. Das ist wichtig für den Zugriff auf einzelne TPU-Scheiben.memory
: Der von Ihrer Anwendung verwendete Worker-Heap-Speicher.object_store_memory
: Speicher, der verwendet wird, wenn Ihre Anwendung Objekte im Objektspeicher mitray.put
erstellt und Werte von Remotefunktionen zurückgibt.tpu-group-0
undtpu-group-1
: Eindeutige Kennungen für die einzelnen TPU-Scheiben. Das ist wichtig, wenn Sie Jobs auf Segmenten ausführen möchten. Diese Felder sind auf „4“ festgelegt, da es in einem v6e-16 vier Hosts pro TPU-Speichereinheit gibt.