对 PyTorch XLA 工作负载进行性能分析
性能优化是构建高效机器学习模型的关键环节。您可以使用 XProf 性能剖析工具来衡量机器学习工作负载的性能。借助 XProf,您可以捕获模型在 XLA 设备上的执行情况的详细跟踪记录。这些跟踪记录可帮助您识别性能瓶颈、了解设备利用率并优化代码。
本指南介绍了如何以编程方式从 PyTorch XLA 脚本捕获跟踪记录,并使用 XProf 和 Tensorboard 直观呈现跟踪记录。
捕获跟踪记录
您只需向现有训练脚本添加几行代码即可捕获跟踪记录。用于捕获跟踪记录的主要工具是 torch_xla.debug.profiler
模块,该模块通常以别名 xp
导入。
1. 启动性能分析器服务器
您需要先启动性能分析器服务器,然后才能捕获跟踪记录。此服务器在脚本后台运行,并收集跟踪记录数据。您可以在主执行块的开头附近调用 xp.start_server()
来启动它。
2. 定义跟踪记录时长
将要分析的代码封装在 xp.start_trace()
和 xp.stop_trace()
调用中。start_trace
函数接受保存跟踪记录文件的目录的路径。
通常的做法是将主要训练循环封装起来,以捕获最相关的操作。
# The directory where the trace files are stored.
log_dir = '/root/logs/'
# Start tracing
xp.start_trace(log_dir)
# ... your training loop or other code to be profiled ...
train_mnist()
# Stop tracing
xp.stop_trace()
3. 添加自定义跟踪记录标签
默认情况下,捕获的跟踪记录是低级 Pytorch XLA 函数,可能难以浏览。您可以使用 xp.Trace()
上下文管理器为代码的特定部分添加自定义标签。这些标签将以命名块的形式显示在性能分析器的时间轴视图中,从而更轻松地识别特定操作,例如数据准备、前向传递或优化器步骤。
以下示例展示了如何为训练步骤的不同部分添加上下文。
def forward(self, x):
# This entire block will be labeled 'forward' in the trace
with xp.Trace('forward'):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 7*7*64)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
with torch_xla.step():
with xp.Trace('train_step_data_prep_and_forward'):
optimizer.zero_grad()
data, target = data.to(device), target.to(device)
output = model(data)
with xp.Trace('train_step_loss_and_backward'):
loss = loss_fn(output, target)
loss.backward()
with xp.Trace('train_step_optimizer_step_host'):
optimizer.step()
完整示例
以下示例展示了如何基于 mnist_xla.py 文件捕获 PyTorch XLA 脚本的跟踪记录。
import torch
import torch.optim as optim
from torchvision import datasets, transforms
# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
def train_mnist():
# ... (model definition and data loading code) ...
print("Starting training...")
# ... (training loop as defined in the previous section) ...
print("Training finished!")
if __name__ == '__main__':
# 1. Start the profiler server
server = xp.start_server(9012)
# 2. Start capturing the trace and define the output directory
xp.start_trace('/root/logs/')
# Run the training function that contains custom trace labels
train_mnist()
# 3. Stop the trace
xp.stop_trace()
直观呈现跟踪记录
脚本完成后,跟踪记录文件会保存在您指定的目录(例如 /root/logs/
)中。您可以使用 XProf 和 TensorBoard 直观呈现此跟踪记录。
安装TensorBoard。
pip install tensorboard_plugin_profile tensorboard
启动TensorBoard。将 TensorBoard 指向您在
xp.start_trace()
中使用的日志目录:tensorboard --logdir /root/logs/
查看性能分析。在网络浏览器中打开 TensorBoard 提供的网址(通常为
http://localhost:6006
)。前往 PyTorch XLA - 性能分析标签页,查看交互式跟踪记录。您将能够看到自己创建的自定义标签,并分析模型不同部分的执行时间。
如果您使用 Google Cloud 运行工作负载,我们建议您使用cloud-diagnostics-xprof 工具。它使用运行 Tensorboard 和 XProf 的虚拟机提供简化的性能分析文件收集和查看体验。