Scala i workload ML utilizzando Ray

Questo documento fornisce dettagli su come eseguire carichi di lavoro di machine learning (ML) con Ray e JAX su TPU. Esistono due modalità diverse per utilizzare le TPU con Ray: modalità incentrata sul dispositivo (PyTorch/XLA) e modalità incentrata sull'host (JAX).

Questo documento presuppone che tu abbia già configurato un ambiente TPU. Per maggiori informazioni, consulta le seguenti risorse:

Modalità incentrata sul dispositivo (PyTorch/XLA)

La modalità incentrata sul dispositivo mantiene gran parte dello stile programmatico del PyTorch classico. In questa modalità, aggiungi un nuovo tipo di dispositivo XLA, che funziona come qualsiasi altro dispositivo PyTorch. Ogni singolo processo interagisce con un dispositivo XLA.

Questa modalità è ideale se hai già dimestichezza con PyTorch con GPU e vuoi utilizzare astrazioni di programmazione simili.

Le seguenti sezioni descrivono come eseguire un carico di lavoro PyTorch/XLA su uno o più dispositivi senza utilizzare Ray e come eseguire lo stesso carico di lavoro su più host utilizzando Ray.

Crea una TPU

  1. Crea variabili di ambiente per i parametri di creazione della TPU:

    export TPU_NAME=TPU_NAME
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-8
    export VERSION=v2-alpha-tpuv5

    Descrizioni delle variabili di ambiente

    TPU_NAME
    Il nome del nuovo Cloud TPU.
    ZONE
    La zona in cui creare la Cloud TPU.
    accelerator-type
    Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per ulteriori informazioni, consulta Versioni TPU.
    version
    La versione software della TPU che vuoi utilizzare. Per ulteriori informazioni, consulta Immagini VM TPU.
  2. Utilizza il seguente comando per creare una VM TPU v5p con 8 core:

    gcloud compute tpus tpu-vm create $TPU_NAME \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE  \
        --version=$VERSION
  3. Connettiti alla VM TPU utilizzando il seguente comando:

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE

Se utilizzi GKE, consulta la guida KubeRay su GKE per informazioni sulla configurazione.

Requisiti di installazione

Esegui i seguenti comandi sulla VM TPU per installare le dipendenze richieste:

  1. Salva quanto segue in un file, ad esempio requirements.txt:

    --find-links https://storage.googleapis.com/libtpu-releases/index.html
    --find-links https://storage.googleapis.com/libtpu-wheels/index.html
    torch~=2.6.0
    torch_xla[tpu]~=2.6.0
    ray[default]==2.40.0
    
  2. Esegui il seguente comando per installare le dipendenze richieste:

    pip install -r requirements.txt
    

Se esegui il tuo carico di lavoro su GKE, ti consigliamo di creare un file Dockerfile che installi le dipendenze richieste. Per un esempio, consulta Eseguire il carico di lavoro sui nodi dei sezioni TPU nella documentazione di GKE.

Esegui un carico di lavoro PyTorch/XLA su un singolo dispositivo

L'esempio seguente mostra come creare un tensore XLA su un singolo dispositivo, ovvero un chip TPU. È simile al modo in cui PyTorch gestisce altri tipi di dispositivi.

  1. Salva il seguente snippet di codice in un file, ad esempio workload.py:

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    t = torch.randn(2, 2, device=xm.xla_device())
    print(t.device)
    print(t)
    

    L'istruzione di importazione import torch_xla inizializza PyTorch/XLA e la funzione xm.xla_device() restituisce il dispositivo XLA corrente, un chip TPU.

  2. Imposta la variabile di ambiente PJRT_DEVICE su TPU:

    export PJRT_DEVICE=TPU
    
  3. Esegui lo script:

    python workload.py
    

    L'output è simile al seguente. Assicurati che l'output indicate che il dispositivo XLA è stato trovato.

    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

Esegui PyTorch/XLA su più dispositivi

  1. Aggiorna lo snippet di codice della sezione precedente in modo che venga eseguito su più dispositivi:

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    def _mp_fn(index):
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    if __name__ == '__main__':
        torch_xla.launch(_mp_fn, args=())
    
  2. Esegui lo script:

    python workload.py
    

    Se esegui lo snippet di codice su una TPU v5p-8, l'output è simile al seguente:

    xla:0
    xla:0
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    

