Biblioteca de monitorização de TPUs

Aceda a estatísticas detalhadas sobre o desempenho e o comportamento do hardware da Cloud TPU com capacidades de monitorização avançadas da TPU, criadas diretamente na camada de software fundamental, a LibTPU. Embora a LibTPU abranja controladores, bibliotecas de rede, o compilador XLA e o tempo de execução do TPU para interagir com os TPUs, o foco deste documento é a biblioteca de monitorização do TPU.

A biblioteca de monitorização de TPU oferece:

  • Observabilidade abrangente: aceda à API de telemetria e ao conjunto de métricas. Isto permite-lhe obter estatísticas detalhadas sobre o desempenho operacional e os comportamentos específicos das suas UTPs.

  • Kits de ferramentas de diagnóstico: fornece um SDK e uma interface de linhas de comando (CLI) concebidos para permitir a depuração e a análise detalhada do desempenho dos seus recursos de TPU.

Estas funcionalidades de monitorização foram concebidas para serem uma solução de nível superior orientada para o cliente, que lhe oferece as ferramentas essenciais para otimizar os seus trabalhos de processamento de TPUs de forma eficaz.

A biblioteca de monitorização da TPU fornece informações detalhadas sobre o desempenho das cargas de trabalho de aprendizagem automática no hardware da TPU. Foi concebida para ajudar a compreender a utilização da TPU, identificar gargalos e depurar problemas de desempenho. Dá-lhe informações mais detalhadas do que as métricas de interrupção, as métricas de débito útil e outras métricas.

Comece a usar a biblioteca de monitorização de TPUs

Aceder a estas estatísticas avançadas é simples. A funcionalidade de monitorização da TPU está integrada no SDK LibTPU, pelo que a funcionalidade está incluída quando instala o LibTPU.

Instale a LibTPU

pip install libtpu

Em alternativa, as atualizações da LibTPU são coordenadas com os lançamentos do JAX, o que significa que, quando instala o lançamento mais recente do JAX (lançado mensalmente), este fixa normalmente a versão mais recente compatível da LibTPU e as respetivas funcionalidades.

Instale o JAX

pip install -U "jax[tpu]"

Para os utilizadores do PyTorch, a instalação do PyTorch/XLA fornece a funcionalidade de monitorização do LibTPU e do TPU mais recente.

Instale o PyTorch/XLA

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

Para mais informações sobre a instalação do PyTorch/XLA, consulte a secção Instalação no repositório do GitHub do PyTorch/XLA.

Importe a biblioteca em Python

Para começar a usar a biblioteca de monitorização de TPUs, tem de importar o módulo libtpu no seu código Python.

from libtpu.sdk import tpumonitoring

Apresente todas as funcionalidades suportadas

Indique todos os nomes das métricas e a funcionalidade que suportam:


from libtpu.sdk import tpumonitoring

tpumonitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

Métricas compatíveis

O seguinte exemplo de código mostra como listar todos os nomes de métricas suportados:

from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

A tabela seguinte mostra todas as métricas e as respetivas definições:

Métrica Definição Nome da métrica para a API Valores de exemplo
Utilização do núcleo Tensor Mede a percentagem da utilização dos TensorCores, calculada como a percentagem de operações que fazem parte das operações dos TensorCores. Amostragem de 10 microssegundos a cada 1 segundo. Não pode modificar a taxa de amostragem. Esta métrica permite-lhe monitorizar a eficiência das suas cargas de trabalho em dispositivos TPU. tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# utilization percentage for accelerator ID 0-3
Percentagem do ciclo de atividade Percentagem do tempo durante o período de amostragem anterior (a cada 5 segundos; pode ser ajustada através da definição da flag LIBTPU_INIT_ARG) durante o qual o acelerador estava a processar ativamente (registado com os ciclos usados para executar programas HLO durante o último período de amostragem). Esta métrica representa o nível de ocupação de uma TPU. A métrica é emitida por chip. duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# Percentagem do ciclo de serviço para o ID do acelerador 0-3
HBM Capacity Total Esta métrica indica a capacidade total de HBM em bytes. hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# Capacidade total de HBM em bytes associada ao ID do acelerador 0-3
Utilização da capacidade de HBM Esta métrica comunica a utilização da capacidade de HBM em bytes durante o período de amostragem anterior (a cada 5 segundos; pode ser ajustada através da definição da flag LIBTPU_INIT_ARG). hbm_capacity_usage ['100', '200', '300', '400']

