Profilare i carichi di lavoro PyTorch XLA

L'ottimizzazione delle prestazioni è una parte fondamentale della creazione di modelli di machine learning efficienti. Puoi utilizzare lo strumento di profilazione XProf per misurare il rendimento dei tuoi carichi di lavoro di machine learning. XProf ti consente di acquisire tracce dettagliate dell'esecuzione del modello sui dispositivi XLA. Queste tracce possono aiutarti a identificare i colli di bottiglia delle prestazioni, comprendere l'utilizzo del dispositivo e ottimizzare il codice.

Questa guida descrive il processo di acquisizione programmatica di una traccia dallo script PyTorch XLA e la visualizzazione utilizzando XProf e TensorBoard.

Acquisire una traccia

Puoi acquisire una traccia aggiungendo alcune righe di codice allo script di addestramento esistente. Lo strumento principale per acquisire una traccia è il modulo torch_xla.debug.profiler, che in genere viene importato con l'alias xp.

1. Avvia il server del profiler

Prima di poter acquisire una traccia, devi avviare il server del profiler. Questo server viene eseguito in background nello script e raccoglie i dati di tracciamento. Puoi avviarlo chiamando xp.start_server() vicino all'inizio del blocco di esecuzione principale.

2. Definisci la durata della traccia

Racchiudi il codice che vuoi profilare all'interno delle chiamate xp.start_trace() e xp.stop_trace(). La funzione start_trace accetta un percorso a una directory in cui vengono salvati i file di traccia.

È prassi comune racchiudere il ciclo di addestramento principale per acquisire le operazioni più pertinenti.

# The directory where the trace files are stored.
log_dir = '/root/logs/'

# Start tracing
xp.start_trace(log_dir)

# ... your training loop or other code to be profiled ...
train_mnist()

# Stop tracing
xp.stop_trace()

3. Aggiungere etichette di traccia personalizzate

Per impostazione predefinita, le tracce acquisite sono funzioni Pytorch XLA di basso livello e può essere difficile da navigare. Puoi aggiungere etichette personalizzate a sezioni specifiche del codice utilizzando il gestore di contesto xp.Trace(). Queste etichette verranno visualizzate come blocchi denominati nella visualizzazione della sequenza temporale del profiler, semplificando notevolmente l'identificazione di operazioni specifiche come la preparazione dei dati, il forward pass o il passaggio dell'ottimizzatore.

L'esempio seguente mostra come aggiungere contesto a diverse parti di un passaggio di addestramento.

def forward(self, x):
    # This entire block will be labeled 'forward' in the trace
    with xp.Trace('forward'):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
    with torch_xla.step():
        with xp.Trace('train_step_data_prep_and_forward'):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)

        with xp.Trace('train_step_loss_and_backward'):
            loss = loss_fn(output, target)
            loss.backward()

        with xp.Trace('train_step_optimizer_step_host'):
            optimizer.step()

Esempio completo

L'esempio seguente mostra come acquisire una traccia da uno script PyTorch XLA, in base al file mnist_xla.py.

import torch
import torch.optim as optim
from torchvision import datasets, transforms

# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

def train_mnist():
    # ... (model definition and data loading code) ...
    print("Starting training...")
    # ... (training loop as defined in the previous section) ...
    print("Training finished!")

if __name__ == '__main__':
    # 1. Start the profiler server
    server = xp.start_server(9012)

    # 2. Start capturing the trace and define the output directory
    xp.start_trace('/root/logs/')

    # Run the training function that contains custom trace labels
    train_mnist()

    # 3. Stop the trace
    xp.stop_trace()

Visualizzare la traccia

Al termine dello script, i file di traccia vengono salvati nella directory che hai specificato (ad esempio /root/logs/). Puoi visualizzare questa traccia utilizzando XProf e TensorBoard.

  1. Installa TensorBoard.

    pip install tensorboard_plugin_profile tensorboard
  2. Avvia TensorBoard. Indirizza TensorBoard alla directory dei log che hai utilizzato in xp.start_trace():

    tensorboard --logdir /root/logs/
  3. Visualizza il profilo. Apri l'URL fornito da TensorBoard nel browser web (di solito http://localhost:6006). Vai alla scheda PyTorch XLA - Profile per visualizzare la traccia interattiva. Potrai visualizzare le etichette personalizzate che hai creato e analizzare il tempo di esecuzione delle diverse parti del modello.

Se utilizzi Google Cloud per eseguire i tuoi carichi di lavoro, ti consigliamo lo strumento cloud-diagnostics-xprof. Offre un'esperienza semplificata di raccolta e visualizzazione dei profili utilizzando le VM che eseguono TensorBoard e XProf.