PyTorch XLA 워크로드 프로파일링

성능 최적화는 효율적인 머신러닝 모델을 빌드하는 데 중요한 부분입니다. XProf 프로파일링 도구를 사용하여 머신러닝 워크로드의 성능을 측정할 수 있습니다. XProf를 사용하면 XLA 기기에서 모델 실행의 자세한 trace를 캡처할 수 있습니다. 이러한 trace를 사용하면 성능 병목 현상을 식별하고, 기기 활용도를 파악하고, 코드를 최적화할 수 있습니다.

이 가이드에서는 PyTorch XLA 스크립트에서 trace를 프로그래매틱 방식으로 캡처하고 XProf 및 Tensorboard를 사용하여 시각화하는 프로세스를 설명합니다.

trace 캡처

기존 학습 스크립트에 몇 줄의 코드를 추가하여 trace를 캡처할 수 있습니다. trace를 캡처하는 기본 도구는 torch_xla.debug.profiler 모듈이며, 이 모듈은 일반적으로 xp 별칭으로 가져옵니다.

1. 프로파일러 서버 시작

trace를 캡처하려면 먼저 프로파일러 서버를 시작해야 합니다. 이 서버는 스크립트의 백그라운드에서 실행되며 trace 데이터를 수집합니다. 기본 실행 블록의 시작 부분 근처에서 xp.start_server()를 호출하여 시작할 수 있습니다.

2. trace 기간 정의

프로파일링할 코드를 xp.start_trace()xp.stop_trace() 호출 내에 래핑합니다. start_trace 함수는 trace 파일이 저장되는 디렉터리의 경로를 사용합니다.

가장 관련성 높은 작업을 포착하기 위해 기본 학습 루프를 래핑하는 것이 일반적입니다.

# The directory where the trace files are stored.
log_dir = '/root/logs/'

# Start tracing
xp.start_trace(log_dir)

# ... your training loop or other code to be profiled ...
train_mnist()

# Stop tracing
xp.stop_trace()

3. 맞춤 trace 라벨 추가

기본적으로 캡처되는 trace는 하위 수준의 PyTorch XLA 함수이며 탐색하기 어려울 수 있습니다. xp.Trace() 컨텍스트 관리자를 사용하여 코드의 특정 섹션에 맞춤 라벨을 추가할 수 있습니다. 이러한 라벨은 프로파일러의 타임라인 뷰에 이름이 지정된 블록으로 표시되므로 데이터 준비, 순방향 패스 또는 옵티마이저 단계와 같은 특정 작업을 훨씬 쉽게 식별할 수 있습니다.

다음 예시에서는 학습 단계의 여러 부분에 컨텍스트를 추가하는 방법을 보여줍니다.

def forward(self, x):
    # This entire block will be labeled 'forward' in the trace
    with xp.Trace('forward'):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
    with torch_xla.step():
        with xp.Trace('train_step_data_prep_and_forward'):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)

        with xp.Trace('train_step_loss_and_backward'):
            loss = loss_fn(output, target)
            loss.backward()

        with xp.Trace('train_step_optimizer_step_host'):
            optimizer.step()

전체 예시

다음 예시에서는 mnist_xla.py 파일을 기반으로 PyTorch XLA 스크립트에서 trace를 캡처하는 방법을 보여줍니다.

import torch
import torch.optim as optim
from torchvision import datasets, transforms

# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

def train_mnist():
    # ... (model definition and data loading code) ...
    print("Starting training...")
    # ... (training loop as defined in the previous section) ...
    print("Training finished!")

if __name__ == '__main__':
    # 1. Start the profiler server
    server = xp.start_server(9012)

    # 2. Start capturing the trace and define the output directory
    xp.start_trace('/root/logs/')

    # Run the training function that contains custom trace labels
    train_mnist()

    # 3. Stop the trace
    xp.stop_trace()

trace 시각화

스크립트가 완료되면 trace 파일이 지정한 디렉터리(예: /root/logs/)에 저장됩니다. XProf 및 TensorBoard를 사용하여 이 trace를 시각화할 수 있습니다.

  1. TensorBoard를 설치합니다.

    pip install tensorboard_plugin_profile tensorboard
  2. TensorBoard를 실행합니다. TensorBoard가 xp.start_trace()에서 사용한 로그 디렉터리를 가리키도록 합니다.

    tensorboard --logdir /root/logs/
  3. 프로필을 확인합니다. 웹브라우저에서 TensorBoard가 제공한 URL(일반적으로 http://localhost:6006)을 엽니다. PyTorch XLA - Profile 탭으로 이동하여 대화형 trace를 확인합니다. 만든 맞춤 라벨을 확인하고 모델의 여러 부분의 실행 시간을 분석할 수 있습니다.

Google Cloud 를 사용하여 워크로드를 실행하는 경우 cloud-diagnostics-xprof 도구를 사용하는 것이 좋습니다. Tensorboard 및 XProf를 실행하는 VM을 사용하여 간소화된 프로필 수집 및 보기 환경을 제공합니다.