torch_xla.launch() accetta due argomenti, una funzione e un elenco di parametri. Crea un processo per ogni dispositivo XLA disponibile e chiama la funzione specificata negli argomenti. In questo esempio sono disponibili 4 dispositivi TPU, quindi torch_xla.launch() crea 4 processi e chiama _mp_fn() su ogni dispositivo. Ogni processo ha accesso a un solo dispositivo, quindi ogni dispositivo ha l'indice 0 e viene stampato xla:0 per tutti i processi.

Esegui PyTorch/XLA su più host con Ray

Le sezioni seguenti mostrano come eseguire lo stesso snippet di codice su uno slice TPU multi-host più grande. Per ulteriori informazioni sull'architettura TPU multi-host, consulta Architettura di sistema.

In questo esempio, configurerai manualmente Ray. Se hai già dimestichezza con la configurazione di Ray, puoi passare all'ultima sezione, Eseguire un carico di lavoro Ray. Per saperne di più sulla configurazione di Ray per un ambiente di produzione, consulta le seguenti risorse:

Crea una VM TPU multi-host

  1. Crea le variabili di ambiente per i parametri di creazione della TPU:

    export TPU_NAME_MULTIHOST=TPU_NAME_MULTIHOST
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE_MULTIHOST=v5p-16
    export VERSION=v2-alpha-tpuv5
  2. Crea una TPU v5p multi-host con 2 host (un v5p-16 con 4 chip TPU su ogni host) utilizzando il seguente comando:

    gcloud compute tpus tpu-vm create $TPU_NAME_MULTIHOST \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE_MULTIHOST \
        --version=$VERSION

Configurare Ray

Una TPU v5p-16 ha 2 host TPU, ciascuno con 4 chip TPU. In questo esempio, avvia il nodo principale Ray su un host e aggiungi il secondo host come nodo worker al cluster Ray.

  1. Connettiti al primo host tramite SSH:

    gcloud compute tpus tpu-vm ssh $TPU_NAME_MULTIHOST --zone=$ZONE --worker=0
  2. Installa le dipendenze con lo stesso file requirements della sezione Install requirements:

    pip install -r requirements.txt
    
  3. Avvia il processo Ray:

    ray start --head --port=6379
    

    L'output è simile al seguente:

    Enable usage stats collection? This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup. Confirm [Y/n]: y
    Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details.
    
    Local node IP: 10.130.0.76
    
    --------------------
    Ray runtime started.
    --------------------
    
    Next steps
    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    
    To connect to this Ray cluster:
        import ray
        ray.init()
    
    To terminate the Ray runtime, run
        ray stop
    
    To view the status of the cluster, use
        ray status
    

    Questo host TPU è ora il nodo principale di Ray. Prendi nota delle righe che mostrano come aggiungere un altro nodo al cluster Ray, in modo simile al seguente:

    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    

    Utilizzerai questo comando in un passaggio successivo.

  4. Controlla lo stato del cluster Ray:

    ray status
    

    L'output è simile al seguente:

    ======== Autoscaler status: 2025-01-14 22:03:39.385610 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/208.0 CPU
    0.0/4.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/268.44GiB memory
    0B/119.04GiB object_store_memory
    0.0/1.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    Il cluster contiene solo 4 TPU (0.0/4.0 TPU) perché finora hai aggiunto solo il nodo principale.

Ora che il nodo principale è in esecuzione, puoi aggiungere il secondo host al cluster.

  1. Connettiti al secondo host tramite SSH:

    gcloud compute tpus tpu-vm ssh $TPU_NAME_MULTIHOST --zone=$ZONE --worker=1
  2. Installa le dipendenze con lo stesso file requirements della sezione Install requirements:

    pip install -r requirements.txt
    
  3. Avvia il processo Ray. Utilizza il comando nell'output del comando ray start per aggiungere questo nodo al cluster Ray esistente. Assicurati di sostituire l'indirizzo IP e la porta nel comando seguente:

    ray start --address='10.130.0.76:6379'

    L'output è simile al seguente:

    Local node IP: 10.130.0.80
    [2025-01-14 22:30:07,397 W 75572 75572] global_state_accessor.cc:463: Retrying to get node with node ID 35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    
    --------------------
    Ray runtime started.
    --------------------
    
    To terminate the Ray runtime, run
    ray stop
    
  4. Controlla di nuovo lo stato di Ray:

    ray status
    

    L'output è simile al seguente:

    ======== Autoscaler status: 2025-01-14 22:45:21.485617 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    1 node_35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/416.0 CPU
    0.0/8.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/546.83GiB memory
    0B/238.35GiB object_store_memory
    0.0/2.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    Il secondo host TPU è ora un nodo del cluster. L'elenco delle risorse disponibili ora mostra 8 TPU (0.0/8.0 TPU).

