TPU 監控程式庫

透過進階 TPU 監控功能,深入瞭解 Cloud TPU 硬體的效能和行為。這項功能直接建構在基礎軟體層 LibTPU 上,LibTPU 包含驅動程式、網路程式庫、XLA 編譯器和 TPU 執行階段,可用於與 TPU 互動,但本文重點是 TPU 監控程式庫。

TPU 監控程式庫提供下列功能:

  • 全面監控:存取遙測 API 和指標套件。您可以藉此深入瞭解 TPU 的運作效能和特定行為。

  • 診斷工具包:提供 SDK 和指令列介面 (CLI),可對 TPU 資源進行偵錯及深入分析效能。

這些監控功能是專為客戶設計的頂層解決方案,可提供必要工具,協助您有效最佳化 TPU 工作負載。

TPU 監控程式庫會提供詳細資訊,說明機器學習工作負載在 TPU 硬體上的執行情況。這項工具可協助您瞭解 TPU 使用率、找出瓶頸,以及偵錯效能問題。與中斷指標、有效輸送量指標和其他指標相比,這項指標可提供更詳細的資訊。

開始使用 TPU 監控程式庫

取得這些實用洞察資料的方式非常簡單。 TPU 監控功能已與 LibTPU SDK 整合,因此安裝 LibTPU 時會一併安裝這項功能。

安裝 LibTPU

pip install libtpu

此外,LibTPU 更新會與 JAX 版本協調,也就是說,安裝最新 JAX 版本 (每月發布) 時,通常會將您固定在最新相容的 LibTPU 版本及其功能。

安裝 JAX

pip install -U "jax[tpu]"

對於 PyTorch 使用者,安裝 PyTorch/XLA 可提供最新的 LibTPU 和 TPU 監控功能。

安裝 PyTorch/XLA

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

如要進一步瞭解如何安裝 PyTorch/XLA,請參閱 PyTorch/XLA GitHub 存放區中的「安裝」一節。

在 Python 中匯入程式庫

如要開始使用 TPU 監控程式庫,您需要在 Python 程式碼中匯入 libtpu 模組。

from libtpu.sdk import tpumonitoring

列出所有支援的功能

列出所有指標名稱和支援的功能:


from libtpu.sdk import tpumonitoring

tpumonitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

支援的指標

以下程式碼範例說明如何列出所有支援的指標名稱:

from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

下表列出所有指標及其對應定義:

指標 定義 API 的指標名稱 範例值
Tensor Core 使用率 測量 TensorCore 用量百分比,計算方式為 TensorCore 作業所占的作業百分比。每秒取樣一次,每次取樣 10 微秒。取樣率無法修改。 這個指標可讓您監控 TPU 裝置上工作負載的效率。 tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# utilization percentage for accelerator ID 0-3
工作週期百分比 在過去取樣期間 (每 5 秒一次;可透過設定 LIBTPU_INIT_ARG 旗標調整) 內,加速器主動處理作業的時間百分比 (以過去取樣期間內執行 HLO 程式所用的週期數記錄)。這項指標代表 TPU 的忙碌程度,並會針對每個晶片發出。 duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# Duty cycle percentage for accelerator ID 0-3
HBM 容量總計 這項指標會以位元組為單位,回報 HBM 總容量。 hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# Total HBM capacity in bytes that attached to accelerator ID 0-3
HBM 容量用量 這項指標會回報過去樣本期間 (每 5 秒一次;可透過設定 LIBTPU_INIT_ARG 旗標調整) 的 HBM 容量用量 (以位元組為單位)。 hbm_capacity_usage ['100', '200', '300', '400']

