TPU 모니터링 라이브러리

LibTPU라는 기본 소프트웨어 계층 위에 직접 구축된 고급 TPU 모니터링 기능을 통해 Cloud TPU 하드웨어의 성능과 동작을 심층적으로 분석할 수 있습니다. LibTPU는 TPU와 상호작용하기 위한 드라이버, 네트워킹 라이브러리, XLA 컴파일러, TPU 런타임 등을 포함하고 있지만, 이 문서에서는 그중에서도 TPU 모니터링 라이브러리에 초점을 맞춰 설명합니다.

TPU 모니터링 라이브러리는 다음을 제공합니다.

  • 포괄적인 모니터링 가능성: Telemetry API 및 측정항목 모음에 액세스하여 TPU의 운영 성능과 특정 동작에 대한 상세한 정보를 얻을 수 있습니다.

  • 진단 툴킷: TPU 리소스를 디버깅하고 심층적인 성능 분석을 수행할 수 있도록 설계된 SDK 및 명령줄 인터페이스(CLI)를 제공합니다.

이러한 모니터링 기능은 고객을 대상으로 하는 최상위 솔루션으로 설계되어, TPU 워크로드를 효과적으로 최적화하는 데 필요한 핵심 도구를 제공합니다.

TPU 모니터링 라이브러리는 TPU 하드웨어에서 실행되는 머신러닝 워크로드 성능에 대한 상세한 정보를 제공합니다. 이 API는 TPU 사용률을 파악하고, 병목 현상을 식별하며, 성능 문제를 디버깅할 수 있도록 설계되었습니다. 제공되는 정보는 중단(interruption) 측정항목, 유효 처리량(goodput) 측정항목, 기타 측정항목보다 더 상세합니다.

TPU 모니터링 라이브러리 시작하기

이러한 고급 정보는 간단히 확인할 수 있습니다. TPU 모니터링 기능은 LibTPU SDK에 통합되어 있어, LibTPU를 설치하면 함께 제공됩니다.

LibTPU 설치

pip install libtpu

또는 LibTPU 업데이트는 JAX 버전과 연동되어 관리되므로, 매달 릴리스되는 최신 JAX를 설치하면 일반적으로 최신 호환 LibTPU 버전과 그 기능이 함께 설치됩니다.

JAX 설치

pip install -U "jax[tpu]"

PyTorch 사용자의 경우 PyTorch/XLA를 설치하면 최신 LibTPU 및 TPU 모니터링 기능이 제공됩니다.

PyTorch/XLA 설치

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

PyTorch/XLA 설치에 대한 자세한 내용은 PyTorch/XLA GitHub 저장소에서 설치를 참조하세요.

Python에서 라이브러리 가져오기

TPU 모니터링 라이브러리를 사용하려면 Python 코드에서 libtpu 모듈을 가져와야 합니다.

from libtpu.sdk import tpumonitoring

지원되는 모든 기능 나열

모든 측정항목 이름과 지원되는 기능을 나열합니다.


from libtpu.sdk import tpumonitoring

tpumonitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

지원되는 측정항목

다음 코드 샘플은 지원되는 모든 측정항목 이름을 나열하는 방법을 보여줍니다.

from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

다음 표에서는 모든 측정항목과 해당 정의를 보여줍니다.

측정항목 정의 API의 측정항목 이름 예시 값
TensorCore 사용률 TensorCore 연산 중 일부로 수행된 연산 비율을 기준으로 TensorCore 사용률의 백분율을 측정합니다. 1초마다 10마이크로초 동안 샘플링됩니다. 샘플링 레이트는 수정할 수 없습니다. 이 측정항목을 통해 TPU 기기에서 워크로드의 효율성을 모니터링할 수 있습니다. tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# 가속기 ID 0~3에 대한 사용률 백분율
듀티 사이클 백분율 가속기가 활성 상태로 작업을 수행한 시간의 비율을 나타냅니다. 이 값은 최근 샘플링 주기(기본 5초, LIBTPU_INIT_ARG 플래그를 설정하여 조정 가능) 동안 HLO 프로그램을 실행하는 데 사용된 사이클 수를 기준으로 측정됩니다. 이 측정항목은 TPU가 얼마나 바쁘게 작동하고 있는지를 나타내며, 칩 단위로 보고됩니다. duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# 가속기 ID 0~3에 대한 듀티 사이클 백분율
HBM 용량 합계 이 측정항목은 HBM 용량 합계(바이트)를 보고합니다. hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# 가속기 ID 0~3에 대한 HBM 용량 합계(바이트)
HBM 용량 사용량 이 측정항목은 최근 샘플링 주기(기본값은 5초, LIBTPU_INIT_ARG 플래그를 설정하여 조정 가능) 동안 HBM 용량 사용량(바이트 단위)을 보고합니다. hbm_capacity_usage ['100', '200', '300', '400']

