Profile für PyTorch XLA-Arbeitslasten erstellen
Die Leistungsoptimierung ist ein wichtiger Bestandteil der Entwicklung effizienter Modelle für maschinelles Lernen. Mit dem Profiling-Tool XProf können Sie die Leistung Ihrer Machine-Learning-Arbeitslasten messen. Mit XProf können Sie detaillierte Traces der Ausführung Ihres Modells auf XLA-Geräten erfassen. Mithilfe dieser Traces können Sie Leistungsengpässe identifizieren, die Gerätenutzung nachvollziehen und Ihren Code optimieren.
In diesem Leitfaden wird beschrieben, wie Sie programmatisch einen Trace aus Ihrem PyTorch XLA-Script erfassen und mit XProf und TensorBoard visualisieren.
Trace aufzeichnen
Sie können einen Trace erfassen, indem Sie Ihrem vorhandenen Trainingsskript einige Codezeilen hinzufügen. Das primäre Tool zum Erfassen eines Traces ist das Modul torch_xla.debug.profiler
, das normalerweise mit dem Alias xp
importiert wird.
1. Profiler-Server starten
Bevor Sie einen Trace aufzeichnen können, müssen Sie den Profiler-Server starten. Dieser Server wird im Hintergrund Ihres Skripts ausgeführt und erfasst die Tracedaten. Sie können sie aufrufen, indem Sie xp.start_server()
am Anfang Ihres Hauptausführungsblocks aufrufen.
2. Tracedauer definieren
Schließen Sie den Code, den Sie profilieren möchten, in xp.start_trace()
- und xp.stop_trace()
-Aufrufe ein. Die Funktion start_trace
verwendet einen Pfad zu einem Verzeichnis, in dem die Tracedateien gespeichert werden.
Es ist üblich, die Haupttrainingsschleife zu umschließen, um die relevantesten Vorgänge zu erfassen.
# The directory where the trace files are stored.
log_dir = '/root/logs/'
# Start tracing
xp.start_trace(log_dir)
# ... your training loop or other code to be profiled ...
train_mnist()
# Stop tracing
xp.stop_trace()
3. Benutzerdefinierte Trace-Labels hinzufügen
Standardmäßig sind die erfassten Traces Low-Level-Pytorch XLA-Funktionen und können schwer zu navigieren sein. Mit dem xp.Trace()
-Kontextmanager können Sie bestimmten Abschnitten Ihres Codes benutzerdefinierte Labels hinzufügen. Diese Labels werden als benannte Blöcke in der Zeitachse des Profilers angezeigt. So lassen sich bestimmte Vorgänge wie die Datenvorbereitung, der Forward Pass oder der Optimierungsschritt viel leichter identifizieren.
Im folgenden Beispiel sehen Sie, wie Sie verschiedenen Teilen eines Trainingsschritts Kontext hinzufügen können.
def forward(self, x):
# This entire block will be labeled 'forward' in the trace
with xp.Trace('forward'):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 7*7*64)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
with torch_xla.step():
with xp.Trace('train_step_data_prep_and_forward'):
optimizer.zero_grad()
data, target = data.to(device), target.to(device)
output = model(data)
with xp.Trace('train_step_loss_and_backward'):
loss = loss_fn(output, target)
loss.backward()
with xp.Trace('train_step_optimizer_step_host'):
optimizer.step()
Vollständiges Beispiel
Im folgenden Beispiel wird gezeigt, wie Sie einen Trace aus einem PyTorch XLA-Skript auf Grundlage der Datei mnist_xla.py erfassen.
import torch
import torch.optim as optim
from torchvision import datasets, transforms
# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
def train_mnist():
# ... (model definition and data loading code) ...
print("Starting training...")
# ... (training loop as defined in the previous section) ...
print("Training finished!")
if __name__ == '__main__':
# 1. Start the profiler server
server = xp.start_server(9012)
# 2. Start capturing the trace and define the output directory
xp.start_trace('/root/logs/')
# Run the training function that contains custom trace labels
train_mnist()
# 3. Stop the trace
xp.stop_trace()
Trace visualisieren
Wenn das Skript fertig ist, werden die Tracedateien im angegebenen Verzeichnis gespeichert (z. B. /root/logs/
). Sie können diesen Trace mit XProf und TensorBoard visualisieren.
TensorBoard installieren
pip install tensorboard_plugin_profile tensorboard
TensorBoard starten Weisen Sie TensorBoard auf das Log-Verzeichnis hin, das Sie in
xp.start_trace()
verwendet haben:tensorboard --logdir /root/logs/
Profil ansehen Öffnen Sie die von TensorBoard bereitgestellte URL in Ihrem Webbrowser (normalerweise
http://localhost:6006
). Rufen Sie den Tab PyTorch XLA – Profil auf, um den interaktiven Trace anzusehen. Sie können die benutzerdefinierten Labels sehen, die Sie erstellt haben, und die Ausführungszeit der verschiedenen Teile Ihres Modells analysieren.
Wenn Sie Google Cloud zum Ausführen Ihrer Arbeitslasten verwenden, empfehlen wir das cloud-diagnostics-xprof-Tool. Es bietet eine optimierte Erfassung und Anzeige von Profilen mithilfe von VMs, auf denen Tensorboard und XProf ausgeführt werden.