Scala i workload ML utilizzando Ray
Questo documento fornisce dettagli su come eseguire carichi di lavoro di machine learning (ML) con Ray e JAX su TPU. Esistono due modalità diverse per utilizzare le TPU con Ray: modalità incentrata sul dispositivo (PyTorch/XLA) e modalità incentrata sull'host (JAX).
Questo documento presuppone che tu abbia già configurato un ambiente TPU. Per maggiori informazioni, consulta le seguenti risorse:
- Cloud TPU: configura l'ambiente Cloud TPU e gestisci le risorse TPU
- Google Kubernetes Engine (GKE): esegui il deployment dei carichi di lavoro TPU in GKE Autopilot o esegui il deployment dei carichi di lavoro TPU in GKE Standard
Modalità incentrata sul dispositivo (PyTorch/XLA)
La modalità incentrata sul dispositivo mantiene gran parte dello stile programmatico del PyTorch classico. In questa modalità, aggiungi un nuovo tipo di dispositivo XLA, che funziona come qualsiasi altro dispositivo PyTorch. Ogni singolo processo interagisce con un dispositivo XLA.
Questa modalità è ideale se hai già dimestichezza con PyTorch con GPU e vuoi utilizzare astrazioni di programmazione simili.
Le seguenti sezioni descrivono come eseguire un carico di lavoro PyTorch/XLA su uno o più dispositivi senza utilizzare Ray e come eseguire lo stesso carico di lavoro su più host utilizzando Ray.
Crea una TPU
Crea variabili di ambiente per i parametri di creazione della TPU:
export TPU_NAME=TPU_NAME export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5p-8 export VERSION=v2-alpha-tpuv5
Descrizioni delle variabili di ambiente
TPU_NAME
- Il nome del nuovo Cloud TPU.
ZONE
- La zona in cui creare la Cloud TPU.
accelerator-type
- Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per ulteriori informazioni, consulta Versioni TPU.
version
- La versione software della TPU che vuoi utilizzare. Per ulteriori informazioni, consulta Immagini VM TPU.
Utilizza il seguente comando per creare una VM TPU v5p con 8 core:
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$VERSION
Connettiti alla VM TPU utilizzando il seguente comando:
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
Se utilizzi GKE, consulta la guida KubeRay su GKE per informazioni sulla configurazione.
Requisiti di installazione
Esegui i seguenti comandi sulla VM TPU per installare le dipendenze richieste:
Salva quanto segue in un file, ad esempio
requirements.txt
:--find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/libtpu-wheels/index.html torch~=2.6.0 torch_xla[tpu]~=2.6.0 ray[default]==2.40.0
Esegui il seguente comando per installare le dipendenze richieste:
pip install -r requirements.txt
Se esegui il tuo carico di lavoro su GKE, ti consigliamo di creare un file Dockerfile che installi le dipendenze richieste. Per un esempio, consulta Eseguire il carico di lavoro sui nodi dei sezioni TPU nella documentazione di GKE.
Esegui un carico di lavoro PyTorch/XLA su un singolo dispositivo
L'esempio seguente mostra come creare un tensore XLA su un singolo dispositivo, ovvero un chip TPU. È simile al modo in cui PyTorch gestisce altri tipi di dispositivi.
Salva il seguente snippet di codice in un file, ad esempio
workload.py
:import torch import torch_xla import torch_xla.core.xla_model as xm t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t)
L'istruzione di importazione
import torch_xla
inizializza PyTorch/XLA e la funzionexm.xla_device()
restituisce il dispositivo XLA corrente, un chip TPU.Imposta la variabile di ambiente
PJRT_DEVICE
su TPU:export PJRT_DEVICE=TPU
Esegui lo script:
python workload.py
L'output è simile al seguente. Assicurati che l'output indicate che il dispositivo XLA è stato trovato.
xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0')
Esegui PyTorch/XLA su più dispositivi
Aggiorna lo snippet di codice della sezione precedente in modo che venga eseguito su più dispositivi:
import torch import torch_xla import torch_xla.core.xla_model as xm def _mp_fn(index): t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t) if __name__ == '__main__': torch_xla.launch(_mp_fn, args=())
Esegui lo script:
python workload.py
Se esegui lo snippet di codice su una TPU v5p-8, l'output è simile al seguente:
xla:0 xla:0 xla:0 tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') xla:0 tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0')
torch_xla.launch()
accetta due argomenti, una funzione e un elenco di parametri.
Crea un processo per ogni dispositivo XLA disponibile e chiama la funzione specificata negli argomenti. In questo esempio sono disponibili 4 dispositivi TPU, quindi torch_xla.launch()
crea 4 processi e chiama _mp_fn()
su ogni dispositivo.
Ogni processo ha accesso a un solo dispositivo, quindi ogni dispositivo ha l'indice 0 e viene stampato xla:0
per tutti i processi.
Esegui PyTorch/XLA su più host con Ray
Le sezioni seguenti mostrano come eseguire lo stesso snippet di codice su uno slice TPU multi-host più grande. Per ulteriori informazioni sull'architettura TPU multi-host, consulta Architettura di sistema.
In questo esempio, configurerai manualmente Ray. Se hai già dimestichezza con la configurazione di Ray, puoi passare all'ultima sezione, Eseguire un carico di lavoro Ray. Per saperne di più sulla configurazione di Ray per un ambiente di produzione, consulta le seguenti risorse:
Crea una VM TPU multi-host
Crea le variabili di ambiente per i parametri di creazione della TPU:
export TPU_NAME_MULTIHOST=TPU_NAME_MULTIHOST export ZONE=europe-west4-b export ACCELERATOR_TYPE_MULTIHOST=v5p-16 export VERSION=v2-alpha-tpuv5
Crea una TPU v5p multi-host con 2 host (un v5p-16 con 4 chip TPU su ogni host) utilizzando il seguente comando:
gcloud compute tpus tpu-vm create $TPU_NAME_MULTIHOST \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE_MULTIHOST \ --version=$VERSION
Configurare Ray
Una TPU v5p-16 ha 2 host TPU, ciascuno con 4 chip TPU. In questo esempio, avvia il nodo principale Ray su un host e aggiungi il secondo host come nodo worker al cluster Ray.
Connettiti al primo host tramite SSH:
gcloud compute tpus tpu-vm ssh $TPU_NAME_MULTIHOST --zone=$ZONE --worker=0
Installa le dipendenze con lo stesso file requirements della sezione Install requirements:
pip install -r requirements.txt
Avvia il processo Ray:
ray start --head --port=6379
L'output è simile al seguente:
Enable usage stats collection? This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup. Confirm [Y/n]: y Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details. Local node IP: 10.130.0.76 -------------------- Ray runtime started. -------------------- Next steps To add another node to this Ray cluster, run ray start --address='10.130.0.76:6379' To connect to this Ray cluster: import ray ray.init() To terminate the Ray runtime, run ray stop To view the status of the cluster, use ray status
Questo host TPU è ora il nodo principale di Ray. Prendi nota delle righe che mostrano come aggiungere un altro nodo al cluster Ray, in modo simile al seguente:
To add another node to this Ray cluster, run ray start --address='10.130.0.76:6379'
Utilizzerai questo comando in un passaggio successivo.
Controlla lo stato del cluster Ray:
ray status
L'output è simile al seguente:
======== Autoscaler status: 2025-01-14 22:03:39.385610 ======== Node status --------------------------------------------------------------- Active: 1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79 Pending: (no pending nodes) Recent failures: (no failures) Resources --------------------------------------------------------------- Usage: 0.0/208.0 CPU 0.0/4.0 TPU 0.0/1.0 TPU-v5p-16-head 0B/268.44GiB memory 0B/119.04GiB object_store_memory 0.0/1.0 your-tpu-name Demands: (no resource demands)
Il cluster contiene solo 4 TPU (
0.0/4.0 TPU
) perché finora hai aggiunto solo il nodo principale.
Ora che il nodo principale è in esecuzione, puoi aggiungere il secondo host al cluster.
Connettiti al secondo host tramite SSH:
gcloud compute tpus tpu-vm ssh $TPU_NAME_MULTIHOST --zone=$ZONE --worker=1
Installa le dipendenze con lo stesso file requirements della sezione Install requirements:
pip install -r requirements.txt
Avvia il processo Ray. Utilizza il comando nell'output del comando
ray start
per aggiungere questo nodo al cluster Ray esistente. Assicurati di sostituire l'indirizzo IP e la porta nel comando seguente:ray start --address='10.130.0.76:6379'
L'output è simile al seguente:
Local node IP: 10.130.0.80 [2025-01-14 22:30:07,397 W 75572 75572] global_state_accessor.cc:463: Retrying to get node with node ID 35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1 -------------------- Ray runtime started. -------------------- To terminate the Ray runtime, run ray stop
Controlla di nuovo lo stato di Ray:
ray status
L'output è simile al seguente:
======== Autoscaler status: 2025-01-14 22:45:21.485617 ======== Node status --------------------------------------------------------------- Active: 1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79 1 node_35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1 Pending: (no pending nodes) Recent failures: (no failures) Resources --------------------------------------------------------------- Usage: 0.0/416.0 CPU 0.0/8.0 TPU 0.0/1.0 TPU-v5p-16-head 0B/546.83GiB memory 0B/238.35GiB object_store_memory 0.0/2.0 your-tpu-name Demands: (no resource demands)
Il secondo host TPU è ora un nodo del cluster. L'elenco delle risorse disponibili ora mostra 8 TPU (
0.0/8.0 TPU
).
Esegui un carico di lavoro Ray
Aggiorna lo snippet di codice da eseguire sul cluster Ray:
import os import torch import torch_xla import torch_xla.core.xla_model as xm import ray import torch.distributed as dist import torch_xla.runtime as xr from torch_xla._internal import pjrt # Defines the local PJRT world size, the number of processes per host LOCAL_WORLD_SIZE = 4 # Defines the number of hosts in the Ray cluster NUM_OF_HOSTS = 2 GLOBAL_WORLD_SIZE = LOCAL_WORLD_SIZE * NUM_OF_HOSTS def init_env(): local_rank = int(os.environ['TPU_VISIBLE_CHIPS']) pjrt.initialize_multiprocess(local_rank, LOCAL_WORLD_SIZE) xr._init_world_size_ordinal() # This decorator signals to Ray that the print_tensor() function should be run on a single TPU chip @ray.remote(resources={"TPU": 1}) def print_tensor(): # Initializes the runtime environment on each Ray worker. Equivalent to # the `torch_xla.launch call` in the Run PyTorch/XLA on multiple devices section. init_env() t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t) ray.init() # Uses Ray to dispatch the function call across available nodes in the cluster tasks = [print_tensor.remote() for _ in range(GLOBAL_WORLD_SIZE)] ray.get(tasks) ray.shutdown()
Esegui lo script sul nodo principale di Ray. Sostituisci ray-workload.py con il percorso dello script.
python ray-workload.py
L'output è simile al seguente:
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU. xla:0 xla:0 xla:0 xla:0 xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') xla:0 xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0')
L'output indica che la funzione è stata chiamata correttamente su ogni dispositivo XLA (8 dispositivi in questo esempio) nello slice TPU multi-host.
Modalità incentrata sull'host (JAX)
Le sezioni seguenti descrivono la modalità host-centric con JAX. JAX utilizza un paradigma di programmazione funzionale e supporta la semantica SPMD (single program, multiple data) di livello superiore. Anziché fare in modo che ogni processo interagisca con un singolo dispositivo XLA, il codice JAX è progettato per funzionare su più dispositivi contemporaneamente su un singolo host.
JAX è progettato per il calcolo ad alte prestazioni e può utilizzare in modo efficiente le TPU per l'addestramento e l'inferenza su larga scala. Questa modalità è ideale se hai familiarità con i concetti di programmazione funzionale, in modo da poter sfruttare tutto il potenziale di JAX.
Queste istruzioni presuppongono che tu abbia già configurato un ambiente Ray e TPU, incluso un ambiente software che includa JAX e altri pacchetti correlati. Per creare un cluster Ray TPU, segui le istruzioni riportate in Avvia Google Cloud un cluster GKE con TPU per KubeRay. Per ulteriori informazioni sull'utilizzo delle TPU con KubeRay, consulta Utilizzare le TPU con KubeRay.
Esegui un carico di lavoro JAX su una TPU con un solo host
Lo script di esempio riportato di seguito mostra come eseguire una funzione JAX su un cluster Ray con una TPU a un solo host, ad esempio una v6e-4. Se hai una TPU multi-host, questo script non risponde a causa del modello di esecuzione con più controller di JAX. Per ulteriori informazioni sull'esecuzione di Ray su una TPU multi-host, consulta Eseguire un carico di lavoro JAX su una TPU multi-host.
import ray
import jax
@ray.remote(resources={"TPU": 4})
def my_function() -> int:
return jax.device_count()
h = my_function.remote()
print(ray.get(h)) # => 4
Se sei abituato a eseguire Ray con GPU, esistono alcune differenze chiave quando utilizzi le TPU:
- Anziché impostare
num_gpus
, specificaTPU
come risorsa personalizzata e imposta il numero di chip TPU. - Specifica la TPU utilizzando il numero di chip per nodo worker Ray. Ad esempio, se utilizzi una v6e-4, l'esecuzione di una funzione remota con
TPU
impostato su 4 consuma l'intero host TPU.- Questo è diverso dal modo in cui in genere vengono eseguite le GPU, con un processo per host.
L'impostazione di
TPU
su un numero diverso da 4 non è consigliata. - Eccezione: se hai un
v6e-8
o unv5litepod-8
a host singolo, devi impostare questo valore su 8.
- Questo è diverso dal modo in cui in genere vengono eseguite le GPU, con un processo per host.
L'impostazione di
Esegui un carico di lavoro JAX su una TPU multi-host
Lo script di esempio seguente mostra come eseguire una funzione JAX su un cluster Ray con una TPU multi-host. Lo script di esempio utilizza una versione v6e-16.
import ray
import jax
@ray.remote(resources={"TPU": 4})
def my_function() -> int:
return jax.device_count()
num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]
Se sei abituato a eseguire Ray con GPU, esistono alcune differenze chiave quando utilizzi le TPU:
- Simile ai carichi di lavoro PyTorch sulle GPU:
- I carichi di lavoro JAX sulle TPU vengono eseguiti in un modello con più controller, un singolo programma e più dati (SPMD).
- I collettivi tra dispositivi sono gestiti dal framework di machine learning.
- A differenza dei carichi di lavoro PyTorch sulle GPU, JAX ha una vista globale dei dispositivi disponibili nel cluster.
Esegui un carico di lavoro JAX multislice
Multislice ti consente di eseguire carichi di lavoro che coprono più slice TPU all'interno di un singolo pod TPU o in più pod sulla rete del data center.
Puoi utilizzare il pacchetto
ray-tpu
per semplificare
le interazioni di Ray con i sezioni TPU. Installa ray-tpu
utilizzando pip
:
pip install ray-tpu
Lo script di esempio seguente mostra come utilizzare il pacchetto ray-tpu
per eseguire carichi di lavoro multislice utilizzando gli attori o le attività Ray:
from ray_tpu import RayTpuManager
import jax
import ray
ray.init()
# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
def get_devices(self):
return jax.device_count()
# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
return jax.device_count()
tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus)
"""
TPU resources:
{'v6e-16': [
RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""
# if using actors
actors = RayTpuManager.remote(
tpus=tpus["v6e-16"],
actor_or_fn=MyActor,
multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]
# if using tasks
h = RayTpuManager.remote(
tpus=tpus["v6e-16"],
actor_or_fn=get_devices,
multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]
# note - you can also run this without Multislice
h = RayTpuManager.run_task(
tpus=tpus["v6e-16"],
actor_or_fn=get_devices,
multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]
Orchestra i carichi di lavoro utilizzando Ray e MaxText
Questa sezione descrive come utilizzare Ray per orchestrare i carichi di lavoro utilizzando MaxText, una libreria open source scalabile e ad alte prestazioni per l'addestramento di LLM utilizzando JAX e XLA.
MaxText contiene uno script di addestramento, train.py
, che deve essere eseguito su ogni host TPU. È simile ad altri carichi di lavoro di machine learning SPMD. Puoi ottenere questo risultato utilizzando il pacchetto ray-tpu
e creando un wrapper attorno alla funzione principale train.py
. I passaggi riportati di seguito mostrano come utilizzare il pacchetto ray-tpu
per eseguire MaxText su una TPU v4-16.
Imposta le variabili di ambiente per i parametri di creazione della TPU:
export TPU_NAME=TPU_NAME export ZONE=ZONE export ACCELERATOR_TYPE=v6e-16 export VERSION=v2-alpha-tpuv6e
Crea una TPU v6e-16:
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$VERSION
Clona il repository MaxText su tutti i worker TPU:
gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=all \ --command="git clone https://github.com/AI-Hypercomputer/maxtext"
Installa i requisiti di MaxText su tutti i worker TPU:
gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=all \ --command="pip install -r maxtext/requirements.txt"
Installa il pacchetto
ray-tpu
su tutti i worker TPU:gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=all \ --command="pip install ray-tpu"
Connettiti al worker 0 tramite SSH:
gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=0
Salva il seguente script in un file denominato
ray_trainer.py
nella directory~/maxtext/MaxText
. Questo script utilizza il pacchettoray-tpu
e crea un wrapper per la funzione principaletrain.py
di MaxText.import ray import ray_tpu from train import main as maxtext_main import logging from typing import Sequence from absl import app # Default env vars that run on all TPU VMs. MACHINE_ENV_VARS = { "ENABLE_PJRT_COMPATIBILITY": "true", "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true", "TPU_SLICE_BUILDER_DUMP_ICI": "true", "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto", # Dumps HLOs for debugging } def setup_loggers(): """Sets up loggers for Ray.""" logging.basicConfig(level=logging.INFO) @ray_tpu.remote( topology={"v4-16": 1}, ) def run_maxtext_train(argv: Sequence[str]): maxtext_main(argv=argv) def main(argv: Sequence[str]): ray.init(runtime_env=dict(worker_process_setup_hook=setup_loggers)) logging.info(f"argv: {argv}") try: ray.get(run_maxtext_train(argv=argv)) except Exception as e: logging.error("Caught error during training: %s", e) logging.error("Shutting down...") ray.shutdown() raise e logging.info("Training complete!") ray.shutdown() if __name__ == "__main__": logger = logging.getLogger() logger.setLevel(logging.INFO) app.run(main)
Esegui lo script eseguendo il seguente comando:
python maxtext/MaxText/ray_trainer.py maxtext/MaxText/configs/base.yml \ base_output_directory=/tmp/maxtext \ dataset_type=synthetic \ per_device_batch_size=2 \ max_target_length=8192 \ model_name=default \ steps=100 \ run_name=test
L'output è simile al seguente:
(run_maxtext_train pid=78967, ip=10.130.0.11) Started an asynchronous checkpoint save for step 0 (run_maxtext_train pid=78967, ip=10.130.0.11) (run_maxtext_train pid=78967, ip=10.130.0.11) Memstats: After params initialized: (run_maxtext_train pid=78967, ip=10.130.0.11) Using (GB) 1.59 / 30.75 (5.170732%) on TPU_4(process=1,(0,0,1,0)) (run_maxtext_train pid=78967, ip=10.130.0.11) Using (GB) 1.59 / 30.75 (5.170732%) on TPU_5(process=1,(1,0,1,0)) (run_maxtext_train pid=78967, ip=10.130.0.11) Using (GB) 1.59 / 30.75 (5.170732%) on TPU_6(process=1,(0,1,1,0)) (run_maxtext_train pid=78967, ip=10.130.0.11) Using (GB) 1.59 / 30.75 (5.170732%) on TPU_7(process=1,(1,1,1,0)) (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 0, seconds: 11.775, TFLOP/s/device: 13.153, Tokens/s/device: 1391.395, total_weights: 131072, loss: 12.066 (run_maxtext_train pid=80538, ip=10.130.0.12) (run_maxtext_train pid=80538, ip=10.130.0.12) To see full metrics 'tensorboard --logdir=/tmp/maxtext/test/tensorboard/' (run_maxtext_train pid=80538, ip=10.130.0.12) Waiting for step 0 to finish before checkpoint... (run_maxtext_train pid=80538, ip=10.130.0.12) Waited 0.7087039947509766 seconds for step 0 to finish before starting checkpointing. (run_maxtext_train pid=80538, ip=10.130.0.12) Started an asynchronous checkpoint save for step 0 (run_maxtext_train pid=80538, ip=10.130.0.12) Memstats: After params initialized: (run_maxtext_train pid=80538, ip=10.130.0.12) Using (GB) 1.59 / 30.75 (5.170732%) on TPU_3(process=0,(1,1,0,0)) [repeated 4x across cluster] (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 4, seconds: 1.116, TFLOP/s/device: 138.799, Tokens/s/device: 14683.240, total_weights: 131072, loss: 0.000 [repeated 9x across cluster] (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 9, seconds: 1.068, TFLOP/s/device: 145.065, Tokens/s/device: 15346.083, total_weights: 131072, loss: 0.000 [repeated 9x across cluster] (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 14, seconds: 1.116, TFLOP/s/device: 138.754, Tokens/s/device: 14678.439, total_weights: 131072, loss: 0.000 [repeated 10x across cluster] ... (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 89, seconds: 1.116, TFLOP/s/device: 138.760, Tokens/s/device: 14679.083, total_weights: 131072, loss: 0.000 [repeated 10x across cluster] (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 94, seconds: 1.091, TFLOP/s/device: 141.924, Tokens/s/device: 15013.837, total_weights: 131072, loss: 0.000 [repeated 10x across cluster] (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 99, seconds: 1.116, TFLOP/s/device: 138.763, Tokens/s/device: 14679.412, total_weights: 131072, loss: 0.000 [repeated 10x across cluster] (run_maxtext_train pid=80538, ip=10.130.0.12) Output size: 1657041920, temp size: 4907988480, argument size: 1657366016, host temp size: 0, in bytes. I0121 01:39:46.830807 130655182204928 ray_trainer.py:47] Training complete! (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 99, seconds: 1.191, TFLOP/s/device: 130.014, Tokens/s/device: 13753.874, total_weights: 131072, loss: 0.000
Risorse TPU e Ray
Ray tratta le TPU in modo diverso dalle GPU per tenere conto della differenza di utilizzo. Nell'esempio seguente sono presenti in totale nove nodi Ray:
- Il nodo principale di Ray è in esecuzione su una VM
n1-standard-16
. - I nodi worker di Ray sono in esecuzione su due TPU
v6e-16
. Ogni TPU è costituita da quattro worker.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/727.0 CPU
0.0/32.0 TPU
0.0/2.0 TPU-v6e-16-head
0B/5.13TiB memory
0B/1.47TiB object_store_memory
0.0/4.0 tpu-group-0
0.0/4.0 tpu-group-1
Demands:
(no resource demands)
Descrizioni dei campi di utilizzo delle risorse:
CPU
: il numero totale di CPU disponibili nel cluster.TPU
: il numero di chip TPU nel cluster.TPU-v6e-16-head
: un identificatore speciale per la risorsa corrispondente al worker 0 di una sezione TPU. Questo è importante per accedere alle singole sezioni di TPU.memory
: memoria heap del worker utilizzata dall'applicazione.object_store_memory
: memoria utilizzata quando l'applicazione crea oggetti nell'object store utilizzandoray.put
e quando restituisce valori dalle funzioni remote.tpu-group-0
etpu-group-1
: identificatori univoci per le singole sezioni TPU. Questo è importante per l'esecuzione di job su slice. Questi campi sono impostati su 4 perché in una v6e-16 sono presenti 4 host per ogni slice TPU.