Biblioteca de monitorização de TPUs
Aceda a estatísticas detalhadas sobre o desempenho e o comportamento do hardware da Cloud TPU com capacidades de monitorização avançadas da TPU, criadas diretamente na camada de software fundamental, a LibTPU. Embora a LibTPU abranja controladores, bibliotecas de rede, o compilador XLA e o tempo de execução do TPU para interagir com os TPUs, o foco deste documento é a biblioteca de monitorização do TPU.
A biblioteca de monitorização de TPU oferece:
Observabilidade abrangente: aceda à API de telemetria e ao conjunto de métricas. Isto permite-lhe obter estatísticas detalhadas sobre o desempenho operacional e os comportamentos específicos das suas UTPs.
Kits de ferramentas de diagnóstico: fornece um SDK e uma interface de linhas de comando (CLI) concebidos para permitir a depuração e a análise detalhada do desempenho dos seus recursos de TPU.
Estas funcionalidades de monitorização foram concebidas para serem uma solução de nível superior orientada para o cliente, que lhe oferece as ferramentas essenciais para otimizar os seus trabalhos de processamento de TPUs de forma eficaz.
A biblioteca de monitorização da TPU fornece informações detalhadas sobre o desempenho das cargas de trabalho de aprendizagem automática no hardware da TPU. Foi concebida para ajudar a compreender a utilização da TPU, identificar gargalos e depurar problemas de desempenho. Dá-lhe informações mais detalhadas do que as métricas de interrupção, as métricas de débito útil e outras métricas.
Comece a usar a biblioteca de monitorização de TPUs
Aceder a estas estatísticas avançadas é simples. A funcionalidade de monitorização da TPU está integrada no SDK LibTPU, pelo que a funcionalidade está incluída quando instala o LibTPU.
Instale a LibTPU
pip install libtpu
Em alternativa, as atualizações da LibTPU são coordenadas com os lançamentos do JAX, o que significa que, quando instala o lançamento mais recente do JAX (lançado mensalmente), este fixa normalmente a versão mais recente compatível da LibTPU e as respetivas funcionalidades.
Instale o JAX
pip install -U "jax[tpu]"
Para os utilizadores do PyTorch, a instalação do PyTorch/XLA fornece a funcionalidade de monitorização do LibTPU e do TPU mais recente.
Instale o PyTorch/XLA
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Para mais informações sobre a instalação do PyTorch/XLA, consulte a secção Instalação no repositório do GitHub do PyTorch/XLA.
Importe a biblioteca em Python
Para começar a usar a biblioteca de monitorização de TPUs, tem de importar o módulo libtpu
no seu código Python.
from libtpu.sdk import tpumonitoring
Apresente todas as funcionalidades suportadas
Indique todos os nomes das métricas e a funcionalidade que suportam:
from libtpu.sdk import tpumonitoring
tpumonitoring.help()
" libtpu.sdk.monitoring.help():
List all supported functionality.
libtpu.sdk.monitoring.list_support_metrics()
List support metric names in the list of str format.
libtpu.sdk.monitoring.get_metric(metric_name:str)
Get metric data with metric name. It represents the snapshot mode.
The metric data is a object with `description()` and `data()` methods,
where the `description()` returns a string describe the format of data
and data unit, `data()` returns the metric data in the list in str format.
"
Métricas compatíveis
O seguinte exemplo de código mostra como listar todos os nomes de métricas suportados:
from libtpu.sdk import tpumonitoring
tpumonitoring.list_supported_metrics()
["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]
A tabela seguinte mostra todas as métricas e as respetivas definições:
Métrica | Definição | Nome da métrica para a API | Valores de exemplo |
---|---|---|---|
Utilização do núcleo Tensor | Mede a percentagem da utilização dos TensorCores, calculada como a percentagem de operações que fazem parte das operações dos TensorCores. Amostragem de 10 microssegundos a cada 1 segundo. Não pode modificar a taxa de amostragem. Esta métrica permite-lhe monitorizar a eficiência das suas cargas de trabalho em dispositivos TPU. |
tensorcore_util
|
['1.11', '2.22', '3.33', '4.44']
# utilization percentage for accelerator ID 0-3 |
Percentagem do ciclo de atividade | Percentagem do tempo durante o período de amostragem anterior (a cada 5 segundos; pode
ser ajustada através da definição da flag LIBTPU_INIT_ARG ) durante o qual
o acelerador estava a processar ativamente (registado com os ciclos usados para
executar programas HLO durante o último período de amostragem). Esta métrica
representa o nível de ocupação de uma TPU. A métrica é emitida por chip.
|
duty_cycle_pct
|
['10.00', '20.00', '30.00', '40.00']
# Percentagem do ciclo de serviço para o ID do acelerador 0-3 |
HBM Capacity Total | Esta métrica indica a capacidade total de HBM em bytes. |
hbm_capacity_total
|
['30000000000', '30000000000', '30000000000', '30000000000']
# Capacidade total de HBM em bytes associada ao ID do acelerador 0-3 |
Utilização da capacidade de HBM | Esta métrica comunica a utilização da capacidade de HBM em bytes durante o período de amostragem anterior (a cada 5 segundos; pode ser ajustada através da definição da flag LIBTPU_INIT_ARG ).
|
hbm_capacity_usage
|
['100', '200', '300', '400']
# Capacity usage for HBM in bytes that attached to accelerator ID 0-3 |
Latência de transferência do buffer | Latências de transferência de rede para tráfego multissegmentado de grande escala. Esta visualização permite-lhe compreender o ambiente de desempenho geral da rede. |
buffer_transfer_latency
|
["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]
# buffer size, mean, p50, p90, p99, p99.9 of network transfer latency distribution |
Métricas de distribuição do tempo de execução da operação de nível elevado | Fornece estatísticas de desempenho detalhadas sobre o estado de execução do binário compilado HLO, o que permite a deteção de regressões e a depuração ao nível do modelo. |
hlo_exec_timing
|
["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]
# A distribuição da duração do tempo de execução do HLO para CoreType-CoreID com média, p50, p90, p95 e p999 |
Tamanho da fila do otimizador de nível elevado | A monitorização do tamanho da fila de execução do HLO acompanha o número de programas HLO compilados que estão à espera ou em execução. Esta métrica revela o congestionamento do pipeline de execução, o que permite identificar restrições de desempenho na execução de hardware, sobrecarga do controlador ou atribuição de recursos. |
hlo_queue_size
|
["tensorcore-0: 1", "tensorcore-1: 2"]
# Measures queue size for CoreType-CoreID. |
Latência completa coletiva | Esta métrica mede a latência coletiva ponto a ponto na DCN em microssegundos, desde o anfitrião que inicia a operação até todos os pares receberem o resultado. Inclui a redução de dados no anfitrião e o envio de resultados para a TPU. Os resultados são strings que detalham o tamanho da memória intermédia, o tipo e as latências média, p50, p90, p95 e p99,9. |
collective_e2e_latency
|
["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]
# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency |
Ler dados de métricas – modo de resumo
Para ativar o modo de instantâneo, especifique o nome da métrica quando chamar a função tpumonitoring.get_metric
. O modo Snapshot
permite-lhe inserir verificações de métricas ad hoc em código de baixo desempenho para
identificar se os problemas de desempenho têm origem no software ou no hardware.
O exemplo de código seguinte mostra como usar o modo de instantâneo para ler o duty_cycle
.
from libtpu.sdk import tpumonitoring
metric = tpumonitoring.get_metric("duty_cycle_pct")
metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."
metric.data()
["0.00", "0.00", "0.00", "0.00"]
# accelerator_0-3
Aceda às métricas através da CLI
Os passos seguintes mostram como interagir com as métricas da LibTPU através da CLI:
Instale
tpu-info
:pip install tpu-info
# Access help information of tpu-info tpu-info --help / -h
Executar a visão predefinida de
tpu-info
:tpu-info
O resultado é semelhante ao seguinte:
TPU Chips
┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
┃ Chip ┃ Type ┃ Devices ┃ PID ┃
┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
│ /dev/accel0 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel1 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel2 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel3 │ TPU v4 chip │ 1 │ 130007 │
└─────────────┴─────────────┴─────────┴────────┘
TPU Runtime Utilization
┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Device ┃ Memory usage ┃ Duty cycle ┃
┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ 0 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 1 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 2 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 3 │ 0.00 GiB / 31.75 GiB │ 0.00% │
└────────┴──────────────────────┴────────────┘
TensorCore Utilization
┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Chip ID ┃ TensorCore Utilization ┃
┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
│ 0 │ 0.00% │
│ 1 │ 0.00% │
│ 3 │ 0.00% │
│ 2 │ 0.00% |
└─────────┴────────────────────────┘
Buffer Transfer Latency
┏━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━━┓
┃ Buffer Size ┃ P50 ┃ P90 ┃ P95 ┃ P999 ┃
┡━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━╇━━━━━━┩
│ 8MB+ | 0us │ 0us │ 0us │ 0us |
└─────────────┴─────┴─────┴─────┴──────┘
Use métricas para verificar a utilização da TPU
Os exemplos seguintes mostram como usar métricas da biblioteca de monitorização de TPUs para acompanhar a utilização de TPUs.
Monitorize o ciclo de serviço da TPU durante o treino JAX
Cenário: está a executar um script de preparação do JAX e quer monitorizar a métrica duty_cycle_pct
da TPU ao longo do processo de preparação para confirmar que as TPUs estão a ser usadas de forma eficaz. Pode registar esta métrica periodicamente durante a preparação para monitorizar a utilização da TPU.
O seguinte exemplo de código mostra como monitorizar o ciclo de serviço da TPU durante a preparação do JAX:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time
# --- Your JAX model and training setup would go here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
def loss_fn(params, x, y):
preds = simple_model(x)
return jnp.mean((preds - y)**2)
def train_step(params, x, y, optimizer):
grads = jax.grad(loss_fn)(params, x, y)
return optimizer.update(grads, params)
key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))
num_epochs = 10
log_interval_steps = 2 # Log duty cycle every 2 steps
for epoch in range(num_epochs):
for step in range(5): # Example steps per epoch
params = train_step(params, data_x, data_y, optimizer)
if (step + 1) % log_interval_steps == 0:
# --- Integrate TPU Monitoring Library here to get duty_cycle ---
duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
duty_cycle_data = duty_cycle_metric.data
print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
print(f" Description: {duty_cycle_metric.description}")
print(f" Data: {duty_cycle_data}")
# --- End TPU Monitoring Library Integration ---
# --- Rest of your training loop logic ---
time.sleep(0.1) # Simulate some computation
print("Training complete.")
Verifique a utilização da HBM antes de executar a inferência JAX
Cenário: antes de executar a inferência com o seu modelo JAX, verifique a utilização atual da HBM (memória de largura de banda elevada) na TPU para confirmar que tem memória suficiente disponível e obter uma medição de base antes de iniciar a inferência.
# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and inference setup would go here ---
# --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
return jnp.sum(x)
key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters
# Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f" Description: {hbm_util_metric.description}")
print(f" Data: {hbm_util_data}")
# End TPU Monitoring Library Integration
# Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)
print("Inference complete.")
Frequência de exportação das métricas de TPU
A frequência de atualização das métricas da TPU está limitada a um mínimo de um segundo. Os dados das métricas do anfitrião são exportados a uma frequência fixa de 1 Hz. A latência introduzida por este processo de exportação é insignificante. As métricas de tempo de execução da LibTPU não estão sujeitas à mesma restrição de frequência. No entanto, para manter a consistência, estas métricas também são amostradas a 1 Hz ou 1 amostra por segundo.