Bibliothèque de surveillance des TPU

Obtenez des insights approfondis sur les performances et le comportement de votre matériel Cloud TPU grâce à des fonctionnalités de surveillance avancées, directement intégrées à la couche logicielle de base, LibTPU. Alors que LibTPU englobe les pilotes, les bibliothèques réseau, le compilateur XLA et l'environnement d'exécution TPU pour interagir avec les TPU, ce document se concentre sur la bibliothèque de surveillance TPU.

La bibliothèque de surveillance des TPU fournit les éléments suivants :

  • Observabilité complète : accédez à l'API de télémétrie et à la suite de métriques. Vous pouvez ainsi obtenir des insights détaillés sur les performances opérationnelles et les comportements spécifiques de vos TPU.

  • Kits d'outils de diagnostic : fournissent un SDK et une interface de ligne de commande (CLI) conçus pour permettre le débogage et l'analyse approfondie des performances de vos ressources TPU.

Ces fonctionnalités de surveillance sont conçues pour être une solution de premier niveau destinée aux clients. Elles vous fournissent les outils essentiels pour optimiser efficacement vos charges de travail TPU.

La bibliothèque de surveillance des TPU vous fournit des informations détaillées sur les performances des charges de travail de machine learning sur le matériel TPU. Il est conçu pour vous aider à comprendre votre utilisation des TPU, à identifier les goulots d'étranglement et à résoudre les problèmes de performances. Elle fournit des informations plus détaillées que les métriques d'interruption, les métriques de débit utile et d'autres métriques.

Premiers pas avec la bibliothèque de surveillance des TPU

Il est facile d'accéder à ces insights puissants. La fonctionnalité de surveillance des TPU est intégrée au SDK LibTPU. Elle est donc incluse lorsque vous installez LibTPU.

Installer LibTPU

pip install libtpu

Les mises à jour de LibTPU sont coordonnées avec les versions de JAX. Cela signifie que lorsque vous installez la dernière version de JAX (publiée tous les mois), vous êtes généralement redirigé vers la dernière version compatible de LibTPU et ses fonctionnalités.

Installer JAX

pip install -U "jax[tpu]"

Pour les utilisateurs de PyTorch, l'installation de PyTorch/XLA fournit les dernières fonctionnalités de LibTPU et de surveillance des TPU.

Installer PyTorch/XLA

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

Pour en savoir plus sur l'installation de PyTorch/XLA, consultez Installation dans le dépôt GitHub de PyTorch/XLA.

Importer la bibliothèque en Python

Pour commencer à utiliser la bibliothèque de surveillance des TPU, vous devez importer le module libtpu dans votre code Python.

from libtpu.sdk import tpumonitoring

Lister toutes les fonctionnalités compatibles

Liste de tous les noms de métriques et des fonctionnalités qu'elles prennent en charge :


from libtpu.sdk import tpumonitoring

tpumonitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

Métriques acceptées

L'exemple de code suivant montre comment lister tous les noms de métriques compatibles :

from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

Le tableau suivant présente toutes les métriques et leurs définitions correspondantes :

