对 PyTorch XLA 工作负载进行性能分析

性能优化是构建高效机器学习模型的关键环节。您可以使用 XProf 性能剖析工具来衡量机器学习工作负载的性能。借助 XProf,您可以捕获模型在 XLA 设备上的执行情况的详细跟踪记录。这些跟踪记录可帮助您识别性能瓶颈、了解设备利用率并优化代码。

本指南介绍了如何以编程方式从 PyTorch XLA 脚本捕获跟踪记录,并使用 XProf 和 Tensorboard 直观呈现跟踪记录。

捕获跟踪记录

您只需向现有训练脚本添加几行代码即可捕获跟踪记录。用于捕获跟踪记录的主要工具是 torch_xla.debug.profiler 模块,该模块通常以别名 xp 导入。

1. 启动性能分析器服务器

您需要先启动性能分析器服务器,然后才能捕获跟踪记录。此服务器在脚本后台运行,并收集跟踪记录数据。您可以在主执行块的开头附近调用 xp.start_server() 来启动它。

2. 定义跟踪记录时长

将要分析的代码封装在 xp.start_trace()xp.stop_trace() 调用中。start_trace 函数接受保存跟踪记录文件的目录的路径。

通常的做法是将主要训练循环封装起来,以捕获最相关的操作。

# The directory where the trace files are stored.
log_dir = '/root/logs/'

# Start tracing
xp.start_trace(log_dir)

# ... your training loop or other code to be profiled ...
train_mnist()

# Stop tracing
xp.stop_trace()

3. 添加自定义跟踪记录标签

默认情况下,捕获的跟踪记录是低级 Pytorch XLA 函数,可能难以浏览。您可以使用 xp.Trace() 上下文管理器为代码的特定部分添加自定义标签。这些标签将以命名块的形式显示在性能分析器的时间轴视图中,从而更轻松地识别特定操作,例如数据准备、前向传递或优化器步骤。

以下示例展示了如何为训练步骤的不同部分添加上下文。

def forward(self, x):
    # This entire block will be labeled 'forward' in the trace
    with xp.Trace('forward'):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
    with torch_xla.step():
        with xp.Trace('train_step_data_prep_and_forward'):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)

        with xp.Trace('train_step_loss_and_backward'):
            loss = loss_fn(output, target)
            loss.backward()

        with xp.Trace('train_step_optimizer_step_host'):
            optimizer.step()

完整示例

以下示例展示了如何基于 mnist_xla.py 文件捕获 PyTorch XLA 脚本的跟踪记录。

import torch
import torch.optim as optim
from torchvision import datasets, transforms

# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

def train_mnist():
    # ... (model definition and data loading code) ...
    print("Starting training...")
    # ... (training loop as defined in the previous section) ...
    print("Training finished!")

if __name__ == '__main__':
    # 1. Start the profiler server
    server = xp.start_server(9012)

    # 2. Start capturing the trace and define the output directory
    xp.start_trace('/root/logs/')

    # Run the training function that contains custom trace labels
    train_mnist()

    # 3. Stop the trace
    xp.stop_trace()

直观呈现跟踪记录

脚本完成后,跟踪记录文件会保存在您指定的目录(例如 /root/logs/)中。您可以使用 XProf 和 TensorBoard 直观呈现此跟踪记录。

  1. 安装TensorBoard。

    pip install tensorboard_plugin_profile tensorboard
  2. 启动TensorBoard。将 TensorBoard 指向您在 xp.start_trace() 中使用的日志目录:

    tensorboard --logdir /root/logs/
  3. 查看性能分析。在网络浏览器中打开 TensorBoard 提供的网址(通常为 http://localhost:6006)。前往 PyTorch XLA - 性能分析标签页,查看交互式跟踪记录。您将能够看到自己创建的自定义标签,并分析模型不同部分的执行时间。

如果您使用 Google Cloud 运行工作负载,我们建议您使用cloud-diagnostics-xprof 工具。它使用运行 Tensorboard 和 XProf 的虚拟机提供简化的性能分析文件收集和查看体验。