ML-Arbeitslasten mit Ray skalieren

In diesem Dokument finden Sie Details zum Ausführen von ML-Arbeitslasten mit Ray und JAX auf TPUs. Es gibt zwei verschiedene Modi für die Verwendung von TPUs mit Ray: geräteorientierter Modus (PyTorch/XLA) und hostorientierter Modus (JAX).

In diesem Dokument wird davon ausgegangen, dass Sie bereits eine TPU-Umgebung eingerichtet haben. Weitere Informationen finden Sie in den folgenden Ressourcen:

Geräteorientierter Modus (PyTorch/XLA)

Im geräteorientierten Modus bleibt der programmatische Stil der klassischen PyTorch-Version weitgehend erhalten. In diesem Modus fügen Sie einen neuen XLA-Gerätetyp hinzu, der wie jedes andere PyTorch-Gerät funktioniert. Jeder einzelne Prozess interagiert mit einem XLA-Gerät.

Dieser Modus eignet sich ideal, wenn Sie bereits mit PyTorch mit GPUs vertraut sind und ähnliche Codierungsabstraktionen verwenden möchten.

In den folgenden Abschnitten wird beschrieben, wie Sie eine PyTorch-/XLA-Arbeitslast auf einem oder mehreren Geräten ohne Ray ausführen und dann dieselbe Arbeitslast mit Ray auf mehreren Hosts ausführen.

TPU erstellen

  1. Erstellen Sie Umgebungsvariablen für die Parameter zum Erstellen von TPUs:

    export TPU_NAME=TPU_NAME
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-8
    export VERSION=v2-alpha-tpuv5

    Beschreibungen von Umgebungsvariablen

    TPU_NAME
    Der Name Ihrer neuen Cloud TPU.
    ZONE
    Die Zone, in der die Cloud TPU erstellt werden soll.
    accelerator-type
    Mit dem Beschleunigertyp geben Sie die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen finden Sie unter TPU-Versionen.
    version
    Die TPU-Softwareversion, die Sie verwenden möchten. Weitere Informationen finden Sie unter TPU-VM-Images.
  2. Verwenden Sie den folgenden Befehl, um eine v5p-TPU-VM mit 8 Kernen zu erstellen:

    gcloud compute tpus tpu-vm create $TPU_NAME \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE  \
        --version=$VERSION
  3. Stellen Sie mit dem folgenden Befehl eine Verbindung zur TPU-VM her:

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE

Wenn Sie GKE verwenden, finden Sie Informationen zur Einrichtung im Leitfaden für KubeRay in GKE.

Installationsanforderungen

Führen Sie die folgenden Befehle auf Ihrer TPU-VM aus, um die erforderlichen Abhängigkeiten zu installieren:

  1. Speichern Sie Folgendes in einer Datei, z. B. requirements.txt:

    --find-links https://storage.googleapis.com/libtpu-releases/index.html
    --find-links https://storage.googleapis.com/libtpu-wheels/index.html
    torch~=2.6.0
    torch_xla[tpu]~=2.6.0
    ray[default]==2.40.0
    
  2. Führen Sie den folgenden Befehl aus, um die erforderlichen Abhängigkeiten zu installieren:

    pip install -r requirements.txt
    

Wenn Sie Ihre Arbeitslast in GKE ausführen, empfehlen wir, ein Dockerfile zu erstellen, mit dem die erforderlichen Abhängigkeiten installiert werden. Ein Beispiel finden Sie in der GKE-Dokumentation unter Arbeitslast auf TPU-Slice-Knoten ausführen.

PyTorch/XLA-Arbeitslast auf einem einzelnen Gerät ausführen

Im folgenden Beispiel wird gezeigt, wie Sie einen XLA-Tensor auf einem einzelnen Gerät erstellen, also auf einem TPU-Chip. Das entspricht der Vorgehensweise von PyTorch bei anderen Gerätetypen.

  1. Speichern Sie das folgende Code-Snippet in einer Datei, z. B. workload.py:

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    t = torch.randn(2, 2, device=xm.xla_device())
    print(t.device)
    print(t)
    

    Die import torch_xla-Importanweisung initialisiert PyTorch/XLA und die xm.xla_device()-Funktion gibt das aktuelle XLA-Gerät zurück, einen TPU-Chip.

  2. Legen Sie die Umgebungsvariable PJRT_DEVICE auf TPU fest:

    export PJRT_DEVICE=TPU
    
  3. Führen Sie das Skript aus:

    python workload.py
    

    Die Ausgabe sieht ungefähr so aus: Achte darauf, dass in der Ausgabe angezeigt wird, dass das XLA-Gerät gefunden wurde.

    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

