剖析 PyTorch XLA 工作負載

效能最佳化是建構有效率機器學習模型的重要環節。您可以使用 XProf 分析工具測量機器學習工作負載的效能。XProf 可讓您擷取模型在 XLA 裝置上執行的詳細追蹤記錄。這些追蹤記錄可協助您找出效能瓶頸、瞭解裝置使用率,以及最佳化程式碼。

本指南說明如何以程式輔助方式從 PyTorch XLA 指令碼擷取追蹤記錄,並使用 XProf 和 Tensorboard 進行視覺化。

擷取追蹤記錄

您可以在現有的訓練指令碼中加入幾行程式碼,擷取追蹤記錄。擷取追蹤記錄的主要工具是 torch_xla.debug.profiler 模組,通常會以別名 xp 匯入。

1. 啟動剖析器伺服器

您必須先啟動剖析器伺服器,才能擷取追蹤記錄。這個伺服器會在指令碼的背景執行,並收集追蹤資料。您可以在主要執行區塊的開頭附近呼叫 xp.start_server(),啟動這項作業。

2. 定義追蹤時間長度

將要剖析的程式碼包裝在 xp.start_trace()xp.stop_trace() 呼叫中。start_trace 函式會採用路徑,指向儲存追蹤記錄檔案的目錄。

一般做法是包裝主要訓練迴圈,以擷取最相關的作業。

# The directory where the trace files are stored.
log_dir = '/root/logs/'

# Start tracing
xp.start_trace(log_dir)

# ... your training loop or other code to be profiled ...
train_mnist()

# Stop tracing
xp.stop_trace()

3. 新增自訂追蹤標籤

根據預設,擷取的追蹤記錄是低階的 PyTorch XLA 函式,可能難以瀏覽。您可以使用 xp.Trace() 內容管理工具,在程式碼的特定區段中加入自訂標籤。這些標籤會以具名區塊的形式顯示在分析器的時間軸檢視畫面中,讓您更容易識別特定作業,例如資料準備、正向傳遞或最佳化工具步驟。

以下範例說明如何為訓練步驟的不同部分新增背景資訊。

def forward(self, x):
    # This entire block will be labeled 'forward' in the trace
    with xp.Trace('forward'):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
    with torch_xla.step():
        with xp.Trace('train_step_data_prep_and_forward'):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)

        with xp.Trace('train_step_loss_and_backward'):
            loss = loss_fn(output, target)
            loss.backward()

        with xp.Trace('train_step_optimizer_step_host'):
            optimizer.step()

完整範例

下列範例說明如何根據 mnist_xla.py 檔案,從 PyTorch XLA 指令碼擷取追蹤記錄。

import torch
import torch.optim as optim
from torchvision import datasets, transforms

# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

def train_mnist():
    # ... (model definition and data loading code) ...
    print("Starting training...")
    # ... (training loop as defined in the previous section) ...
    print("Training finished!")

if __name__ == '__main__':
    # 1. Start the profiler server
    server = xp.start_server(9012)

    # 2. Start capturing the trace and define the output directory
    xp.start_trace('/root/logs/')

    # Run the training function that contains custom trace labels
    train_mnist()

    # 3. Stop the trace
    xp.stop_trace()

以圖表呈現追蹤記錄

指令碼完成後,追蹤記錄檔會儲存在您指定的目錄 (例如 /root/logs/)。您可以使用 XProf 和 TensorBoard 顯示這項追蹤記錄。

  1. 安裝 TensorBoard。

    pip install tensorboard_plugin_profile tensorboard
  2. 啟動 TensorBoard。將 TensorBoard 指向您在 xp.start_trace() 中使用的記錄目錄:

    tensorboard --logdir /root/logs/
  3. 查看設定檔。在網路瀏覽器中開啟 TensorBoard 提供的網址 (通常是 http://localhost:6006)。前往「PyTorch XLA - Profile」分頁,即可查看互動式追蹤記錄。您可以查看建立的自訂標籤,並分析模型不同部分的執行時間。

如果您使用 Google Cloud 執行工作負載,建議使用 cloud-diagnostics-xprof 工具。使用執行 TensorBoard 和 XProf 的 VM,即可簡化設定檔的收集和檢視體驗。