Profile für PyTorch XLA-Arbeitslasten erstellen

Die Leistungsoptimierung ist ein wichtiger Bestandteil der Entwicklung effizienter Modelle für maschinelles Lernen. Mit dem Profiling-Tool XProf können Sie die Leistung Ihrer Machine-Learning-Arbeitslasten messen. Mit XProf können Sie detaillierte Traces der Ausführung Ihres Modells auf XLA-Geräten erfassen. Mithilfe dieser Traces können Sie Leistungsengpässe identifizieren, die Gerätenutzung nachvollziehen und Ihren Code optimieren.

In diesem Leitfaden wird beschrieben, wie Sie programmatisch einen Trace aus Ihrem PyTorch XLA-Script erfassen und mit XProf und TensorBoard visualisieren.

Trace aufzeichnen

Sie können einen Trace erfassen, indem Sie Ihrem vorhandenen Trainingsskript einige Codezeilen hinzufügen. Das primäre Tool zum Erfassen eines Traces ist das Modul torch_xla.debug.profiler, das normalerweise mit dem Alias xp importiert wird.

1. Profiler-Server starten

Bevor Sie einen Trace aufzeichnen können, müssen Sie den Profiler-Server starten. Dieser Server wird im Hintergrund Ihres Skripts ausgeführt und erfasst die Tracedaten. Sie können sie aufrufen, indem Sie xp.start_server() am Anfang Ihres Hauptausführungsblocks aufrufen.

2. Tracedauer definieren

Schließen Sie den Code, den Sie profilieren möchten, in xp.start_trace()- und xp.stop_trace()-Aufrufe ein. Die Funktion start_trace verwendet einen Pfad zu einem Verzeichnis, in dem die Tracedateien gespeichert werden.

Es ist üblich, die Haupttrainingsschleife zu umschließen, um die relevantesten Vorgänge zu erfassen.

# The directory where the trace files are stored.
log_dir = '/root/logs/'

# Start tracing
xp.start_trace(log_dir)

# ... your training loop or other code to be profiled ...
train_mnist()

# Stop tracing
xp.stop_trace()

3. Benutzerdefinierte Trace-Labels hinzufügen

Standardmäßig sind die erfassten Traces Low-Level-Pytorch XLA-Funktionen und können schwer zu navigieren sein. Mit dem xp.Trace()-Kontextmanager können Sie bestimmten Abschnitten Ihres Codes benutzerdefinierte Labels hinzufügen. Diese Labels werden als benannte Blöcke in der Zeitachse des Profilers angezeigt. So lassen sich bestimmte Vorgänge wie die Datenvorbereitung, der Forward Pass oder der Optimierungsschritt viel leichter identifizieren.

Im folgenden Beispiel sehen Sie, wie Sie verschiedenen Teilen eines Trainingsschritts Kontext hinzufügen können.

def forward(self, x):
    # This entire block will be labeled 'forward' in the trace
    with xp.Trace('forward'):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
    with torch_xla.step():
        with xp.Trace('train_step_data_prep_and_forward'):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)

        with xp.Trace('train_step_loss_and_backward'):
            loss = loss_fn(output, target)
            loss.backward()

        with xp.Trace('train_step_optimizer_step_host'):
            optimizer.step()

Vollständiges Beispiel

Im folgenden Beispiel wird gezeigt, wie Sie einen Trace aus einem PyTorch XLA-Skript auf Grundlage der Datei mnist_xla.py erfassen.

import torch
import torch.optim as optim
from torchvision import datasets, transforms

# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

def train_mnist():
    # ... (model definition and data loading code) ...
    print("Starting training...")
    # ... (training loop as defined in the previous section) ...
    print("Training finished!")

if __name__ == '__main__':
    # 1. Start the profiler server
    server = xp.start_server(9012)

    # 2. Start capturing the trace and define the output directory
    xp.start_trace('/root/logs/')

    # Run the training function that contains custom trace labels
    train_mnist()

    # 3. Stop the trace
    xp.stop_trace()

Trace visualisieren

Wenn das Skript fertig ist, werden die Tracedateien im angegebenen Verzeichnis gespeichert (z. B. /root/logs/). Sie können diesen Trace mit XProf und TensorBoard visualisieren.

  1. TensorBoard installieren

    pip install tensorboard_plugin_profile tensorboard
  2. TensorBoard starten Weisen Sie TensorBoard auf das Log-Verzeichnis hin, das Sie in xp.start_trace() verwendet haben:

    tensorboard --logdir /root/logs/
  3. Profil ansehen Öffnen Sie die von TensorBoard bereitgestellte URL in Ihrem Webbrowser (normalerweise http://localhost:6006). Rufen Sie den Tab PyTorch XLA – Profil auf, um den interaktiven Trace anzusehen. Sie können die benutzerdefinierten Labels sehen, die Sie erstellt haben, und die Ausführungszeit der verschiedenen Teile Ihres Modells analysieren.

Wenn Sie Google Cloud zum Ausführen Ihrer Arbeitslasten verwenden, empfehlen wir das cloud-diagnostics-xprof-Tool. Es bietet eine optimierte Erfassung und Anzeige von Profilen mithilfe von VMs, auf denen Tensorboard und XProf ausgeführt werden.