Profilare i carichi di lavoro PyTorch XLA
L'ottimizzazione delle prestazioni è una parte fondamentale della creazione di modelli di machine learning efficienti. Puoi utilizzare lo strumento di profilazione XProf per misurare il rendimento dei tuoi carichi di lavoro di machine learning. XProf ti consente di acquisire tracce dettagliate dell'esecuzione del modello sui dispositivi XLA. Queste tracce possono aiutarti a identificare i colli di bottiglia delle prestazioni, comprendere l'utilizzo del dispositivo e ottimizzare il codice.
Questa guida descrive il processo di acquisizione programmatica di una traccia dallo script PyTorch XLA e la visualizzazione utilizzando XProf e TensorBoard.
Acquisire una traccia
Puoi acquisire una traccia aggiungendo alcune righe di codice allo script di addestramento
esistente. Lo strumento principale per acquisire una traccia è il modulo torch_xla.debug.profiler
, che in genere viene importato con l'alias xp
.
1. Avvia il server del profiler
Prima di poter acquisire una traccia, devi avviare il server del profiler. Questo
server viene eseguito in background nello script e raccoglie i dati di tracciamento. Puoi
avviarlo chiamando xp.start_server()
vicino all'inizio del blocco di esecuzione
principale.
2. Definisci la durata della traccia
Racchiudi il codice che vuoi profilare all'interno delle chiamate xp.start_trace()
e
xp.stop_trace()
. La funzione start_trace
accetta un percorso a una directory
in cui vengono salvati i file di traccia.
È prassi comune racchiudere il ciclo di addestramento principale per acquisire le operazioni più pertinenti.
# The directory where the trace files are stored.
log_dir = '/root/logs/'
# Start tracing
xp.start_trace(log_dir)
# ... your training loop or other code to be profiled ...
train_mnist()
# Stop tracing
xp.stop_trace()
3. Aggiungere etichette di traccia personalizzate
Per impostazione predefinita, le tracce acquisite sono funzioni Pytorch XLA di basso livello e può essere
difficile da navigare. Puoi aggiungere etichette personalizzate a sezioni specifiche del codice
utilizzando il gestore di contesto xp.Trace()
. Queste etichette verranno visualizzate come blocchi denominati
nella visualizzazione della sequenza temporale del profiler, semplificando notevolmente l'identificazione di operazioni specifiche
come la preparazione dei dati, il forward pass o il passaggio dell'ottimizzatore.
L'esempio seguente mostra come aggiungere contesto a diverse parti di un passaggio di addestramento.
def forward(self, x):
# This entire block will be labeled 'forward' in the trace
with xp.Trace('forward'):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 7*7*64)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
with torch_xla.step():
with xp.Trace('train_step_data_prep_and_forward'):
optimizer.zero_grad()
data, target = data.to(device), target.to(device)
output = model(data)
with xp.Trace('train_step_loss_and_backward'):
loss = loss_fn(output, target)
loss.backward()
with xp.Trace('train_step_optimizer_step_host'):
optimizer.step()
Esempio completo
L'esempio seguente mostra come acquisire una traccia da uno script PyTorch XLA, in base al file mnist_xla.py.
import torch
import torch.optim as optim
from torchvision import datasets, transforms
# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
def train_mnist():
# ... (model definition and data loading code) ...
print("Starting training...")
# ... (training loop as defined in the previous section) ...
print("Training finished!")
if __name__ == '__main__':
# 1. Start the profiler server
server = xp.start_server(9012)
# 2. Start capturing the trace and define the output directory
xp.start_trace('/root/logs/')
# Run the training function that contains custom trace labels
train_mnist()
# 3. Stop the trace
xp.stop_trace()
Visualizzare la traccia
Al termine dello script, i file di traccia vengono salvati nella directory che hai
specificato (ad esempio /root/logs/
). Puoi visualizzare questa traccia utilizzando
XProf e TensorBoard.
Installa TensorBoard.
pip install tensorboard_plugin_profile tensorboard
Avvia TensorBoard. Indirizza TensorBoard alla directory dei log che hai utilizzato in
xp.start_trace()
:tensorboard --logdir /root/logs/
Visualizza il profilo. Apri l'URL fornito da TensorBoard nel browser web (di solito
http://localhost:6006
). Vai alla scheda PyTorch XLA - Profile per visualizzare la traccia interattiva. Potrai visualizzare le etichette personalizzate che hai creato e analizzare il tempo di esecuzione delle diverse parti del modello.
Se utilizzi Google Cloud per eseguire i tuoi carichi di lavoro, ti consigliamo lo strumento cloud-diagnostics-xprof. Offre un'esperienza semplificata di raccolta e visualizzazione dei profili utilizzando le VM che eseguono TensorBoard e XProf.