Libreria di monitoraggio TPU

Ottieni informazioni approfondite sulle prestazioni e sul comportamento dell'hardware Cloud TPU con funzionalità di monitoraggio avanzate, create direttamente sul livello software di base, LibTPU. Sebbene LibTPU comprenda driver, librerie di rete, il compilatore XLA e il runtime TPU per l'interazione con le TPU, questo documento si concentra sulla libreria di monitoraggio TPU.

La libreria di monitoraggio TPU fornisce:

  • Osservabilità completa: accedi all'API Telemetry e alla suite di metriche. In questo modo puoi ottenere informazioni dettagliate sul rendimento operativo e sui comportamenti specifici delle tue TPU.

  • Toolkit di diagnostica: fornisce un SDK e un'interfaccia a riga di comando (CLI) progettati per consentire il debug e l'analisi approfondita delle prestazioni delle tue risorse TPU.

Queste funzionalità di monitoraggio sono progettate per essere una soluzione di primo livello rivolta ai clienti, fornendoti gli strumenti essenziali per ottimizzare in modo efficace i carichi di lavoro TPU.

La libreria di monitoraggio TPU fornisce informazioni dettagliate sul rendimento dei carichi di lavoro di machine learning sull'hardware TPU. È progettato per aiutarti a comprendere l'utilizzo delle TPU, identificare i colli di bottiglia ed eseguire il debug dei problemi di prestazioni. Fornisce informazioni più dettagliate rispetto alle metriche di interruzione, goodput e altre metriche.

Inizia a utilizzare la libreria di monitoraggio TPU

Accedere a questi potenti approfondimenti è semplice. La funzionalità di monitoraggio delle TPU è integrata nell'SDK LibTPU, quindi è inclusa quando installi LibTPU.

Installare LibTPU

pip install libtpu

In alternativa, gli aggiornamenti di LibTPU sono coordinati con le release di JAX, il che significa che quando installi l'ultima release di JAX (rilasciata mensilmente), in genere viene bloccata l'ultima versione compatibile di LibTPU e le relative funzionalità.

Installare JAX

pip install -U "jax[tpu]"

Per gli utenti di PyTorch, l'installazione di PyTorch/XLA fornisce le funzionalità di monitoraggio di LibTPU e TPU più recenti.

Installa PyTorch/XLA

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

Per maggiori informazioni sull'installazione di PyTorch/XLA, consulta la sezione Installazione nel repository GitHub di PyTorch/XLA.

Importa la libreria in Python

Per iniziare a utilizzare la libreria di monitoraggio TPU, devi importare il modulo libtpu nel codice Python.

from libtpu.sdk import tpumonitoring

Elencare tutte le funzionalità supportate

Elenca tutti i nomi delle metriche e le funzionalità che supportano:


from libtpu.sdk import tpumonitoring

tpumonitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

Metriche supportate

Il seguente esempio di codice mostra come elencare tutti i nomi delle metriche supportate:

from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

La tabella seguente mostra tutte le metriche e le relative definizioni:

Metrica Definizione Nome della metrica per l'API Valori di esempio
Utilizzo di Tensor Core Misura la percentuale di utilizzo di TensorCore, calcolata come la percentuale di operazioni che fanno parte delle operazioni TensorCore. Campionamento eseguito ogni secondo per 10 microsecondi. Non puoi modificare la frequenza di campionamento. Questa metrica ti consente di monitorare l'efficienza dei tuoi carichi di lavoro sui dispositivi TPU. tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# utilization percentage for accelerator ID 0-3
Percentuale del ciclo di lavoro Percentuale di tempo nell'ultimo periodo di campionamento (ogni 5 secondi; può essere regolata impostando il flag LIBTPU_INIT_ARG) durante il quale l'acceleratore ha eseguito attivamente l'elaborazione (registrata con i cicli utilizzati per eseguire i programmi HLO nell'ultimo periodo di campionamento). Questa metrica rappresenta il livello di utilizzo di una TPU. La metrica viene emessa per chip. duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# Duty cycle percentage for accelerator ID 0-3
HBM Capacity Total Questa metrica indica la capacità totale di HBM in byte. hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# Capacità totale HBM in byte collegata all'ID acceleratore 0-3
Utilizzo capacità HBM Questa metrica indica l'utilizzo della capacità HBM in byte nell'ultimo periodo di campionamento (ogni 5 secondi; può essere modificato impostando il flag LIBTPU_INIT_ARG). hbm_capacity_usage ['100', '200', '300', '400']