# Capacity usage for HBM in bytes that attached to accelerator ID 0-3
Latência de transferência do buffer Latências de transferência de rede para tráfego multissegmentado de grande escala. Esta visualização permite-lhe compreender o ambiente de desempenho geral da rede. buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# buffer size, mean, p50, p90, p99, p99.9 of network transfer latency distribution
Métricas de distribuição do tempo de execução da operação de nível elevado Fornece estatísticas de desempenho detalhadas sobre o estado de execução do binário compilado HLO, o que permite a deteção de regressões e a depuração ao nível do modelo. hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# A distribuição da duração do tempo de execução do HLO para CoreType-CoreID com média, p50, p90, p95 e p999
Tamanho da fila do otimizador de nível elevado A monitorização do tamanho da fila de execução do HLO acompanha o número de programas HLO compilados que estão à espera ou em execução. Esta métrica revela o congestionamento do pipeline de execução, o que permite identificar restrições de desempenho na execução de hardware, sobrecarga do controlador ou atribuição de recursos. hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# Measures queue size for CoreType-CoreID.
Latência completa coletiva Esta métrica mede a latência coletiva ponto a ponto na DCN em microssegundos, desde o anfitrião que inicia a operação até todos os pares receberem o resultado. Inclui a redução de dados no anfitrião e o envio de resultados para a TPU. Os resultados são strings que detalham o tamanho da memória intermédia, o tipo e as latências média, p50, p90, p95 e p99,9. collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency

Ler dados de métricas – modo de resumo

Para ativar o modo de instantâneo, especifique o nome da métrica quando chamar a função tpumonitoring.get_metric. O modo Snapshot permite-lhe inserir verificações de métricas ad hoc em código de baixo desempenho para identificar se os problemas de desempenho têm origem no software ou no hardware.

O exemplo de código seguinte mostra como usar o modo de instantâneo para ler o duty_cycle.

from libtpu.sdk import tpumonitoring

metric = tpumonitoring.get_metric("duty_cycle_pct")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

Aceda às métricas através da CLI

Os passos seguintes mostram como interagir com as métricas da LibTPU através da CLI:

  1. Instale tpu-info:

    pip install tpu-info
    
    
    # Access help information of tpu-info
    tpu-info --help / -h
    
    
  2. Executar a visão predefinida de tpu-info:

    tpu-info
    

    O resultado é semelhante ao seguinte:

   TPU Chips
   ┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
    Chip         Type         Devices  PID       ┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
    /dev/accel0  TPU v4 chip  1        130007     /dev/accel1  TPU v4 chip  1        130007     /dev/accel2  TPU v4 chip  1        130007     /dev/accel3  TPU v4 chip  1        130007    └─────────────┴─────────────┴─────────┴────────┘

   TPU Runtime Utilization
   ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
    Device  Memory usage          Duty cycle    ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
    0       0.00 GiB / 31.75 GiB       0.00%     1       0.00 GiB / 31.75 GiB       0.00%     2       0.00 GiB / 31.75 GiB       0.00%     3       0.00 GiB / 31.75 GiB       0.00%    └────────┴──────────────────────┴────────────┘

   TensorCore Utilization
   ┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
    Chip ID  TensorCore Utilization    ┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
    0                         0.00%     1                         0.00%     3                         0.00%     2                         0.00% |
   └─────────┴────────────────────────┘

   Buffer Transfer Latency
   ┏━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━━┓
    Buffer Size  P50  P90  P95  P999    ┡━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━╇━━━━━━┩
          8MB+  | 0us  0us  0us   0us |
   └─────────────┴─────┴─────┴─────┴──────┘

Use métricas para verificar a utilização da TPU

Os exemplos seguintes mostram como usar métricas da biblioteca de monitorização de TPUs para acompanhar a utilização de TPUs.

Monitorize o ciclo de serviço da TPU durante o treino JAX

Cenário: está a executar um script de preparação do JAX e quer monitorizar a métrica duty_cycle_pct da TPU ao longo do processo de preparação para confirmar que as TPUs estão a ser usadas de forma eficaz. Pode registar esta métrica periodicamente durante a preparação para monitorizar a utilização da TPU.

O seguinte exemplo de código mostra como monitorizar o ciclo de serviço da TPU durante a preparação do JAX:

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring Library here to get duty_cycle ---
            duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring Library Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

Verifique a utilização da HBM antes de executar a inferência JAX

Cenário: antes de executar a inferência com o seu modelo JAX, verifique a utilização atual da HBM (memória de largura de banda elevada) na TPU para confirmar que tem memória suficiente disponível e obter uma medição de base antes de iniciar a inferência.

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End TPU Monitoring Library Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

Frequência de exportação das métricas de TPU

A frequência de atualização das métricas da TPU está limitada a um mínimo de um segundo. Os dados das métricas do anfitrião são exportados a uma frequência fixa de 1 Hz. A latência introduzida por este processo de exportação é insignificante. As métricas de tempo de execução da LibTPU não estão sujeitas à mesma restrição de frequência. No entanto, para manter a consistência, estas métricas também são amostradas a 1 Hz ou 1 amostra por segundo.