Biblioteca de monitorización de TPU
Obtén información detallada sobre el rendimiento y el comportamiento del hardware de tu TPU de Cloud con las funciones de monitorización avanzadas de TPU, creadas directamente sobre la capa de software fundamental, LibTPU. Aunque LibTPU incluye controladores, bibliotecas de redes, el compilador XLA y el tiempo de ejecución de TPU para interactuar con las TPUs, este documento se centra en la biblioteca de monitorización de TPUs.
La biblioteca de monitorización de TPU proporciona lo siguiente:
Observabilidad integral: accede a la API de telemetría y al conjunto de métricas. De esta forma, puede obtener información detallada sobre el rendimiento operativo y los comportamientos específicos de sus TPUs.
Kits de herramientas de diagnóstico: proporciona un SDK y una interfaz de línea de comandos (CLI) diseñados para depurar y analizar en profundidad el rendimiento de tus recursos de TPU.
Estas funciones de monitorización se han diseñado para ser una solución de alto nivel orientada al cliente, que te proporciona las herramientas esenciales para optimizar tus cargas de trabajo de TPU de forma eficaz.
La biblioteca de monitorización de TPUs te proporciona información detallada sobre el rendimiento de las cargas de trabajo de aprendizaje automático en el hardware de TPUs. Está diseñada para ayudarte a entender el uso de tus TPUs, identificar cuellos de botella y depurar problemas de rendimiento. Proporciona información más detallada que las métricas de interrupción, las métricas de buen rendimiento y otras métricas.
Empezar a usar la biblioteca de monitorización de TPU
Acceder a estas valiosas estadísticas es muy sencillo. La función de monitorización de las TPU está integrada en el SDK de LibTPU, por lo que se incluye al instalar LibTPU.
Instalar LibTPU
pip install libtpu
De forma alternativa, las actualizaciones de LibTPU se coordinan con los lanzamientos de JAX, lo que significa que, cuando instales la última versión de JAX (que se lanza mensualmente), normalmente se te asignará la última versión compatible de LibTPU y sus funciones.
Instalar JAX
pip install -U "jax[tpu]"
Para los usuarios de PyTorch, instalar PyTorch/XLA proporciona las funciones más recientes de LibTPU y de monitorización de TPU.
Instalar PyTorch/XLA
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Para obtener más información sobre cómo instalar PyTorch/XLA, consulta Instalación en el repositorio de GitHub de PyTorch/XLA.
Importar la biblioteca en Python
Para empezar a usar la biblioteca de monitorización de TPU, debes importar el módulo libtpu
en tu código de Python.
from libtpu.sdk import tpumonitoring
Lista de todas las funciones admitidas
Lista de todos los nombres de métricas y las funciones que admiten:
from libtpu.sdk import tpumonitoring
tpumonitoring.help()
" libtpu.sdk.monitoring.help():
List all supported functionality.
libtpu.sdk.monitoring.list_support_metrics()
List support metric names in the list of str format.
libtpu.sdk.monitoring.get_metric(metric_name:str)
Get metric data with metric name. It represents the snapshot mode.
The metric data is a object with `description()` and `data()` methods,
where the `description()` returns a string describe the format of data
and data unit, `data()` returns the metric data in the list in str format.
"
Métricas admitidas
En el siguiente código de ejemplo se muestra cómo obtener una lista de todos los nombres de métricas admitidos:
from libtpu.sdk import tpumonitoring
tpumonitoring.list_supported_metrics()
["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]
En la siguiente tabla se muestran todas las métricas y sus definiciones correspondientes:
Métrica | Definición | Nombre de la métrica de la API | Valores de ejemplo |
---|---|---|---|
Uso de Tensor Core | Mide el porcentaje de uso de Tensor Core, que se calcula como el porcentaje de operaciones que forman parte de las operaciones de Tensor Core. Se muestrea cada segundo durante 10 microsegundos. No puedes modificar la frecuencia de muestreo. Esta métrica le permite monitorizar la eficiencia de sus cargas de trabajo en dispositivos TPU. |
tensorcore_util
|
['1.11', '2.22', '3.33', '4.44']
# porcentaje de uso del ID de acelerador 0-3 |
Porcentaje del ciclo de trabajo | Porcentaje del tiempo durante el último periodo de muestreo (cada 5 segundos; se puede ajustar configurando la marca LIBTPU_INIT_ARG ) en el que el acelerador ha procesado activamente (registrado con los ciclos utilizados para ejecutar programas HLO durante el último periodo de muestreo). Esta métrica representa el nivel de ocupación de una TPU. La métrica se emite por chip.
|
duty_cycle_pct
|
['10.00', '20.00', '30.00', '40.00']
# Porcentaje del ciclo de actividad del acelerador con ID 0-3 |
Capacidad total de HBM | Esta métrica indica la capacidad total de HBM en bytes. |
hbm_capacity_total
|
['30000000000', '30000000000', '30000000000', '30000000000']
# Capacidad total de HBM en bytes conectada al acelerador con ID del 0 al 3 |
Uso de la capacidad de HBM | Esta métrica informa del uso de la capacidad de HBM en bytes durante el periodo de muestreo anterior (cada 5 segundos; se puede ajustar configurando la marca LIBTPU_INIT_ARG ).
|
hbm_capacity_usage
|
['100', '200', '300', '400']
# Uso de la capacidad de HBM en bytes que se ha adjuntado al ID de acelerador 0-3 |
Latencia de transferencia de búfer | Latencias de transferencia de red para tráfico multisegmento a gran escala. Esta visualización te permite conocer el entorno de rendimiento general de la red. |
buffer_transfer_latency
|
["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]
# tamaño del búfer, media, p50, p90, p99 y p99, 9 de la distribución de la latencia de transferencia de red |
Métricas de distribución del tiempo de ejecución de operaciones de alto nivel | Proporciona estadísticas de rendimiento detalladas sobre el estado de ejecución del archivo binario compilado de HLO, lo que permite detectar regresiones y depurar a nivel de modelo. |
hlo_exec_timing
|
["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]
# Distribución de la duración del tiempo de ejecución de HLO para CoreType-CoreID con la media, p50, p90, p95 y p999 |
Tamaño de la cola del optimizador de nivel alto | La monitorización del tamaño de la cola de ejecución de HLO registra el número de programas HLO compilados que están esperando o en proceso de ejecución. Esta métrica muestra la congestión de la canalización de ejecución, lo que permite identificar cuellos de botella en el rendimiento de la ejecución de hardware, la sobrecarga de controladores o la asignación de recursos. |
hlo_queue_size
|
["tensorcore-0: 1", "tensorcore-1: 2"]
# Mide el tamaño de la cola de CoreType-CoreID. |
Latencia integral colectiva | Esta métrica mide la latencia colectiva de extremo a extremo en la DCN en microsegundos, desde que el host inicia la operación hasta que todos los peers reciben el resultado. Incluye la reducción de datos del host y el envío de la salida a la TPU. Los resultados son cadenas que detallan el tamaño del búfer, el tipo y las latencias media, p50, p90, p95 y p99,9. |
collective_e2e_latency
|
["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]
# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency |
Leer datos de métricas (modo de vista general)
Para habilitar el modo de instantánea, especifica el nombre de la métrica al llamar a la función tpumonitoring.get_metric
. El modo de instantánea te permite insertar comprobaciones de métricas ad hoc en código de bajo rendimiento para identificar si los problemas de rendimiento se deben al software o al hardware.
En el siguiente código de ejemplo se muestra cómo usar el modo de instantánea para leer el duty_cycle
.
from libtpu.sdk import tpumonitoring
metric = tpumonitoring.get_metric("duty_cycle_pct")
metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."
metric.data()
["0.00", "0.00", "0.00", "0.00"]
# accelerator_0-3
Acceder a las métricas mediante la CLI
En los siguientes pasos se muestra cómo interactuar con las métricas de LibTPU mediante la CLI:
Instalar
tpu-info
:pip install tpu-info
# Access help information of tpu-info tpu-info --help / -h
Ejecuta la visión predeterminada de
tpu-info
:tpu-info
El resultado debería ser similar al siguiente:
TPU Chips
┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
┃ Chip ┃ Type ┃ Devices ┃ PID ┃
┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
│ /dev/accel0 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel1 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel2 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel3 │ TPU v4 chip │ 1 │ 130007 │
└─────────────┴─────────────┴─────────┴────────┘
TPU Runtime Utilization
┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Device ┃ Memory usage ┃ Duty cycle ┃
┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ 0 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 1 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 2 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 3 │ 0.00 GiB / 31.75 GiB │ 0.00% │
└────────┴──────────────────────┴────────────┘
TensorCore Utilization
┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Chip ID ┃ TensorCore Utilization ┃
┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
│ 0 │ 0.00% │
│ 1 │ 0.00% │
│ 3 │ 0.00% │
│ 2 │ 0.00% |
└─────────┴────────────────────────┘
Buffer Transfer Latency
┏━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━━┓
┃ Buffer Size ┃ P50 ┃ P90 ┃ P95 ┃ P999 ┃
┡━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━╇━━━━━━┩
│ 8MB+ | 0us │ 0us │ 0us │ 0us |
└─────────────┴─────┴─────┴─────┴──────┘
Usar métricas para comprobar la utilización de la TPU
En los siguientes ejemplos se muestra cómo usar las métricas de la biblioteca de monitorización de TPUs para hacer un seguimiento del uso de las TPUs.
Monitorizar el ciclo de trabajo de la TPU durante el entrenamiento de JAX
Situación: estás ejecutando una secuencia de comandos de entrenamiento de JAX y quieres monitorizar la métrica duty_cycle_pct
de la TPU durante todo el proceso de entrenamiento para confirmar que las TPUs se están utilizando de forma eficaz. Puedes registrar esta métrica periódicamente durante el entrenamiento para monitorizar el uso de la TPU.
En el siguiente ejemplo de código se muestra cómo monitorizar el ciclo de trabajo de la TPU durante el entrenamiento de JAX:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time
# --- Your JAX model and training setup would go here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
def loss_fn(params, x, y):
preds = simple_model(x)
return jnp.mean((preds - y)**2)
def train_step(params, x, y, optimizer):
grads = jax.grad(loss_fn)(params, x, y)
return optimizer.update(grads, params)
key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))
num_epochs = 10
log_interval_steps = 2 # Log duty cycle every 2 steps
for epoch in range(num_epochs):
for step in range(5): # Example steps per epoch
params = train_step(params, data_x, data_y, optimizer)
if (step + 1) % log_interval_steps == 0:
# --- Integrate TPU Monitoring Library here to get duty_cycle ---
duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
duty_cycle_data = duty_cycle_metric.data
print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
print(f" Description: {duty_cycle_metric.description}")
print(f" Data: {duty_cycle_data}")
# --- End TPU Monitoring Library Integration ---
# --- Rest of your training loop logic ---
time.sleep(0.1) # Simulate some computation
print("Training complete.")
Comprobar la utilización de HBM antes de ejecutar la inferencia de JAX
Situación: antes de ejecutar la inferencia con tu modelo de JAX, comprueba el uso actual de la memoria de alto ancho de banda (HBM) en la TPU para confirmar que tienes suficiente memoria disponible y obtener una medición de referencia antes de que empiece la inferencia.
# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and inference setup would go here ---
# --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
return jnp.sum(x)
key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters
# Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f" Description: {hbm_util_metric.description}")
print(f" Data: {hbm_util_data}")
# End TPU Monitoring Library Integration
# Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)
print("Inference complete.")
Frecuencia de exportación de métricas de TPU
La frecuencia de actualización de las métricas de TPU está limitada a un mínimo de un segundo. Los datos de las métricas de host se exportan con una frecuencia fija de 1 Hz. La latencia introducida por este proceso de exportación es insignificante. Las métricas de tiempo de ejecución de LibTPU no están sujetas a la misma restricción de frecuencia. Sin embargo, para mantener la coherencia, estas métricas también se muestrean a 1 Hz o 1 muestra por segundo.