Profiler les charges de travail PyTorch/XLA
L'optimisation des performances est un élément essentiel de la création de modèles de machine learning efficaces. Vous pouvez utiliser l'outil de profilage XProf pour mesurer les performances de vos charges de travail de machine learning. XProf vous permet de capturer des traces détaillées de l'exécution de votre modèle sur les appareils XLA. Ces traces peuvent vous aider à identifier les goulots d'étranglement des performances, à comprendre l'utilisation de l'appareil et à optimiser votre code.
Ce guide décrit la procédure à suivre pour capturer de manière programmatique une trace à partir de votre script PyTorch/XLA et la visualiser à l'aide de XProf et Tensorboard.
Capturer une trace
Vous pouvez capturer une trace en ajoutant quelques lignes de code à votre script d'entraînement existant. L'outil principal pour capturer une trace est le module torch_xla.debug.profiler
, qui est généralement importé avec l'alias xp
.
1. Démarrer le serveur du profileur
Avant de pouvoir capturer une trace, vous devez démarrer le serveur du profileur. Ce serveur s'exécute en arrière-plan de votre script et collecte les données de trace. Vous pouvez le démarrer en appelant xp.start_server()
près du début de votre bloc d'exécution principal.
2. Définir la durée de la trace
Encapsulez le code que vous souhaitez profiler dans les appels xp.start_trace()
et xp.stop_trace()
. La fonction start_trace
accepte un chemin d'accès à un répertoire dans lequel les fichiers de trace sont enregistrés.
Il est courant d'encapsuler la boucle d'entraînement principale pour capturer les opérations les plus pertinentes.
# The directory where the trace files are stored.
log_dir = '/root/logs/'
# Start tracing
xp.start_trace(log_dir)
# ... your training loop or other code to be profiled ...
train_mnist()
# Stop tracing
xp.stop_trace()
3. Ajouter des libellés de trace personnalisés
Par défaut, les traces capturées sont des fonctions Pytorch XLA de bas niveau et peuvent être difficiles à parcourir. Vous pouvez ajouter des libellés personnalisés à des sections spécifiques de votre code à l'aide du gestionnaire de contexte xp.Trace()
. Ces libellés s'affichent sous forme de blocs nommés dans la vue chronologique du profileur, ce qui permet d'identifier beaucoup plus facilement des opérations spécifiques telles que la préparation des données, la passe avant ou l'étape de l'optimiseur.
L'exemple suivant montre comment ajouter du contexte à différentes parties d'une étape d'entraînement.
def forward(self, x):
# This entire block will be labeled 'forward' in the trace
with xp.Trace('forward'):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 7*7*64)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
with torch_xla.step():
with xp.Trace('train_step_data_prep_and_forward'):
optimizer.zero_grad()
data, target = data.to(device), target.to(device)
output = model(data)
with xp.Trace('train_step_loss_and_backward'):
loss = loss_fn(output, target)
loss.backward()
with xp.Trace('train_step_optimizer_step_host'):
optimizer.step()
Exemple complet
L'exemple suivant montre comment capturer une trace à partir d'un script PyTorch/XLA, basé sur le fichier mnist_xla.py.
import torch
import torch.optim as optim
from torchvision import datasets, transforms
# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
def train_mnist():
# ... (model definition and data loading code) ...
print("Starting training...")
# ... (training loop as defined in the previous section) ...
print("Training finished!")
if __name__ == '__main__':
# 1. Start the profiler server
server = xp.start_server(9012)
# 2. Start capturing the trace and define the output directory
xp.start_trace('/root/logs/')
# Run the training function that contains custom trace labels
train_mnist()
# 3. Stop the trace
xp.stop_trace()
Visualiser la trace
Une fois votre script terminé, les fichiers de trace sont enregistrés dans le répertoire que vous avez spécifié (par exemple, /root/logs/
). Vous pouvez visualiser cette trace à l'aide de XProf et TensorBoard.
Installez TensorBoard.
pip install tensorboard_plugin_profile tensorboard
Lancez TensorBoard. Pointez TensorBoard vers le répertoire de journaux que vous avez utilisé dans
xp.start_trace()
:tensorboard --logdir /root/logs/
Affichez le profil. Ouvrez l'URL fournie par TensorBoard dans votre navigateur Web (généralement
http://localhost:6006
). Accédez à l'onglet PyTorch XLA – Profile (PyTorch XLA – Profil) pour afficher la trace interactive. Vous pourrez voir les libellés personnalisés que vous avez créés et analyser le temps d'exécution des différentes parties de votre modèle.
Si vous utilisez Google Cloud pour exécuter vos charges de travail, nous vous recommandons l'outil cloud-diagnostics-xprof. Il offre une expérience simplifiée de collecte et d'affichage des profils à l'aide de VM exécutant Tensorboard et XProf.