Métrique Définition Nom de la métrique pour l'API Exemples de valeur
Utilisation des Tensor Cores Mesure le pourcentage d'utilisation de votre TensorCore, calculé comme le pourcentage d'opérations faisant partie des opérations TensorCore. Échantillonné toutes les secondes à 10 microsecondes. Vous ne pouvez pas modifier le taux d'échantillonnage. Cette métrique vous permet de surveiller l'efficacité de vos charges de travail sur les appareils TPU. tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# utilization percentage for accelerator ID 0-3
Pourcentage du cycle d'utilisation Pourcentage de temps au cours de la dernière période d'échantillonnage (toutes les cinq secondes ; peut être ajusté en définissant l'indicateur LIBTPU_INIT_ARG) pendant lequel l'accélérateur a été en mode de traitement actif (enregistré avec les cycles utilisés pour exécuter les programmes HLO au cours de la dernière période d'échantillonnage). Cette métrique représente la charge d'un TPU. Elle est émise par puce. duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# Duty cycle percentage for accelerator ID 0-3
Capacité totale de la mémoire HBM Cette métrique indique la capacité totale de la HBM en octets. hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# Capacité HBM totale en octets associée aux ID d'accélérateur 0 à 3
Utilisation de la capacité HBM Cette métrique indique l'utilisation de la capacité HBM en octets au cours de la période d'échantillonnage précédente (toutes les cinq secondes ; peut être ajustée en définissant l'indicateur LIBTPU_INIT_ARG). hbm_capacity_usage ['100', '200', '300', '400']

# Utilisation de la capacité pour la mémoire HBM en octets associée aux ID d'accélérateur 0 à 3
Latence de transfert du tampon Latences de transfert réseau pour le trafic multislice à grande échelle. Cette visualisation vous permet de comprendre l'environnement global des performances du réseau. buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# buffer size, mean, p50, p90, p99, p99.9 of network transfer latency distribution
Métriques de distribution du temps d'exécution des opérations de haut niveau Fournit des insights précis sur les performances de l'état d'exécution du binaire compilé HLO, ce qui permet de détecter les régressions et de déboguer au niveau du modèle. hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# The HLO execution time duration distribution for CoreType-CoreID with mean, p50, p90, p95, p999
Taille de la file d'attente de l'optimiseur de haut niveau La surveillance de la taille de la file d'exécution HLO permet de suivre le nombre de programmes HLO compilés en attente ou en cours d'exécution. Cette métrique révèle la congestion du pipeline d'exécution, ce qui permet d'identifier les goulots d'étranglement des performances dans l'exécution matérielle, la surcharge du pilote ou l'allocation des ressources. hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# Mesure la taille de la file d'attente pour CoreType-CoreID.
Latence collective de bout en bout Cette métrique mesure la latence collective de bout en bout sur le DCN en microsecondes, depuis l'hôte qui lance l'opération jusqu'à ce que tous les pairs reçoivent le résultat. Cela inclut la réduction des données côté hôte et l'envoi de la sortie au TPU. Les résultats sont des chaînes détaillant la taille du tampon, le type et les latences moyennes, p50, p90, p95 et p99,9. collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# Taille du transfert-op collectif, moyenne, p50, p90, p95, p999 de la latence collective de bout en bout

Lire les données de métrique en mode instantané

Pour activer le mode instantané, spécifiez le nom de la métrique lorsque vous appelez la fonction tpumonitoring.get_metric. Le mode instantané vous permet d'insérer des vérifications de métriques ponctuelles dans le code à faibles performances pour déterminer si les problèmes de performances proviennent du logiciel ou du matériel.

L'exemple de code suivant montre comment utiliser le mode instantané pour lire duty_cycle.

from libtpu.sdk import tpumonitoring

metric = tpumonitoring.get_metric("duty_cycle_pct")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

Accéder aux métriques à l'aide de la CLI

Les étapes suivantes montrent comment interagir avec les métriques LibTPU à l'aide de la CLI :

  1. Installez tpu-info :

    pip install tpu-info
    
    
    # Access help information of tpu-info
    tpu-info --help / -h
    
    
  2. Exécutez la vision par défaut de tpu-info :

    tpu-info
    

    Le résultat ressemble à ce qui suit :

   TPU Chips
   ┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
    Chip         Type         Devices  PID       ┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
    /dev/accel0  TPU v4 chip  1        130007     /dev/accel1  TPU v4 chip  1        130007     /dev/accel2  TPU v4 chip  1        130007     /dev/accel3  TPU v4 chip  1        130007    └─────────────┴─────────────┴─────────┴────────┘

   TPU Runtime Utilization
   ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
    Device  Memory usage          Duty cycle    ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
    0       0.00 GiB / 31.75 GiB       0.00%     1       0.00 GiB / 31.75 GiB       0.00%     2       0.00 GiB / 31.75 GiB       0.00%     3       0.00 GiB / 31.75 GiB       0.00%    └────────┴──────────────────────┴────────────┘

   TensorCore Utilization
   ┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
    Chip ID  TensorCore Utilization    ┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
    0                         0.00%     1                         0.00%     3                         0.00%     2                         0.00% |
   └─────────┴────────────────────────┘

   Buffer Transfer Latency
   ┏━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━━┓
    Buffer Size  P50  P90  P95  P999    ┡━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━╇━━━━━━┩
          8MB+  | 0us  0us  0us   0us |
   └─────────────┴─────┴─────┴─────┴──────┘

Utiliser des métriques pour vérifier l'utilisation des TPU

Les exemples suivants montrent comment utiliser les métriques de la bibliothèque de surveillance des TPU pour suivre l'utilisation des TPU.

Surveiller le cycle d'utilisation des TPU pendant l'entraînement JAX

Scénario : Vous exécutez un script d'entraînement JAX et vous souhaitez surveiller la métrique duty_cycle_pct du TPU tout au long du processus d'entraînement pour confirmer que vos TPU sont utilisés efficacement. Vous pouvez enregistrer cette métrique périodiquement pendant l'entraînement pour suivre l'utilisation des TPU.

L'exemple de code suivant montre comment surveiller le taux d'utilisation des TPU lors de l'entraînement JAX :

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring Library here to get duty_cycle ---
            duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring Library Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

Vérifier l'utilisation de la HBM avant d'exécuter l'inférence JAX

Scénario  : Avant d'exécuter l'inférence avec votre modèle JAX, vérifiez l'utilisation actuelle de la mémoire HBM (High Bandwidth Memory) sur la TPU pour vous assurer que vous disposez de suffisamment de mémoire et obtenir une mesure de référence avant le début de l'inférence.

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End TPU Monitoring Library Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

Fréquence d'exportation des métriques TPU

La fréquence d'actualisation des métriques TPU est limitée à une seconde minimum. Les données de métriques hôtes sont exportées à une fréquence fixe de 1 Hz. La latence introduite par ce processus d'exportation est négligeable. Les métriques d'exécution de LibTPU ne sont pas soumises à la même contrainte de fréquence. Toutefois, par souci de cohérence, ces métriques sont également échantillonnées à 1 Hz, soit un échantillon par seconde.