PyTorch/XLA auf mehreren Geräten ausführen

  1. Aktualisieren Sie das Code-Snippet aus dem vorherigen Abschnitt, damit es auf mehreren Geräten ausgeführt werden kann:

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    def _mp_fn(index):
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    if __name__ == '__main__':
        torch_xla.launch(_mp_fn, args=())
    
  2. Führen Sie das Skript aus:

    python workload.py
    

    Wenn Sie das Code-Snippet auf einer TPU v5p-8 ausführen, sieht die Ausgabe in etwa so aus:

    xla:0
    xla:0
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    

torch_xla.launch() nimmt zwei Argumente an, eine Funktion und eine Liste von Parametern. Es wird ein Prozess für jedes verfügbare XLA-Gerät erstellt und die in den Argumenten angegebene Funktion aufgerufen. In diesem Beispiel sind vier TPU-Geräte verfügbar. torch_xla.launch() erstellt also vier Prozesse und ruft _mp_fn() auf jedem Gerät auf. Da jeder Prozess nur auf ein Gerät zugreifen kann, hat jedes Gerät den Index 0 und xla:0 wird für alle Prozesse ausgegeben.

PyTorch/XLA mit Ray auf mehreren Hosts ausführen

In den folgenden Abschnitten wird gezeigt, wie Sie dasselbe Code-Snippet auf einem größeren TPU-Speicherbereich mit mehreren Hosts ausführen. Weitere Informationen zur TPU-Architektur mit mehreren Hosts finden Sie unter Systemarchitektur.

In diesem Beispiel richten Sie Ray manuell ein. Wenn Sie bereits mit der Einrichtung von Ray vertraut sind, können Sie mit dem letzten Abschnitt Ray-Arbeitslast ausführen fortfahren. Weitere Informationen zum Einrichten von Ray für eine Produktionsumgebung finden Sie unter den folgenden Links:

TPU-VM mit mehreren Hosts erstellen

  1. Erstellen Sie Umgebungsvariablen für die Parameter zum Erstellen von TPUs:

    export TPU_NAME_MULTIHOST=TPU_NAME_MULTIHOST
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE_MULTIHOST=v5p-16
    export VERSION=v2-alpha-tpuv5
  2. Erstellen Sie mit dem folgenden Befehl eine TPU v5p mit mehreren Hosts und zwei Hosts (v5p-16 mit jeweils 4 TPU-Chips auf jedem Host):

    gcloud compute tpus tpu-vm create $TPU_NAME_MULTIHOST \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE_MULTIHOST \
        --version=$VERSION

Ray einrichten

Ein TPU v5p-16 hat 2 TPU-Hosts mit jeweils 4 TPU-Chips. In diesem Beispiel starten Sie den Ray-Leitknoten auf einem Host und fügen den zweiten Host als Workerknoten dem Ray-Cluster hinzu.

  1. Stellen Sie über SSH eine Verbindung zum ersten Host her:

    gcloud compute tpus tpu-vm ssh $TPU_NAME_MULTIHOST --zone=$ZONE --worker=0
  2. Installieren Sie die Abhängigkeiten mit derselben Anforderungsdatei wie im Abschnitt Anforderungen an die Installation beschrieben:

    pip install -r requirements.txt
    
  3. Starten Sie den Ray-Prozess:

    ray start --head --port=6379
    

    Die Ausgabe sieht dann ungefähr so aus:

    Enable usage stats collection? This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup. Confirm [Y/n]: y
    Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details.
    
    Local node IP: 10.130.0.76
    
    --------------------
    Ray runtime started.
    --------------------
    
    Next steps
    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    
    To connect to this Ray cluster:
        import ray
        ray.init()
    
    To terminate the Ray runtime, run
        ray stop
    
    To view the status of the cluster, use
        ray status
    

    Dieser TPU-Host ist jetzt der Ray-Leitknoten. Notieren Sie sich die Zeilen, die zeigen, wie Sie dem Ray-Cluster einen weiteren Knoten hinzufügen, ähnlich wie hier:

    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    

    Sie verwenden diesen Befehl in einem späteren Schritt.

  4. Prüfen Sie den Status des Ray-Clusters:

    ray status
    

    Die Ausgabe sieht dann ungefähr so aus:

    ======== Autoscaler status: 2025-01-14 22:03:39.385610 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/208.0 CPU
    0.0/4.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/268.44GiB memory
    0B/119.04GiB object_store_memory
    0.0/1.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    Der Cluster enthält nur vier TPUs (0.0/4.0 TPU), da Sie bisher nur den Head-Knoten hinzugefügt haben.