# Capacity usage for HBM in bytes that attached to accelerator ID 0-3
Latenza di trasferimento del buffer Latenze di trasferimento di rete per il traffico multislice su larga scala. Questa visualizzazione ti consente di comprendere l'ambiente di rendimento complessivo della rete. buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# buffer size, mean, p50, p90, p99, p99.9 of network transfer latency distribution
Metriche di distribuzione del tempo di esecuzione delle operazioni di alto livello Fornisce informazioni dettagliate sul rendimento dello stato di esecuzione del file binario compilato HLO, consentendo il rilevamento delle regressioni e il debug a livello di modello. hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# The HLO execution time duration distribution for CoreType-CoreID with mean, p50, p90, p95, p999
Dimensioni della coda dello strumento di ottimizzazione di alto livello Il monitoraggio delle dimensioni della coda di esecuzione HLO tiene traccia del numero di programmi HLO compilati in attesa o in fase di esecuzione. Questa metrica rivela la congestione della pipeline di esecuzione, consentendo l'identificazione di colli di bottiglia delle prestazioni nell'esecuzione hardware, nell'overhead dei driver o nell'allocazione delle risorse. hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# Measures queue size for CoreType-CoreID.
Latenza end-to-end collettiva Questa metrica misura la latenza collettiva end-to-end su DCN in microsecondi, dall'host che avvia l'operazione a tutti i peer che ricevono l'output. Include la riduzione dei dati lato host e l'invio dell'output alla TPU. I risultati sono stringhe che descrivono in dettaglio le dimensioni, il tipo e le latenze medie, P50, P90, P95 e P99,9 del buffer. collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency

Lettura dei dati delle metriche - modalità snapshot

Per attivare la modalità snapshot, specifica il nome della metrica quando chiami la funzione tpumonitoring.get_metric. La modalità Snapshot consente di inserire controlli delle metriche ad hoc nel codice con prestazioni scarse per identificare se i problemi di prestazioni derivano da software o hardware.

Il seguente esempio di codice mostra come utilizzare la modalità snapshot per leggere duty_cycle.

from libtpu.sdk import tpumonitoring

metric = tpumonitoring.get_metric("duty_cycle_pct")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

Accedere alle metriche utilizzando la CLI

I seguenti passaggi mostrano come interagire con le metriche LibTPU utilizzando la CLI:

  1. Installa tpu-info:

    pip install tpu-info
    
    
    # Access help information of tpu-info
    tpu-info --help / -h
    
    
  2. Esegui la visione predefinita di tpu-info:

    tpu-info
    

    L'output è simile al seguente:

   TPU Chips
   ┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
    Chip         Type         Devices  PID       ┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
    /dev/accel0  TPU v4 chip  1        130007     /dev/accel1  TPU v4 chip  1        130007     /dev/accel2  TPU v4 chip  1        130007     /dev/accel3  TPU v4 chip  1        130007    └─────────────┴─────────────┴─────────┴────────┘

   TPU Runtime Utilization
   ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
    Device  Memory usage          Duty cycle    ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
    0       0.00 GiB / 31.75 GiB       0.00%     1       0.00 GiB / 31.75 GiB       0.00%     2       0.00 GiB / 31.75 GiB       0.00%     3       0.00 GiB / 31.75 GiB       0.00%    └────────┴──────────────────────┴────────────┘

   TensorCore Utilization
   ┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
    Chip ID  TensorCore Utilization    ┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
    0                         0.00%     1                         0.00%     3                         0.00%     2                         0.00% |
   └─────────┴────────────────────────┘

   Buffer Transfer Latency
   ┏━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━━┓
    Buffer Size  P50  P90  P95  P999    ┡━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━╇━━━━━━┩
          8MB+  | 0us  0us  0us   0us |
   └─────────────┴─────┴─────┴─────┴──────┘

Utilizzare le metriche per controllare l'utilizzo delle TPU

Gli esempi seguenti mostrano come utilizzare le metriche della libreria di monitoraggio TPU per monitorare l'utilizzo della TPU.

Monitorare il ciclo di servizio TPU durante l'addestramento JAX

Scenario: stai eseguendo uno script di addestramento JAX e vuoi monitorare la metrica duty_cycle_pct della TPU durante il processo di addestramento per verificare che le TPU vengano utilizzate in modo efficace. Puoi registrare questa metrica periodicamente durante l'addestramento per monitorare l'utilizzo della TPU.

Il seguente esempio di codice mostra come monitorare il ciclo di lavoro della TPU durante l'addestramento JAX:

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring Library here to get duty_cycle ---
            duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring Library Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

Controlla l'utilizzo della HBM prima di eseguire l'inferenza JAX

Scenario: prima di eseguire l'inferenza con il modello JAX, controlla l'utilizzo attuale della HBM (High Bandwidth Memory) sulla TPU per verificare di avere memoria sufficiente e per ottenere una misurazione di base prima dell'inizio dell'inferenza.

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End TPU Monitoring Library Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

Frequenza di esportazione delle metriche TPU

La frequenza di aggiornamento delle metriche TPU è limitata a un minimo di un secondo. I dati delle metriche host vengono esportati a una frequenza fissa di 1 Hz. La latenza introdotta da questo processo di esportazione è trascurabile. Le metriche di runtime di LibTPU non sono soggette allo stesso vincolo di frequenza. Tuttavia, per coerenza, anche queste metriche vengono campionate a 1 Hz o 1 campione al secondo.