Biblioteca de supervisión de TPU

Obtén estadísticas detalladas sobre el rendimiento y el comportamiento del hardware de tu Cloud TPU con las capacidades avanzadas de supervisión de TPU, integradas directamente en la capa de software fundamental, LibTPU. Si bien LibTPU abarca controladores, bibliotecas de redes, el compilador XLA y el entorno de ejecución de TPU para interactuar con las TPU, este documento se centra en la biblioteca de supervisión de TPU.

La biblioteca de supervisión de TPU proporciona lo siguiente:

  • Observabilidad integral: Obtén acceso a la API de telemetría y al conjunto de métricas. Esto te permite obtener estadísticas detalladas sobre el rendimiento operativo y los comportamientos específicos de tus TPU.

  • Kits de herramientas de diagnóstico: Proporcionan un SDK y una interfaz de línea de comandos (CLI) diseñados para permitir la depuración y el análisis de rendimiento detallado de tus recursos de TPU.

Estas funciones de supervisión están diseñadas para ser una solución de alto nivel orientada al cliente, ya que te proporcionan las herramientas esenciales para optimizar tus cargas de trabajo de TPU de manera eficaz.

La biblioteca de supervisión de TPU te brinda información detallada sobre el rendimiento de las cargas de trabajo de aprendizaje automático en el hardware de TPU. Está diseñado para ayudarte a comprender el uso de la TPU, identificar cuellos de botella y depurar problemas de rendimiento. Te brinda información más detallada que las métricas de interrupción, las métricas de buen rendimiento y otras métricas.

Comienza a usar la biblioteca de supervisión de TPU

Acceder a estas estadísticas útiles es sencillo. La funcionalidad de supervisión de TPU está integrada en el SDK de LibTPU, por lo que se incluye cuando instalas LibTPU.

Instala LibTPU

pip install libtpu

De forma alternativa, las actualizaciones de LibTPU se coordinan con los lanzamientos de JAX, lo que significa que, cuando instalas el lanzamiento de JAX más reciente (que se lanza mensualmente), por lo general, se fijará en la versión de LibTPU compatible más reciente y sus funciones.

Instala JAX

pip install -U "jax[tpu]"

Para los usuarios de PyTorch, instalar PyTorch/XLA proporciona las funciones más recientes de LibTPU y supervisión de TPU.

Instala PyTorch/XLA

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

Para obtener más información sobre la instalación de PyTorch/XLA, consulta Installation en el repositorio de GitHub de PyTorch/XLA.

Importa la biblioteca en Python

Para comenzar a usar la biblioteca de supervisión de TPU, debes importar el módulo libtpu en tu código de Python.

from libtpu.sdk import tpumonitoring

Enumera todas las funciones compatibles

Enumera todos los nombres de métricas y la funcionalidad que admiten:


from libtpu.sdk import tpumonitoring

tpumonitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

Métricas admitidas

En el siguiente muestra de código, se muestra cómo enumerar todos los nombres de métricas admitidos:

from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

En la siguiente tabla, se muestran todas las métricas y sus definiciones correspondientes:

Métrica Definición Nombre de la métrica para la API Valores de ejemplo
Uso de Tensor Core Mide el porcentaje de uso de TensorCore, calculado como el porcentaje de operaciones que forman parte de las operaciones de TensorCore. Se toman muestras cada 10 microsegundos por segundo. No puedes modificar la tasa de muestreo. Esta métrica te permite supervisar la eficiencia de tus cargas de trabajo en dispositivos TPU. tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# Porcentaje de uso para los IDs de acelerador del 0 al 3
Porcentaje del ciclo de trabajo Porcentaje de tiempo durante el último período de muestra (cada 5 segundos; se puede ajustar configurando la marca LIBTPU_INIT_ARG) durante el cual el acelerador estuvo realizando procesamiento de forma activa (registrado con los ciclos que se usaron para ejecutar programas de HLO durante el último período de muestreo). Esta métrica representa qué tan ocupada está una TPU y se emite por chip. duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# Porcentaje del ciclo de trabajo para los IDs de acelerador del 0 al 3
Capacidad total de HBM Esta métrica informa la capacidad total de HBM en bytes. hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# Capacidad total de HBM en bytes que se adjuntó a los IDs de acelerador del 0 al 3
Uso de la capacidad de HBM Esta métrica informa el uso de la capacidad de HBM en bytes durante el período de muestra anterior (cada 5 segundos; se puede ajustar configurando la marca LIBTPU_INIT_ARG). hbm_capacity_usage ['100', '200', '300', '400']