Esegui un carico di lavoro Ray

  1. Aggiorna lo snippet di codice da eseguire sul cluster Ray:

    import os
    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    import ray
    
    import torch.distributed as dist
    import torch_xla.runtime as xr
    from torch_xla._internal import pjrt
    
    # Defines the local PJRT world size, the number of processes per host
    LOCAL_WORLD_SIZE = 4
    # Defines the number of hosts in the Ray cluster
    NUM_OF_HOSTS = 2
    GLOBAL_WORLD_SIZE = LOCAL_WORLD_SIZE * NUM_OF_HOSTS
    
    def init_env():
        local_rank = int(os.environ['TPU_VISIBLE_CHIPS'])
    
        pjrt.initialize_multiprocess(local_rank, LOCAL_WORLD_SIZE)
        xr._init_world_size_ordinal()
    
    # This decorator signals to Ray that the print_tensor() function should be run on a single TPU chip
    @ray.remote(resources={"TPU": 1})
    def print_tensor():
        # Initializes the runtime environment on each Ray worker. Equivalent to
        # the `torch_xla.launch call` in the Run PyTorch/XLA on multiple devices section.
        init_env()
    
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    ray.init()
    
    # Uses Ray to dispatch the function call across available nodes in the cluster
    tasks = [print_tensor.remote() for _ in range(GLOBAL_WORLD_SIZE)] 
    ray.get(tasks)
    
    ray.shutdown()
    
  2. Esegui lo script sul nodo principale di Ray. Sostituisci ray-workload.py con il percorso dello script.

    python ray-workload.py

    L'output è simile al seguente:

    WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
    xla:0
    xla:0
    xla:0
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

    L'output indica che la funzione è stata chiamata correttamente su ogni dispositivo XLA (8 dispositivi in questo esempio) nello slice TPU multi-host.

Modalità incentrata sull'host (JAX)

Le sezioni seguenti descrivono la modalità host-centric con JAX. JAX utilizza un paradigma di programmazione funzionale e supporta la semantica SPMD (single program, multiple data) di livello superiore. Anziché fare in modo che ogni processo interagisca con un singolo dispositivo XLA, il codice JAX è progettato per funzionare su più dispositivi contemporaneamente su un singolo host.

JAX è progettato per il calcolo ad alte prestazioni e può utilizzare in modo efficiente le TPU per l'addestramento e l'inferenza su larga scala. Questa modalità è ideale se hai familiarità con i concetti di programmazione funzionale, in modo da poter sfruttare tutto il potenziale di JAX.

Queste istruzioni presuppongono che tu abbia già configurato un ambiente Ray e TPU, incluso un ambiente software che includa JAX e altri pacchetti correlati. Per creare un cluster Ray TPU, segui le istruzioni riportate in Avvia Google Cloud un cluster GKE con TPU per KubeRay. Per ulteriori informazioni sull'utilizzo delle TPU con KubeRay, consulta Utilizzare le TPU con KubeRay.

Esegui un carico di lavoro JAX su una TPU con un solo host

Lo script di esempio riportato di seguito mostra come eseguire una funzione JAX su un cluster Ray con una TPU a un solo host, ad esempio una v6e-4. Se hai una TPU multi-host, questo script non risponde a causa del modello di esecuzione con più controller di JAX. Per ulteriori informazioni sull'esecuzione di Ray su una TPU multi-host, consulta Eseguire un carico di lavoro JAX su una TPU multi-host.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

h = my_function.remote()
print(ray.get(h)) # => 4

Se sei abituato a eseguire Ray con GPU, esistono alcune differenze chiave quando utilizzi le TPU:

  • Anziché impostare num_gpus, specifica TPU come risorsa personalizzata e imposta il numero di chip TPU.
  • Specifica la TPU utilizzando il numero di chip per nodo worker Ray. Ad esempio, se utilizzi una v6e-4, l'esecuzione di una funzione remota con TPU impostato su 4 consuma l'intero host TPU.
    • Questo è diverso dal modo in cui in genere vengono eseguite le GPU, con un processo per host. L'impostazione di TPU su un numero diverso da 4 non è consigliata.
    • Eccezione: se hai un v6e-8 o un v5litepod-8 a host singolo, devi impostare questo valore su 8.

Esegui un carico di lavoro JAX su una TPU multi-host

Lo script di esempio seguente mostra come eseguire una funzione JAX su un cluster Ray con una TPU multi-host. Lo script di esempio utilizza una versione v6e-16.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]

