TPU 監控程式庫
透過進階 TPU 監控功能,深入瞭解 Cloud TPU 硬體的效能和行為。這項功能直接建構在基礎軟體層 LibTPU 上,LibTPU 包含驅動程式、網路程式庫、XLA 編譯器和 TPU 執行階段,可用於與 TPU 互動,但本文重點是 TPU 監控程式庫。
TPU 監控程式庫提供下列功能:
全面監控:存取遙測 API 和指標套件。您可以藉此深入瞭解 TPU 的運作效能和特定行為。
診斷工具包:提供 SDK 和指令列介面 (CLI),可對 TPU 資源進行偵錯及深入分析效能。
這些監控功能是專為客戶設計的頂層解決方案,可提供必要工具,協助您有效最佳化 TPU 工作負載。
TPU 監控程式庫會提供詳細資訊,說明機器學習工作負載在 TPU 硬體上的執行情況。這項工具可協助您瞭解 TPU 使用率、找出瓶頸,以及偵錯效能問題。與中斷指標、有效輸送量指標和其他指標相比,這項指標可提供更詳細的資訊。
開始使用 TPU 監控程式庫
取得這些實用洞察資料的方式非常簡單。 TPU 監控功能已與 LibTPU SDK 整合,因此安裝 LibTPU 時會一併安裝這項功能。
安裝 LibTPU
pip install libtpu
此外,LibTPU 更新會與 JAX 版本協調,也就是說,安裝最新 JAX 版本 (每月發布) 時,通常會將您固定在最新相容的 LibTPU 版本及其功能。
安裝 JAX
pip install -U "jax[tpu]"
對於 PyTorch 使用者,安裝 PyTorch/XLA 可提供最新的 LibTPU 和 TPU 監控功能。
安裝 PyTorch/XLA
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
如要進一步瞭解如何安裝 PyTorch/XLA,請參閱 PyTorch/XLA GitHub 存放區中的「安裝」一節。
在 Python 中匯入程式庫
如要開始使用 TPU 監控程式庫,您需要在 Python 程式碼中匯入 libtpu
模組。
from libtpu.sdk import tpumonitoring
列出所有支援的功能
列出所有指標名稱和支援的功能:
from libtpu.sdk import tpumonitoring
tpumonitoring.help()
" libtpu.sdk.monitoring.help():
List all supported functionality.
libtpu.sdk.monitoring.list_support_metrics()
List support metric names in the list of str format.
libtpu.sdk.monitoring.get_metric(metric_name:str)
Get metric data with metric name. It represents the snapshot mode.
The metric data is a object with `description()` and `data()` methods,
where the `description()` returns a string describe the format of data
and data unit, `data()` returns the metric data in the list in str format.
"
支援的指標
以下程式碼範例說明如何列出所有支援的指標名稱:
from libtpu.sdk import tpumonitoring
tpumonitoring.list_supported_metrics()
["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]
下表列出所有指標及其對應定義:
指標 | 定義 | API 的指標名稱 | 範例值 |
---|---|---|---|
Tensor Core 使用率 | 測量 TensorCore 用量百分比,計算方式為 TensorCore 作業所占的作業百分比。每秒取樣一次,每次取樣 10 微秒。取樣率無法修改。 這個指標可讓您監控 TPU 裝置上工作負載的效率。 |
tensorcore_util
|
['1.11', '2.22', '3.33', '4.44']
# utilization percentage for accelerator ID 0-3 |
工作週期百分比 | 在過去取樣期間 (每 5 秒一次;可透過設定 LIBTPU_INIT_ARG 旗標調整) 內,加速器主動處理作業的時間百分比 (以過去取樣期間內執行 HLO 程式所用的週期數記錄)。這項指標代表 TPU 的忙碌程度,並會針對每個晶片發出。 |
duty_cycle_pct
|
['10.00', '20.00', '30.00', '40.00']
# Duty cycle percentage for accelerator ID 0-3 |
HBM 容量總計 | 這項指標會以位元組為單位,回報 HBM 總容量。 |
hbm_capacity_total
|
['30000000000', '30000000000', '30000000000', '30000000000']
# Total HBM capacity in bytes that attached to accelerator ID 0-3 |
HBM 容量用量 | 這項指標會回報過去樣本期間 (每 5 秒一次;可透過設定 LIBTPU_INIT_ARG 旗標調整) 的 HBM 容量用量 (以位元組為單位)。 |
hbm_capacity_usage
|
['100', '200', '300', '400']
# Capacity usage for HBM in bytes that attached to accelerator ID 0-3 |
緩衝區傳輸延遲 | 巨型多切片流量的網路傳輸延遲時間。 這項視覺化功能可協助您瞭解整體網路效能環境。 |
buffer_transfer_latency
|
["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]
# buffer size, mean, p50, p90, p99, p99.9 of network transfer latency distribution |
高階作業執行時間分配指標 | 提供 HLO 編譯二進位檔執行狀態的精細效能洞察資料,方便偵測迴歸和模型層級的偵錯。 |
hlo_exec_timing
|
["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]
# The HLO execution time duration distribution for CoreType-CoreID with mean, p50, p90, p95, p999 |
高階最佳化工具佇列大小 | HLO 執行佇列大小監控會追蹤等待或正在執行的已編譯 HLO 程式數量。這項指標會顯示執行管道壅塞情形,協助找出硬體執行、驅動程式負擔或資源分配方面的效能瓶頸。 |
hlo_queue_size
|
["tensorcore-0: 1", "tensorcore-1: 2"]
# Measures queue size for CoreType-CoreID. |
集體端對端延遲時間 | 這項指標會測量 DCN 的端對端集體延遲時間 (以微秒為單位),從主機啟動作業到所有對等互連接收輸出內容為止。包括減少主機端資料,以及將輸出內容傳送至 TPU。結果是字串,詳細說明緩衝區大小、類型,以及平均、p50、p90、p95 和 p99.9 延遲時間。 |
collective_e2e_latency
|
["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]
# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency |
讀取指標資料 - 快照模式
如要啟用快照模式,請在呼叫 tpumonitoring.get_metric
函式時指定指標名稱。您可以在效能不佳的程式碼中插入臨時指標檢查,藉此判斷效能問題是源自軟體還是硬體。
下列程式碼範例說明如何使用快照模式讀取 duty_cycle
。
from libtpu.sdk import tpumonitoring
metric = tpumonitoring.get_metric("duty_cycle_pct")
metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."
metric.data()
["0.00", "0.00", "0.00", "0.00"]
# accelerator_0-3
使用 CLI 存取指標
下列步驟說明如何使用 CLI 與 LibTPU 指標互動:
安裝
tpu-info
:pip install tpu-info
# Access help information of tpu-info tpu-info --help / -h
執行
tpu-info
的預設視覺版本:tpu-info
輸出結果會與下列內容相似:
TPU Chips
┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
┃ Chip ┃ Type ┃ Devices ┃ PID ┃
┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
│ /dev/accel0 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel1 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel2 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel3 │ TPU v4 chip │ 1 │ 130007 │
└─────────────┴─────────────┴─────────┴────────┘
TPU Runtime Utilization
┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Device ┃ Memory usage ┃ Duty cycle ┃
┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ 0 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 1 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 2 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 3 │ 0.00 GiB / 31.75 GiB │ 0.00% │
└────────┴──────────────────────┴────────────┘
TensorCore Utilization
┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Chip ID ┃ TensorCore Utilization ┃
┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
│ 0 │ 0.00% │
│ 1 │ 0.00% │
│ 3 │ 0.00% │
│ 2 │ 0.00% |
└─────────┴────────────────────────┘
Buffer Transfer Latency
┏━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━━┓
┃ Buffer Size ┃ P50 ┃ P90 ┃ P95 ┃ P999 ┃
┡━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━╇━━━━━━┩
│ 8MB+ | 0us │ 0us │ 0us │ 0us |
└─────────────┴─────┴─────┴─────┴──────┘
使用指標檢查 TPU 使用率
下列範例說明如何使用 TPU 監控程式庫中的指標,追蹤 TPU 使用率。
在 JAX 訓練期間監控 TPU 任務週期
情境:您正在執行 JAX 訓練指令碼,並想在整個訓練過程中監控 TPU 的 duty_cycle_pct
指標,確認 TPU 得到有效運用。您可以在訓練期間定期記錄這項指標,追蹤 TPU 使用率。
下列程式碼範例說明如何在 JAX 訓練期間監控 TPU 負載週期:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time
# --- Your JAX model and training setup would go here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
def loss_fn(params, x, y):
preds = simple_model(x)
return jnp.mean((preds - y)**2)
def train_step(params, x, y, optimizer):
grads = jax.grad(loss_fn)(params, x, y)
return optimizer.update(grads, params)
key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))
num_epochs = 10
log_interval_steps = 2 # Log duty cycle every 2 steps
for epoch in range(num_epochs):
for step in range(5): # Example steps per epoch
params = train_step(params, data_x, data_y, optimizer)
if (step + 1) % log_interval_steps == 0:
# --- Integrate TPU Monitoring Library here to get duty_cycle ---
duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
duty_cycle_data = duty_cycle_metric.data
print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
print(f" Description: {duty_cycle_metric.description}")
print(f" Data: {duty_cycle_data}")
# --- End TPU Monitoring Library Integration ---
# --- Rest of your training loop logic ---
time.sleep(0.1) # Simulate some computation
print("Training complete.")
執行 JAX 推論前,請先檢查 HBM 使用率
情境: 使用 JAX 模型執行推論前,請先檢查 TPU 的 HBM (高頻寬記憶體) 使用率,確認有足夠的可用記憶體,並在推論開始前取得基準測量結果。
# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and inference setup would go here ---
# --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
return jnp.sum(x)
key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters
# Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f" Description: {hbm_util_metric.description}")
print(f" Data: {hbm_util_data}")
# End TPU Monitoring Library Integration
# Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)
print("Inference complete.")
TPU 指標的匯出頻率
TPU 指標的重新整理頻率下限為一秒。主機指標資料會以 1 Hz 的固定頻率匯出,這個匯出程序造成的延遲可忽略不計。LibTPU 的執行階段指標不受相同頻率限制。不過,為確保一致性,這些指標也會以 1 Hz 的取樣率取樣,也就是每秒取樣一次。