TPU Monitoring ライブラリ

基盤ソフトウェア レイヤの LibTPU に直接構築された高度な TPU モニタリング機能を使用して、Cloud TPU ハードウェアのパフォーマンスや動作に関する詳細な分析情報を取得できます。LibTPU には、TPU とのやり取りに使用されるドライバ、ネットワーキング ライブラリ、XLA コンパイラ、TPU ランタイムが含まれますが、このドキュメントでは TPU Monitoring ライブラリに焦点を当てます。

TPU Monitoring ライブラリは次の機能を提供します。

  • 包括的なオブザーバビリティ: Telemetry API と指標スイートにアクセスできます。これにより、TPU の運用パフォーマンスや特定の動作に関する詳細な分析情報が得られます。

  • 診断ツールキット: TPU リソースのデバッグや詳細なパフォーマンス分析を可能にするために設計された SDK とコマンドライン インターフェース(CLI)を提供します。

これらのモニタリング機能は、トップレベルのお客様向けソリューションとして設計されており、TPU ワークロードを効果的に最適化するために不可欠なツールとなります。

TPU Monitoring ライブラリを使用すると、ML ワークロードが TPU ハードウェア上でどのように動作しているかを示す詳細な情報が得られます。これは、TPU 使用率の把握、ボトルネックの特定、パフォーマンスに関する問題のデバッグに役立つよう設計されています。得られる情報は、中断指標、グッドプット指標、その他の指標よりも詳細です。

TPU Monitoring ライブラリを使ってみる

これらの有益な分析情報にアクセスするのは簡単です。TPU モニタリング機能は LibTPU SDK と統合されているため、LibTPU をインストールするとこの機能もインストールされます。

LibTPU をインストールする

pip install libtpu

また、LibTPU のアップデートは JAX のリリースと連携しています。つまり、最新の JAX リリース(月 1 回リリースされる)をインストールすると、通常は互換性のある最新の LibTPU バージョンとその機能がご使用のシステムに紐づけされます。

JAX をインストールする

pip install -U "jax[tpu]"

PyTorch ユーザーの場合は、PyTorch/XLA をインストールすると、最新の LibTPU と TPU モニタリング機能が提供されます。

PyTorch/XLA をインストールする

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

PyTorch/XLA のインストールの詳細については、PyTorch/XLA GitHub リポジトリのインストールをご覧ください。

Python でライブラリをインポートする

TPU Monitoring ライブラリの使用を開始するには、Python コードに libtpu モジュールをインポートする必要があります。

from libtpu.sdk import tpumonitoring

サポートされているすべての機能を一覧表示する

すべての指標名と、それらの指標でサポートされている機能を一覧表示します。


from libtpu.sdk import tpumonitoring

tpumonitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

サポートされている指標

次のコードサンプルは、サポートされているすべての指標名を一覧表示する方法を示しています。

from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

次の表に、すべての指標とその定義を示します。

指標 定義 API の指標名 値の例
Tensor Core 使用率 TensorCore の使用率を測定します。これは、TensorCore 演算の一部である演算の割合として計算されます。1 秒ごとに 10 マイクロ秒のデータがサンプリングされます。サンプリング レートを変更することはできません。この指標により、TPU デバイスでのワークロードの効率をモニタリングできます。 tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# アクセラレータ ID 0~3 の使用率(%)
デューティ サイクルの割合 過去のサンプル期間(5 秒ごと。LIBTPU_INIT_ARG フラグによって調整可能)中にアクセラレータがアクティブに処理していた時間の割合(最後のサンプリング期間中に HLO プログラムの実行に使用されていたサイクルとともに記録される)。この指標は、TPU がどれだけビジーであるかを表します。この指標はチップごとに出力されます。 duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# アクセラレータ ID 0~3 のデューティ サイクルの割合(%)
HBM 容量の合計 この指標は、HBM の合計容量をバイト単位で報告します。 hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# アクセラレータ ID 0~3 にアタッチされている HBM 容量の合計(バイト単位)
HBM 容量の使用量 この指標は、過去のサンプル期間(5 秒ごと。LIBTPU_INIT_ARG フラグによって調整可能)中の HBM 容量の使用量をバイト単位で報告します。 hbm_capacity_usage ['100', '200', '300', '400']