Nachdem der Hauptknoten ausgeführt wird, können Sie dem Cluster den zweiten Host hinzufügen.

  1. Stellen Sie über SSH eine Verbindung zum zweiten Host her:

    gcloud compute tpus tpu-vm ssh $TPU_NAME_MULTIHOST --zone=$ZONE --worker=1
  2. Installieren Sie die Abhängigkeiten mit derselben Anforderungsdatei wie im Abschnitt Installationsanforderungen beschrieben:

    pip install -r requirements.txt
    
  3. Starten Sie den Ray-Prozess. Verwenden Sie den Befehl aus der Ausgabe des Befehls ray start, um diesen Knoten dem vorhandenen Ray-Cluster hinzuzufügen. Ersetzen Sie die IP-Adresse und den Port im folgenden Befehl:

    ray start --address='10.130.0.76:6379'

    Die Ausgabe sieht dann ungefähr so aus:

    Local node IP: 10.130.0.80
    [2025-01-14 22:30:07,397 W 75572 75572] global_state_accessor.cc:463: Retrying to get node with node ID 35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    
    --------------------
    Ray runtime started.
    --------------------
    
    To terminate the Ray runtime, run
    ray stop
    
  4. Prüfen Sie den Ray-Status noch einmal:

    ray status
    

    Die Ausgabe sieht dann ungefähr so aus:

    ======== Autoscaler status: 2025-01-14 22:45:21.485617 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    1 node_35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/416.0 CPU
    0.0/8.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/546.83GiB memory
    0B/238.35GiB object_store_memory
    0.0/2.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    Der zweite TPU-Host ist jetzt ein Knoten im Cluster. Die Liste der verfügbaren Ressourcen enthält jetzt 8 TPUs (0.0/8.0 TPU).

Ray-Arbeitslast ausführen

  1. Aktualisieren Sie das Code-Snippet, damit es auf dem Ray-Cluster ausgeführt wird:

    import os
    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    import ray
    
    import torch.distributed as dist
    import torch_xla.runtime as xr
    from torch_xla._internal import pjrt
    
    # Defines the local PJRT world size, the number of processes per host
    LOCAL_WORLD_SIZE = 4
    # Defines the number of hosts in the Ray cluster
    NUM_OF_HOSTS = 2
    GLOBAL_WORLD_SIZE = LOCAL_WORLD_SIZE * NUM_OF_HOSTS
    
    def init_env():
        local_rank = int(os.environ['TPU_VISIBLE_CHIPS'])
    
        pjrt.initialize_multiprocess(local_rank, LOCAL_WORLD_SIZE)
        xr._init_world_size_ordinal()
    
    # This decorator signals to Ray that the print_tensor() function should be run on a single TPU chip
    @ray.remote(resources={"TPU": 1})
    def print_tensor():
        # Initializes the runtime environment on each Ray worker. Equivalent to
        # the `torch_xla.launch call` in the Run PyTorch/XLA on multiple devices section.
        init_env()
    
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    ray.init()
    
    # Uses Ray to dispatch the function call across available nodes in the cluster
    tasks = [print_tensor.remote() for _ in range(GLOBAL_WORLD_SIZE)] 
    ray.get(tasks)
    
    ray.shutdown()
    
  2. Führen Sie das Script auf dem Ray-Leitknoten aus. Ersetzen Sie ray-workload.py durch den Pfad zu Ihrem Script.

    python ray-workload.py

    Die Ausgabe sieht in etwa so aus:

    WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
    xla:0
    xla:0
    xla:0
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

    Die Ausgabe zeigt an, dass die Funktion auf jedem XLA-Gerät (in diesem Beispiel 8 Geräte) im TPU-Slice mit mehreren Hosts erfolgreich aufgerufen wurde.

Host-zentrierter Modus (JAX)