# 가속기 ID 0~3에 대한 HBM 용량 사용량(바이트)
버퍼 전송 지연 시간 메가스케일 멀티 슬라이스 트래픽에 대한 네트워크 전송 지연 시간입니다. 이 시각화를 통해 전체적인 네트워크 성능 환경을 파악할 수 있습니다. buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# 네트워크 전송 지연 시간 분포에 대한 버퍼 크기, 평균, p50, p90, p99, p99.9
고수준 연산 실행 시간 분포 측정항목 HLO로 컴파일된 바이너리의 실행 상태에 대한 세부적인 성능 통계를 제공하여, 회귀 감지 및 모델 수준 디버깅을 지원합니다. hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# CoreType-CoreID에 대한 HLO 실행 시간 분포(평균, p50, p90, p95, p999)
고수준 옵티마이저 큐 크기 HLO 실행 큐 크기 모니터링은 컴파일된 HLO 프로그램 중 실행을 기다리거나 실행 중인 프로그램의 개수를 추적합니다. 이 측정항목은 실행 파이프라인의 혼잡도를 보여주며, 하드웨어 실행, 드라이버 오버헤드, 리소스 할당과 관련된 성능 병목 현상을 식별하는 데 도움이 됩니다. hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# CoreType-CoreID의 큐 크기를 측정합니다.
집합적 엔드 투 엔드 지연 시간 이 측정항목은 작업을 시작하는 호스트에서 출력을 수신하는 모든 피어까지 DCN을 통한 엔드 투 엔드 집합적 지연 시간을 마이크로초 단위로 측정합니다. 여기에는 호스트 측 데이터 감소와 TPU로의 출력 전송이 포함됩니다. 결과는 버퍼 크기, 유형, 평균, p50, p90, p95, p99.9 지연 시간을 자세히 설명하는 문자열입니다. collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# 전송 크기-집합 작업, 평균, p50, p90, p95, p999의 집합적 엔드 투 엔드 지연 시간

측정항목 데이터 읽기 - 스냅샷 모드

스냅샷 모드를 사용 설정하려면 tpumonitoring.get_metric 함수를 호출할 때 측정항목 이름을 지정해야 합니다. 스냅샷 모드를 사용하면 저성능 코드 영역에 임시로 측정항목 검사를 삽입하여, 성능 문제가 소프트웨어 또는 하드웨어로부터 비롯되었는지를 식별할 수 있습니다.

다음 코드 샘플은 스냅샷 모드를 사용하여 duty_cycle을 읽는 방법을 보여줍니다.

from libtpu.sdk import tpumonitoring

metric = tpumonitoring.get_metric("duty_cycle_pct")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

CLI를 사용하여 측정항목 액세스

다음 단계에서는 CLI를 사용하여 LibTPU 측정항목과 상호 작용하는 방법을 보여줍니다.

  1. tpu-info을 설치합니다.

    pip install tpu-info
    
    
    # Access help information of tpu-info
    tpu-info --help / -h
    
    
  2. tpu-info의 기본 비전을 실행합니다.

    tpu-info
    

    출력은 다음과 비슷합니다.

   TPU Chips
   ┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
    Chip         Type         Devices  PID       ┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
    /dev/accel0  TPU v4 chip  1        130007     /dev/accel1  TPU v4 chip  1        130007     /dev/accel2  TPU v4 chip  1        130007     /dev/accel3  TPU v4 chip  1        130007    └─────────────┴─────────────┴─────────┴────────┘

   TPU Runtime Utilization
   ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
    Device  Memory usage          Duty cycle    ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
    0       0.00 GiB / 31.75 GiB       0.00%     1       0.00 GiB / 31.75 GiB       0.00%     2       0.00 GiB / 31.75 GiB       0.00%     3       0.00 GiB / 31.75 GiB       0.00%    └────────┴──────────────────────┴────────────┘

   TensorCore Utilization
   ┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
    Chip ID  TensorCore Utilization    ┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
    0                         0.00%     1                         0.00%     3                         0.00%     2                         0.00% |
   └─────────┴────────────────────────┘

   Buffer Transfer Latency
   ┏━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━━┓
    Buffer Size  P50  P90  P95  P999    ┡━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━╇━━━━━━┩
          8MB+  | 0us  0us  0us   0us |
   └─────────────┴─────┴─────┴─────┴──────┘

측정항목을 사용하여 TPU 사용률 확인

다음 예시에서는 TPU 모니터링 라이브러리의 측정항목을 사용하여 TPU 사용률을 추적하는 방법을 보여줍니다.

JAX 학습 중 TPU 듀티 사이클 모니터링

시나리오: JAX 학습 스크립트를 실행 중이며, 학습 과정 전반에 걸쳐 TPU의 duty_cycle_pct 측정항목을 모니터링하여 TPU가 효과적으로 활용되고 있는지 확인하려고 합니다. 학습 중 이 측정항목을 주기적으로 로깅하면 TPU 사용률을 추적할 수 있습니다.

다음 코드 샘플은 JAX 학습 중 TPU 듀티 사이클을 모니터링하는 방법을 보여줍니다.

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring Library here to get duty_cycle ---
            duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring Library Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

JAX 추론 실행 전 HBM 사용률 확인

시나리오: JAX 모델로 추론을 실행하기 전에, TPU의 현재 HBM(고대역폭 메모리) 사용률을 확인하여 사용 가능한 메모리가 충분한지 점검하고, 추론 시작 전 기준선 수치를 확보합니다.

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End TPU Monitoring Library Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

TPU 측정항목의 내보내기 빈도

TPU 측정항목의 새로고침 빈도는 최소 1초로 제한됩니다. 호스트 측정항목 데이터는 1Hz의 고정 주기로 내보내지며, 이 과정에서 발생하는 지연 시간은 무시할 수 있을 정도로 작습니다. LibTPU에서 제공되는 런타임 측정항목은 이러한 주기 제약은 없지만, 일관성을 위해 해당 측정항목도 역시 동일하게 1Hz, 즉 초당 1회로 샘플링됩니다.