# アクセラレータ ID 0~3 にアタッチされている HBM 容量の使用量(バイト単位)
バッファ転送レイテンシ 大規模なマルチスライス トラフィックのネットワーク転送レイテンシ。この可視化により、全体的なネットワーク パフォーマンス環境を把握できます。 buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# ネットワーク転送レイテンシ分布のバッファサイズ、平均、p50、p90、p99、p99.9
高レベル演算の実行時間分布指標 HLO コンパイル済みバイナリの実行ステータスに関する詳細なパフォーマンス分析情報を提供します。これにより、回帰検出とモデルレベルのデバッグが可能になります。 hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# CoreType-CoreID の HLO 実行時間の分布(平均、p50、p90、p95、p999)
High Level Optimizer のキューサイズ HLO 実行キューサイズのモニタリングは、待機中または実行中のコンパイル済み HLO プログラムの数を追跡します。この指標は、実行パイプラインの輻輳を明らかにし、ハードウェア実行、ドライバ オーバーヘッド、リソース割り当てにおけるパフォーマンス ボトルネックの特定に役立ちます。 hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# CoreType-CoreID のキューサイズを測定します。
エンドツーエンドのグループ レイテンシ この指標は、オペレーションを開始したホストから、出力を受信したすべてのピアへの DCN を介したエンドツーエンドのグループ レイテンシをマイクロ秒単位で測定します。これには、ホスト側のデータ削減と TPU への出力の送信が含まれます。結果は、バッファサイズ、タイプ、平均、p50、p90、p95、p99.9 のレイテンシを詳細に説明する文字列です。 collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# 転送サイズ - エンドツーエンドのグループ レイテンシのグループ オペレーション、平均、p50、p90、p95、p999

指標データを読み取る - スナップショット モード

スナップショット モードを有効にするには、tpumonitoring.get_metric 関数を呼び出すときに指標名を指定します。スナップショット モードでは、パフォーマンスの悪いコードにアドホック指標チェックを挿入することで、パフォーマンスの問題がソフトウェアとハードウェアのどちらに起因するのかを特定できます。

次のコードサンプルは、スナップショット モードを使用して duty_cycle を読み取る方法を示しています。

from libtpu.sdk import tpumonitoring

metric = tpumonitoring.get_metric("duty_cycle_pct")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

CLI を使用して指標にアクセスする

次の手順では、CLI を使用して LibTPU 指標を扱う方法を示します。

  1. tpu-info をインストールします。

    pip install tpu-info
    
    
    # Access help information of tpu-info
    tpu-info --help / -h
    
    
  2. tpu-info のデフォルトのビジョンを実行します。

    tpu-info
    

    出力は次のようになります。

   TPU Chips
   ┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
    Chip         Type         Devices  PID       ┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
    /dev/accel0  TPU v4 chip  1        130007     /dev/accel1  TPU v4 chip  1        130007     /dev/accel2  TPU v4 chip  1        130007     /dev/accel3  TPU v4 chip  1        130007    └─────────────┴─────────────┴─────────┴────────┘

   TPU Runtime Utilization
   ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
    Device  Memory usage          Duty cycle    ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
    0       0.00 GiB / 31.75 GiB       0.00%     1       0.00 GiB / 31.75 GiB       0.00%     2       0.00 GiB / 31.75 GiB       0.00%     3       0.00 GiB / 31.75 GiB       0.00%    └────────┴──────────────────────┴────────────┘

   TensorCore Utilization
   ┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
    Chip ID  TensorCore Utilization    ┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
    0                         0.00%     1                         0.00%     3                         0.00%     2                         0.00% |
   └─────────┴────────────────────────┘

   Buffer Transfer Latency
   ┏━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━━┓
    Buffer Size  P50  P90  P95  P999    ┡━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━╇━━━━━━┩
          8MB+  | 0us  0us  0us   0us |
   └─────────────┴─────┴─────┴─────┴──────┘

指標を使用して TPU 使用率を確認する

次の例は、TPU Monitoring ライブラリの指標を使用して TPU 使用率を追跡する方法を示しています。

JAX トレーニング中の TPU デューティ サイクルをモニタリングする

シナリオ: JAX トレーニング スクリプトを実行し、トレーニング プロセスの間中 TPU の duty_cycle_pct 指標をモニタリングして TPU が効果的に使用されていることを確認します。トレーニング中にこの指標を定期的にログに出力することで、TPU 使用率を追跡できます。

次のコードサンプルは、JAX トレーニング中に TPU のデューティ サイクルをモニタリングする方法を示しています。

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring Library here to get duty_cycle ---
            duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring Library Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

JAX 推論を実行する前に HBM 使用率を確認する

シナリオ: JAX モデルで推論を実行する前に、TPU の現在の HBM(高帯域幅メモリ)使用率をチェックして十分な量のメモリが使用可能であることを確認します。また、推論の開始前にベースライン測定値を取得します。

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End TPU Monitoring Library Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

TPU 指標のエクスポート頻度

TPU 指標の更新頻度は、最小 1 秒に 1 回に制限されています。ホスト指標データは 1 Hz の固定頻度でエクスポートされます。このエクスポート プロセスによって発生するレイテンシはごくわずかです。LibTPU のランタイム指標には、同じ頻度の制約は適用されません。ただし、一貫性を保つため、これらの指標も 1 Hz(1 秒あたり 1 サンプル)でサンプリングされます。