PyTorch XLA 워크로드 프로파일링
성능 최적화는 효율적인 머신러닝 모델을 빌드하는 데 중요한 부분입니다. XProf 프로파일링 도구를 사용하여 머신러닝 워크로드의 성능을 측정할 수 있습니다. XProf를 사용하면 XLA 기기에서 모델 실행의 자세한 trace를 캡처할 수 있습니다. 이러한 trace를 사용하면 성능 병목 현상을 식별하고, 기기 활용도를 파악하고, 코드를 최적화할 수 있습니다.
이 가이드에서는 PyTorch XLA 스크립트에서 trace를 프로그래매틱 방식으로 캡처하고 XProf 및 Tensorboard를 사용하여 시각화하는 프로세스를 설명합니다.
trace 캡처
기존 학습 스크립트에 몇 줄의 코드를 추가하여 trace를 캡처할 수 있습니다. trace를 캡처하는 기본 도구는 torch_xla.debug.profiler
모듈이며, 이 모듈은 일반적으로 xp
별칭으로 가져옵니다.
1. 프로파일러 서버 시작
trace를 캡처하려면 먼저 프로파일러 서버를 시작해야 합니다. 이 서버는 스크립트의 백그라운드에서 실행되며 trace 데이터를 수집합니다. 기본 실행 블록의 시작 부분 근처에서 xp.start_server()
를 호출하여 시작할 수 있습니다.
2. trace 기간 정의
프로파일링할 코드를 xp.start_trace()
및 xp.stop_trace()
호출 내에 래핑합니다. start_trace
함수는 trace 파일이 저장되는 디렉터리의 경로를 사용합니다.
가장 관련성 높은 작업을 포착하기 위해 기본 학습 루프를 래핑하는 것이 일반적입니다.
# The directory where the trace files are stored.
log_dir = '/root/logs/'
# Start tracing
xp.start_trace(log_dir)
# ... your training loop or other code to be profiled ...
train_mnist()
# Stop tracing
xp.stop_trace()
3. 맞춤 trace 라벨 추가
기본적으로 캡처되는 trace는 하위 수준의 PyTorch XLA 함수이며 탐색하기 어려울 수 있습니다. xp.Trace()
컨텍스트 관리자를 사용하여 코드의 특정 섹션에 맞춤 라벨을 추가할 수 있습니다. 이러한 라벨은 프로파일러의 타임라인 뷰에 이름이 지정된 블록으로 표시되므로 데이터 준비, 순방향 패스 또는 옵티마이저 단계와 같은 특정 작업을 훨씬 쉽게 식별할 수 있습니다.
다음 예시에서는 학습 단계의 여러 부분에 컨텍스트를 추가하는 방법을 보여줍니다.
def forward(self, x):
# This entire block will be labeled 'forward' in the trace
with xp.Trace('forward'):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 7*7*64)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
with torch_xla.step():
with xp.Trace('train_step_data_prep_and_forward'):
optimizer.zero_grad()
data, target = data.to(device), target.to(device)
output = model(data)
with xp.Trace('train_step_loss_and_backward'):
loss = loss_fn(output, target)
loss.backward()
with xp.Trace('train_step_optimizer_step_host'):
optimizer.step()
전체 예시
다음 예시에서는 mnist_xla.py 파일을 기반으로 PyTorch XLA 스크립트에서 trace를 캡처하는 방법을 보여줍니다.
import torch
import torch.optim as optim
from torchvision import datasets, transforms
# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
def train_mnist():
# ... (model definition and data loading code) ...
print("Starting training...")
# ... (training loop as defined in the previous section) ...
print("Training finished!")
if __name__ == '__main__':
# 1. Start the profiler server
server = xp.start_server(9012)
# 2. Start capturing the trace and define the output directory
xp.start_trace('/root/logs/')
# Run the training function that contains custom trace labels
train_mnist()
# 3. Stop the trace
xp.stop_trace()
trace 시각화
스크립트가 완료되면 trace 파일이 지정한 디렉터리(예: /root/logs/
)에 저장됩니다. XProf 및 TensorBoard를 사용하여 이 trace를 시각화할 수 있습니다.
TensorBoard를 설치합니다.
pip install tensorboard_plugin_profile tensorboard
TensorBoard를 실행합니다. TensorBoard가
xp.start_trace()
에서 사용한 로그 디렉터리를 가리키도록 합니다.tensorboard --logdir /root/logs/
프로필을 확인합니다. 웹브라우저에서 TensorBoard가 제공한 URL(일반적으로
http://localhost:6006
)을 엽니다. PyTorch XLA - Profile 탭으로 이동하여 대화형 trace를 확인합니다. 만든 맞춤 라벨을 확인하고 모델의 여러 부분의 실행 시간을 분석할 수 있습니다.
Google Cloud 를 사용하여 워크로드를 실행하는 경우 cloud-diagnostics-xprof 도구를 사용하는 것이 좋습니다. Tensorboard 및 XProf를 실행하는 VM을 사용하여 간소화된 프로필 수집 및 보기 환경을 제공합니다.