# Uso de capacidad para HBM en bytes que se adjuntan a los IDs de acelerador del 0 al 3
Latencia de transferencia de búfer Latencias de transferencia de red para el tráfico de varias porciones a gran escala. Esta visualización te permite comprender el entorno general de rendimiento de la red. buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# buffer size, mean, p50, p90, p99, p99.9 of network transfer latency distribution
Métricas de distribución del tiempo de ejecución de operaciones de alto nivel Proporciona estadísticas detalladas del rendimiento sobre el estado de ejecución del archivo binario compilado de HLO, lo que permite la detección de regresiones y la depuración a nivel del modelo. hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# La distribución de la duración del tiempo de ejecución del HLO para CoreType-CoreID con la media, P50, P90, P95 y P999
Tamaño de la cola del optimizador de alto nivel El monitoreo del tamaño de la cola de ejecución de HLO hace un seguimiento de la cantidad de programas HLO compilados que están en espera o en ejecución. Esta métrica revela la congestión de la canalización de ejecución, lo que permite identificar los cuellos de botella en el rendimiento de la ejecución del hardware, la sobrecarga del controlador o la asignación de recursos. hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# Mide el tamaño de la cola para CoreType-CoreID.
Latencia colectiva de extremo a extremo Esta métrica mide la latencia colectiva de extremo a extremo en la DCN en microsegundos, desde el host que inicia la operación hasta todos los pares que reciben el resultado. Incluye la reducción de datos del host y el envío de resultados a la TPU. Los resultados son cadenas que detallan el tamaño del búfer, el tipo y las latencias medias, p50, p90, p95 y p99.9. collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency

Lee datos de métricas: modo de instantánea

Para habilitar el modo de instantánea, especifica el nombre de la métrica cuando llames a la función tpumonitoring.get_metric. El modo de instantánea te permite insertar verificaciones de métricas ad hoc en el código de bajo rendimiento para identificar si los problemas de rendimiento provienen del software o del hardware.

En la siguiente muestra de código, se muestra cómo usar el modo de instantánea para leer el objeto duty_cycle.

from libtpu.sdk import tpumonitoring

metric = tpumonitoring.get_metric("duty_cycle_pct")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

Cómo acceder a las métricas con la CLI

En los siguientes pasos, se muestra cómo interactuar con las métricas de LibTPU a través de la CLI:

  1. Instala tpu-info:

    pip install tpu-info
    
    
    # Access help information of tpu-info
    tpu-info --help / -h
    
    
  2. Ejecuta la visión predeterminada de tpu-info:

    tpu-info
    

    El resultado es similar a este:

   TPU Chips
   ┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
    Chip         Type         Devices  PID       ┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
    /dev/accel0  TPU v4 chip  1        130007     /dev/accel1  TPU v4 chip  1        130007     /dev/accel2  TPU v4 chip  1        130007     /dev/accel3  TPU v4 chip  1        130007    └─────────────┴─────────────┴─────────┴────────┘

   TPU Runtime Utilization
   ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
    Device  Memory usage          Duty cycle    ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
    0       0.00 GiB / 31.75 GiB       0.00%     1       0.00 GiB / 31.75 GiB       0.00%     2       0.00 GiB / 31.75 GiB       0.00%     3       0.00 GiB / 31.75 GiB       0.00%    └────────┴──────────────────────┴────────────┘

   TensorCore Utilization
   ┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
    Chip ID  TensorCore Utilization    ┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
    0                         0.00%     1                         0.00%     3                         0.00%     2                         0.00% |
   └─────────┴────────────────────────┘

   Buffer Transfer Latency
   ┏━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━━┓
    Buffer Size  P50  P90  P95  P999    ┡━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━╇━━━━━━┩
          8MB+  | 0us  0us  0us   0us |
   └─────────────┴─────┴─────┴─────┴──────┘

Usa métricas para verificar la utilización de la TPU

En los siguientes ejemplos, se muestra cómo usar las métricas de la biblioteca de supervisión de TPU para hacer un seguimiento del uso de la TPU.

Supervisa el ciclo de trabajo de la TPU durante el entrenamiento de JAX

Situación: Estás ejecutando una secuencia de comandos de entrenamiento de JAX y deseas supervisar la métrica duty_cycle_pct de la TPU durante todo el proceso de entrenamiento para confirmar que las TPU se utilizan de manera eficaz. Puedes registrar esta métrica periódicamente durante el entrenamiento para hacer un seguimiento de la utilización de la TPU.

En el siguiente muestra de código, se muestra cómo supervisar el ciclo de trabajo de la TPU durante el entrenamiento de JAX:

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring Library here to get duty_cycle ---
            duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring Library Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

Verifica la utilización de la HBM antes de ejecutar la inferencia de JAX

Situación: Antes de ejecutar la inferencia con tu modelo de JAX, verifica el uso actual de la HBM (memoria de ancho de banda alto) en la TPU para confirmar que tienes suficiente memoria disponible y obtener una medición de referencia antes de que comience la inferencia.

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End TPU Monitoring Library Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

Frecuencia de exportación de las métricas de TPU

La frecuencia de actualización de las métricas de TPU está restringida a un mínimo de un segundo. Los datos de las métricas del host se exportan con una frecuencia fija de 1 Hz. La latencia que introduce este proceso de exportación es insignificante. Las métricas de tiempo de ejecución de LibTPU no están sujetas a la misma restricción de frecuencia. Sin embargo, para mantener la coherencia, estas métricas también se muestrean a 1 Hz o 1 muestra por segundo.