# Capacity usage for HBM in bytes that attached to accelerator ID 0-3
緩衝區傳輸延遲 巨型多切片流量的網路傳輸延遲時間。 這項視覺化功能可協助您瞭解整體網路效能環境。 buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# buffer size, mean, p50, p90, p99, p99.9 of network transfer latency distribution
高階作業執行時間分配指標 提供 HLO 編譯二進位檔執行狀態的精細效能洞察資料,方便偵測迴歸和模型層級的偵錯。 hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# The HLO execution time duration distribution for CoreType-CoreID with mean, p50, p90, p95, p999
高階最佳化工具佇列大小 HLO 執行佇列大小監控會追蹤等待或正在執行的已編譯 HLO 程式數量。這項指標會顯示執行管道壅塞情形,協助找出硬體執行、驅動程式負擔或資源分配方面的效能瓶頸。 hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# Measures queue size for CoreType-CoreID.
集體端對端延遲時間 這項指標會測量 DCN 的端對端集體延遲時間 (以微秒為單位),從主機啟動作業到所有對等互連接收輸出內容為止。包括減少主機端資料,以及將輸出內容傳送至 TPU。結果是字串,詳細說明緩衝區大小、類型,以及平均、p50、p90、p95 和 p99.9 延遲時間。 collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency

讀取指標資料 - 快照模式

如要啟用快照模式,請在呼叫 tpumonitoring.get_metric 函式時指定指標名稱。您可以在效能不佳的程式碼中插入臨時指標檢查,藉此判斷效能問題是源自軟體還是硬體。

下列程式碼範例說明如何使用快照模式讀取 duty_cycle

from libtpu.sdk import tpumonitoring

metric = tpumonitoring.get_metric("duty_cycle_pct")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

使用 CLI 存取指標

下列步驟說明如何使用 CLI 與 LibTPU 指標互動:

  1. 安裝 tpu-info

    pip install tpu-info
    
    
    # Access help information of tpu-info
    tpu-info --help / -h
    
    
  2. 執行 tpu-info 的預設視覺版本:

    tpu-info
    

    輸出結果會與下列內容相似:

   TPU Chips
   ┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
    Chip         Type         Devices  PID       ┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
    /dev/accel0  TPU v4 chip  1        130007     /dev/accel1  TPU v4 chip  1        130007     /dev/accel2  TPU v4 chip  1        130007     /dev/accel3  TPU v4 chip  1        130007    └─────────────┴─────────────┴─────────┴────────┘

   TPU Runtime Utilization
   ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
    Device  Memory usage          Duty cycle    ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
    0       0.00 GiB / 31.75 GiB       0.00%     1       0.00 GiB / 31.75 GiB       0.00%     2       0.00 GiB / 31.75 GiB       0.00%     3       0.00 GiB / 31.75 GiB       0.00%    └────────┴──────────────────────┴────────────┘

   TensorCore Utilization
   ┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
    Chip ID  TensorCore Utilization    ┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
    0                         0.00%     1                         0.00%     3                         0.00%     2                         0.00% |
   └─────────┴────────────────────────┘

   Buffer Transfer Latency
   ┏━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━━┓
    Buffer Size  P50  P90  P95  P999    ┡━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━╇━━━━━━┩
          8MB+  | 0us  0us  0us   0us |
   └─────────────┴─────┴─────┴─────┴──────┘

使用指標檢查 TPU 使用率

下列範例說明如何使用 TPU 監控程式庫中的指標,追蹤 TPU 使用率。

在 JAX 訓練期間監控 TPU 任務週期

情境:您正在執行 JAX 訓練指令碼,並想在整個訓練過程中監控 TPU 的 duty_cycle_pct 指標,確認 TPU 得到有效運用。您可以在訓練期間定期記錄這項指標,追蹤 TPU 使用率。

下列程式碼範例說明如何在 JAX 訓練期間監控 TPU 負載週期:

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring Library here to get duty_cycle ---
            duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring Library Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

執行 JAX 推論前,請先檢查 HBM 使用率

情境: 使用 JAX 模型執行推論前,請先檢查 TPU 的 HBM (高頻寬記憶體) 使用率,確認有足夠的可用記憶體,並在推論開始前取得基準測量結果。

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End TPU Monitoring Library Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

TPU 指標的匯出頻率

TPU 指標的重新整理頻率下限為一秒。主機指標資料會以 1 Hz 的固定頻率匯出,這個匯出程序造成的延遲可忽略不計。LibTPU 的執行階段指標不受相同頻率限制。不過,為確保一致性,這些指標也會以 1 Hz 的取樣率取樣,也就是每秒取樣一次。