Se sei abituato a eseguire Ray con GPU, esistono alcune differenze chiave quando utilizzi le TPU:

  • Simile ai carichi di lavoro PyTorch sulle GPU:
  • A differenza dei carichi di lavoro PyTorch sulle GPU, JAX ha una vista globale dei dispositivi disponibili nel cluster.

Esegui un carico di lavoro JAX multislice

Multislice ti consente di eseguire carichi di lavoro che coprono più slice TPU all'interno di un singolo pod TPU o in più pod sulla rete del data center.

Puoi utilizzare il pacchetto ray-tpu per semplificare le interazioni di Ray con i sezioni TPU. Installa ray-tpu utilizzando pip:

pip install ray-tpu

Lo script di esempio seguente mostra come utilizzare il pacchetto ray-tpu per eseguire carichi di lavoro multislice utilizzando gli attori o le attività Ray:

from ray_tpu import RayTpuManager
import jax
import ray

ray.init()

# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
    def get_devices(self):
        return jax.device_count()

# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
    return jax.device_count()

tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus) 
"""
TPU resources:
{'v6e-16': [
    RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
    RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""

# if using actors
actors = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=MyActor,
    multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]

# if using tasks
h = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]

# note - you can also run this without Multislice
h = RayTpuManager.run_task(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]

Orchestra i carichi di lavoro utilizzando Ray e MaxText

Questa sezione descrive come utilizzare Ray per orchestrare i carichi di lavoro utilizzando MaxText, una libreria open source scalabile e ad alte prestazioni per l'addestramento di LLM utilizzando JAX e XLA.

MaxText contiene uno script di addestramento, train.py, che deve essere eseguito su ogni host TPU. È simile ad altri carichi di lavoro di machine learning SPMD. Puoi ottenere questo risultato utilizzando il pacchetto ray-tpu e creando un wrapper attorno alla funzione principale train.py. I passaggi riportati di seguito mostrano come utilizzare il pacchetto ray-tpu per eseguire MaxText su una TPU v4-16.

  1. Imposta le variabili di ambiente per i parametri di creazione della TPU:

    export TPU_NAME=TPU_NAME
    export ZONE=ZONE
    export ACCELERATOR_TYPE=v6e-16
    export VERSION=v2-alpha-tpuv6e
  2. Crea una TPU v6e-16:

    gcloud compute tpus tpu-vm create $TPU_NAME \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE \
        --version=$VERSION
  3. Clona il repository MaxText su tutti i worker TPU:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=all \
        --command="git clone https://github.com/AI-Hypercomputer/maxtext"
  4. Installa i requisiti di MaxText su tutti i worker TPU:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=all \
        --command="pip install -r maxtext/requirements.txt"
  5. Installa il pacchetto ray-tpu su tutti i worker TPU:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=all \
        --command="pip install ray-tpu"
  6. Connettiti al worker 0 tramite SSH:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=0
  7. Salva il seguente script in un file denominato ray_trainer.py nella directory ~/maxtext/MaxText. Questo script utilizza il pacchetto ray-tpu e crea un wrapper per la funzione principale train.py di MaxText.

    import ray
    import ray_tpu
    from train import main as maxtext_main
    
    import logging
    from typing import Sequence
    from absl import app
    
    # Default env vars that run on all TPU VMs.
    MACHINE_ENV_VARS = {
        "ENABLE_PJRT_COMPATIBILITY": "true",
        "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true",
        "TPU_SLICE_BUILDER_DUMP_ICI": "true",
        "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",  # Dumps HLOs for debugging
    }
    
    def setup_loggers():
        """Sets up loggers for Ray."""
        logging.basicConfig(level=logging.INFO)
    
    @ray_tpu.remote(
        topology={"v4-16": 1},
    )
    def run_maxtext_train(argv: Sequence[str]):
        maxtext_main(argv=argv)
    
    def main(argv: Sequence[str]):
        ray.init(runtime_env=dict(worker_process_setup_hook=setup_loggers))
    
        logging.info(f"argv: {argv}")
    
        try:
            ray.get(run_maxtext_train(argv=argv))
        except Exception as e:
            logging.error("Caught error during training: %s", e)
            logging.error("Shutting down...")
            ray.shutdown()
            raise e
    
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        app.run(main)
    
  8. Esegui lo script eseguendo il seguente comando:

        python maxtext/MaxText/ray_trainer.py maxtext/MaxText/configs/base.yml \
            base_output_directory=/tmp/maxtext \
            dataset_type=synthetic \
            per_device_batch_size=2 \
            max_target_length=8192 \
            model_name=default \
            steps=100 \
            run_name=test
    

    L'output è simile al seguente:

    (run_maxtext_train pid=78967, ip=10.130.0.11) Started an asynchronous checkpoint save for step 0
    (run_maxtext_train pid=78967, ip=10.130.0.11)
    (run_maxtext_train pid=78967, ip=10.130.0.11) Memstats: After params initialized:
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_4(process=1,(0,0,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_5(process=1,(1,0,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_6(process=1,(0,1,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_7(process=1,(1,1,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 0, seconds: 11.775, TFLOP/s/device: 13.153, Tokens/s/device: 1391.395, total_weights: 131072, loss: 12.066
    (run_maxtext_train pid=80538, ip=10.130.0.12)
    (run_maxtext_train pid=80538, ip=10.130.0.12) To see full metrics 'tensorboard --logdir=/tmp/maxtext/test/tensorboard/'
    (run_maxtext_train pid=80538, ip=10.130.0.12) Waiting for step 0 to finish before checkpoint...
    (run_maxtext_train pid=80538, ip=10.130.0.12) Waited 0.7087039947509766 seconds for step 0 to finish before starting checkpointing.
    (run_maxtext_train pid=80538, ip=10.130.0.12) Started an asynchronous checkpoint save for step 0
    (run_maxtext_train pid=80538, ip=10.130.0.12) Memstats: After params initialized:
    (run_maxtext_train pid=80538, ip=10.130.0.12)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_3(process=0,(1,1,0,0)) [repeated 4x across cluster]
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 4, seconds: 1.116, TFLOP/s/device: 138.799, Tokens/s/device: 14683.240, total_weights: 131072, loss: 0.000 [repeated 9x across cluster]
    (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 9, seconds: 1.068, TFLOP/s/device: 145.065, Tokens/s/device: 15346.083, total_weights: 131072, loss: 0.000 [repeated 9x across cluster]
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 14, seconds: 1.116, TFLOP/s/device: 138.754, Tokens/s/device: 14678.439, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    
    ...
    
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 89, seconds: 1.116, TFLOP/s/device: 138.760, Tokens/s/device: 14679.083, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 94, seconds: 1.091, TFLOP/s/device: 141.924, Tokens/s/device: 15013.837, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 99, seconds: 1.116, TFLOP/s/device: 138.763, Tokens/s/device: 14679.412, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    (run_maxtext_train pid=80538, ip=10.130.0.12) Output size: 1657041920, temp size: 4907988480, argument size: 1657366016, host temp size: 0, in bytes.
    I0121 01:39:46.830807 130655182204928 ray_trainer.py:47] Training complete!
    (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 99, seconds: 1.191, TFLOP/s/device: 130.014, Tokens/s/device: 13753.874, total_weights: 131072, loss: 0.000
    

Risorse TPU e Ray

Ray tratta le TPU in modo diverso dalle GPU per tenere conto della differenza di utilizzo. Nell'esempio seguente sono presenti in totale nove nodi Ray:

  • Il nodo principale di Ray è in esecuzione su una VM n1-standard-16.
  • I nodi worker di Ray sono in esecuzione su due TPU v6e-16. Ogni TPU è costituita da quattro worker.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
 1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
 1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
 1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
 1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
 1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
 1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
 1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
 1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
 1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/727.0 CPU
 0.0/32.0 TPU
 0.0/2.0 TPU-v6e-16-head
 0B/5.13TiB memory
 0B/1.47TiB object_store_memory
 0.0/4.0 tpu-group-0
 0.0/4.0 tpu-group-1

Demands:
 (no resource demands)

Descrizioni dei campi di utilizzo delle risorse:

  • CPU: il numero totale di CPU disponibili nel cluster.
  • TPU: il numero di chip TPU nel cluster.
  • TPU-v6e-16-head: un identificatore speciale per la risorsa corrispondente al worker 0 di una sezione TPU. Questo è importante per accedere alle singole sezioni di TPU.
  • memory: memoria heap del worker utilizzata dall'applicazione.
  • object_store_memory: memoria utilizzata quando l'applicazione crea oggetti nell'object store utilizzando ray.put e quando restituisce valori dalle funzioni remote.
  • tpu-group-0 e tpu-group-1: identificatori univoci per le singole sezioni TPU. Questo è importante per l'esecuzione di job su slice. Questi campi sono impostati su 4 perché in una v6e-16 sono presenti 4 host per ogni slice TPU.