In den folgenden Abschnitten wird der hostzentrierte Modus mit JAX beschrieben. JAX nutzt ein funktionales Programmierparadigma und unterstützt die höhere Ebene der SPMD-Semantik (Single Program, Multiple Data). Anstatt dass jeder Prozess mit einem einzelnen XLA-Gerät interagiert, ist JAX-Code so konzipiert, dass er gleichzeitig auf mehreren Geräten auf einem einzigen Host ausgeführt werden kann.

JAX wurde für Hochleistungs-Computing entwickelt und kann TPUs effizient für die groß angelegte Modellerstellung und Inferenz nutzen. Dieser Modus eignet sich ideal, wenn Sie mit den Konzepten der funktionalen Programmierung vertraut sind, damit Sie das volle Potenzial von JAX nutzen können.

In dieser Anleitung wird davon ausgegangen, dass Sie bereits eine Ray- und TPU-Umgebung eingerichtet haben, einschließlich einer Softwareumgebung mit JAX und anderen zugehörigen Paketen. Folgen Sie der Anleitung unter GKE-Cluster mit TPUs für KubeRay startenGoogle Cloud , um einen Ray-TPU-Cluster zu erstellen. Weitere Informationen zur Verwendung von TPUs mit KubeRay finden Sie unter TPUs mit KubeRay verwenden.

JAX-Arbeitslast auf einer TPU mit einem einzelnen Host ausführen

Das folgende Beispielskript zeigt, wie eine JAX-Funktion in einem Ray-Cluster mit einer TPU mit einem einzelnen Host wie einer v6e-4 ausgeführt wird. Wenn Sie eine TPU mit mehreren Hosts haben, reagiert dieses Script aufgrund des Multi-Controller-Ausführungsmodells von JAX nicht mehr. Weitere Informationen zum Ausführen von Ray auf einer TPU mit mehreren Hosts finden Sie unter JAX-Arbeitslast auf einer TPU mit mehreren Hosts ausführen.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

h = my_function.remote()
print(ray.get(h)) # => 4

Wenn Sie Ray bisher mit GPUs ausgeführt haben, gibt es einige wichtige Unterschiede bei der Verwendung von TPUs:

  • Anstatt num_gpus festzulegen, geben Sie TPU als benutzerdefinierte Ressource an und legen die Anzahl der TPU-Chips fest.
  • Sie geben die TPU anhand der Anzahl der Chips pro Ray-Worker-Knoten an. Wenn Sie beispielsweise v6e-4 verwenden, wird beim Ausführen einer Remotefunktion mit TPU = 4 der gesamte TPU-Host belegt.
    • Das unterscheidet sich von der üblichen Ausführung von GPUs mit einem Prozess pro Host. Es wird nicht empfohlen, TPU auf eine andere Zahl als 4 festzulegen.
    • Ausnahme: Wenn Sie v6e-8 oder v5litepod-8 für einen einzelnen Host verwenden, sollten Sie diesen Wert auf 8 festlegen.

JAX-Arbeitslast auf einer TPU mit mehreren Hosts ausführen

Das folgende Beispielscript zeigt, wie eine JAX-Funktion in einem Ray-Cluster mit einer TPU mit mehreren Hosts ausgeführt wird. Im Beispielscript wird eine v6e-16 verwendet.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]

Wenn Sie Ray bisher mit GPUs ausgeführt haben, gibt es einige wichtige Unterschiede bei der Verwendung von TPUs:

  • Ähnlich wie bei PyTorch-Arbeitslasten auf GPUs:
  • Im Gegensatz zu PyTorch-Arbeitslasten auf GPUs hat JAX eine globale Ansicht der verfügbaren Geräte im Cluster.

Multislice-JAX-Arbeitslast ausführen

Mit Multislice können Sie Arbeitslasten ausführen, die mehrere TPU-Slices innerhalb eines einzelnen TPU-Pods oder in mehreren Pods über das Rechenzentrumsnetzwerk umfassen.

Mit dem Paket ray-tpu können Sie die Interaktionen von Ray mit TPU-Scheiben vereinfachen. Installieren Sie ray-tpu mit pip:

pip install ray-tpu

Das folgende Beispielskript zeigt, wie Sie mit dem Paket ray-tpu Multislice-Arbeitslasten mit Ray-Actors oder ‑Tasks ausführen:

from ray_tpu import RayTpuManager
import jax
import ray

ray.init()

# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
    def get_devices(self):
        return jax.device_count()

# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
    return jax.device_count()

tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus) 
"""
TPU resources:
{'v6e-16': [
    RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
    RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""

# if using actors
actors = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=MyActor,
    multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]

# if using tasks
h = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]

# note - you can also run this without Multislice
h = RayTpuManager.run_task(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]

Arbeitslasten mit Ray und MaxText orchestrieren

In diesem Abschnitt wird beschrieben, wie Sie mit Ray Arbeitslasten mit MaxText orchestrieren, einer skalierbaren und leistungsstarken Open-Source-Bibliothek zum Trainieren von LLMs mit JAX und XLA.

MaxText enthält ein Trainingsskript, train.py, das auf jedem TPU-Host ausgeführt werden muss. Das ist vergleichbar mit anderen SPMD-Arbeitslasten für maschinelles Lernen. Dazu können Sie das Paket ray-tpu verwenden und einen Wrapper um die Hauptfunktion train.py erstellen. In den folgenden Schritten wird gezeigt, wie Sie das ray-tpu-Paket verwenden, um MaxText auf einer TPU v4-16 auszuführen.

  1. Legen Sie Umgebungsvariablen für die Parameter zum Erstellen von TPUs fest:

    export TPU_NAME=TPU_NAME
    export ZONE=ZONE
    export ACCELERATOR_TYPE=v6e-16
    export VERSION=v2-alpha-tpuv6e
  2. TPU v6e-16 erstellen:

    gcloud compute tpus tpu-vm create $TPU_NAME \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE \
        --version=$VERSION
  3. Klonen Sie das MaxText-Repository auf allen TPU-Arbeitsstationen:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=all \
        --command="git clone https://github.com/AI-Hypercomputer/maxtext"
  4. Installieren Sie die MaxText-Anforderungen auf allen TPU-Workern:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=all \
        --command="pip install -r maxtext/requirements.txt"
  5. Installieren Sie das ray-tpu-Paket auf allen TPU-Workern:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=all \
        --command="pip install ray-tpu"
  6. Stellen Sie über SSH eine Verbindung zu Worker 0 her:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=0
  7. Speichern Sie das folgende Script im Verzeichnis ~/maxtext/MaxText in einer Datei mit dem Namen ray_trainer.py. Dieses Script verwendet das Paket ray-tpu und erstellt einen Wrapper um die Hauptfunktion train.py von MaxText.

    import ray
    import ray_tpu
    from train import main as maxtext_main
    
    import logging
    from typing import Sequence
    from absl import app
    
    # Default env vars that run on all TPU VMs.
    MACHINE_ENV_VARS = {
        "ENABLE_PJRT_COMPATIBILITY": "true",
        "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true",
        "TPU_SLICE_BUILDER_DUMP_ICI": "true",
        "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",  # Dumps HLOs for debugging
    }
    
    def setup_loggers():
        """Sets up loggers for Ray."""
        logging.basicConfig(level=logging.INFO)
    
    @ray_tpu.remote(
        topology={"v4-16": 1},
    )
    def run_maxtext_train(argv: Sequence[str]):
        maxtext_main(argv=argv)
    
    def main(argv: Sequence[str]):
        ray.init(runtime_env=dict(worker_process_setup_hook=setup_loggers))
    
        logging.info(f"argv: {argv}")
    
        try:
            ray.get(run_maxtext_train(argv=argv))
        except Exception as e:
            logging.error("Caught error during training: %s", e)
            logging.error("Shutting down...")
            ray.shutdown()
            raise e
    
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        app.run(main)
    
  8. Führen Sie das Skript aus, indem Sie den folgenden Befehl ausführen:

        python maxtext/MaxText/ray_trainer.py maxtext/MaxText/configs/base.yml \
            base_output_directory=/tmp/maxtext \
            dataset_type=synthetic \
            per_device_batch_size=2 \
            max_target_length=8192 \
            model_name=default \
            steps=100 \
            run_name=test
    

    Die Ausgabe sieht dann ungefähr so aus:

    (run_maxtext_train pid=78967, ip=10.130.0.11) Started an asynchronous checkpoint save for step 0
    (run_maxtext_train pid=78967, ip=10.130.0.11)
    (run_maxtext_train pid=78967, ip=10.130.0.11) Memstats: After params initialized:
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_4(process=1,(0,0,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_5(process=1,(1,0,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_6(process=1,(0,1,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_7(process=1,(1,1,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 0, seconds: 11.775, TFLOP/s/device: 13.153, Tokens/s/device: 1391.395, total_weights: 131072, loss: 12.066
    (run_maxtext_train pid=80538, ip=10.130.0.12)
    (run_maxtext_train pid=80538, ip=10.130.0.12) To see full metrics 'tensorboard --logdir=/tmp/maxtext/test/tensorboard/'
    (run_maxtext_train pid=80538, ip=10.130.0.12) Waiting for step 0 to finish before checkpoint...
    (run_maxtext_train pid=80538, ip=10.130.0.12) Waited 0.7087039947509766 seconds for step 0 to finish before starting checkpointing.
    (run_maxtext_train pid=80538, ip=10.130.0.12) Started an asynchronous checkpoint save for step 0
    (run_maxtext_train pid=80538, ip=10.130.0.12) Memstats: After params initialized:
    (run_maxtext_train pid=80538, ip=10.130.0.12)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_3(process=0,(1,1,0,0)) [repeated 4x across cluster]
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 4, seconds: 1.116, TFLOP/s/device: 138.799, Tokens/s/device: 14683.240, total_weights: 131072, loss: 0.000 [repeated 9x across cluster]
    (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 9, seconds: 1.068, TFLOP/s/device: 145.065, Tokens/s/device: 15346.083, total_weights: 131072, loss: 0.000 [repeated 9x across cluster]
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 14, seconds: 1.116, TFLOP/s/device: 138.754, Tokens/s/device: 14678.439, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    
    ...
    
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 89, seconds: 1.116, TFLOP/s/device: 138.760, Tokens/s/device: 14679.083, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 94, seconds: 1.091, TFLOP/s/device: 141.924, Tokens/s/device: 15013.837, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 99, seconds: 1.116, TFLOP/s/device: 138.763, Tokens/s/device: 14679.412, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    (run_maxtext_train pid=80538, ip=10.130.0.12) Output size: 1657041920, temp size: 4907988480, argument size: 1657366016, host temp size: 0, in bytes.
    I0121 01:39:46.830807 130655182204928 ray_trainer.py:47] Training complete!
    (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 99, seconds: 1.191, TFLOP/s/device: 130.014, Tokens/s/device: 13753.874, total_weights: 131072, loss: 0.000
    

TPU- und Ray-Ressourcen

Ray behandelt TPUs anders als GPUs, um den Unterschied bei der Nutzung zu berücksichtigen. Im folgenden Beispiel gibt es insgesamt neun Ray-Knoten:

  • Der Ray-Leitknoten wird auf einer n1-standard-16-VM ausgeführt.
  • Die Ray-Worker-Knoten werden auf zwei v6e-16 TPUs ausgeführt. Jede TPU besteht aus vier Workern.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
 1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
 1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
 1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
 1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
 1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
 1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
 1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
 1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
 1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/727.0 CPU
 0.0/32.0 TPU
 0.0/2.0 TPU-v6e-16-head
 0B/5.13TiB memory
 0B/1.47TiB object_store_memory
 0.0/4.0 tpu-group-0
 0.0/4.0 tpu-group-1

Demands:
 (no resource demands)

Beschreibungen der Felder zur Ressourcennutzung:

  • CPU: Die Gesamtzahl der im Cluster verfügbaren CPUs.
  • TPU: Die Anzahl der TPU-Chips im Cluster.
  • TPU-v6e-16-head: Eine spezielle Kennung für die Ressource, die Worker 0 eines TPU-Slabs entspricht. Das ist wichtig für den Zugriff auf einzelne TPU-Scheiben.
  • memory: Der von Ihrer Anwendung verwendete Worker-Heap-Speicher.
  • object_store_memory: Speicher, der verwendet wird, wenn Ihre Anwendung Objekte im Objektspeicher mit ray.put erstellt und Werte von Remotefunktionen zurückgibt.
  • tpu-group-0 und tpu-group-1: Eindeutige Kennungen für die einzelnen TPU-Scheiben. Das ist wichtig, wenn Jobs auf Segmenten ausgeführt werden sollen. Diese Felder sind auf „4“ festgelegt, da es in einem v6e-16 vier Hosts pro TPU-Speichere gibt.