Crear perfiles de cargas de trabajo de PyTorch XLA

La optimización del rendimiento es una parte fundamental del desarrollo de modelos de aprendizaje automático eficientes. Puedes usar la herramienta de creación de perfiles XProf para medir el rendimiento de tus cargas de trabajo de aprendizaje automático. XProf te permite capturar trazas detalladas de la ejecución de tu modelo en dispositivos XLA. Estos registros pueden ayudarte a identificar cuellos de botella en el rendimiento, comprender el uso de los dispositivos y optimizar tu código.

En esta guía se describe el proceso para capturar una traza de forma programática desde tu secuencia de comandos de PyTorch XLA y visualizarla con XProf y TensorBoard.

Capturar una traza

Para capturar un rastreo, añade unas líneas de código a tu secuencia de comandos de entrenamiento. La herramienta principal para capturar un rastreo es el módulo torch_xla.debug.profiler, que normalmente se importa con el alias xp.

1. Iniciar el servidor del generador de perfiles

Antes de poder capturar un rastreo, debes iniciar el servidor del generador de perfiles. Este servidor se ejecuta en segundo plano en tu secuencia de comandos y recoge los datos de la traza. Puedes iniciarlo llamando a xp.start_server() cerca del principio del bloque de ejecución principal.

2. Definir la duración de la traza

Encierra el código que quieras analizar entre llamadas xp.start_trace() y xp.stop_trace(). La función start_trace toma una ruta a un directorio donde se guardan los archivos de seguimiento.

Es una práctica habitual envolver el bucle de entrenamiento principal para capturar las operaciones más relevantes.

# The directory where the trace files are stored.
log_dir = '/root/logs/'

# Start tracing
xp.start_trace(log_dir)

# ... your training loop or other code to be profiled ...
train_mnist()

# Stop tracing
xp.stop_trace()

3. Añadir etiquetas de traza personalizadas

De forma predeterminada, las trazas capturadas son funciones de Pytorch XLA de bajo nivel y puede ser difícil desplazarse por ellas. Puede añadir etiquetas personalizadas a secciones específicas de su código con el gestor de contexto xp.Trace(). Estas etiquetas aparecerán como bloques con nombre en la vista de línea de tiempo del generador de perfiles, lo que facilitará la identificación de operaciones específicas, como la preparación de datos, el pase hacia delante o el paso del optimizador.

En el siguiente ejemplo se muestra cómo añadir contexto a diferentes partes de un paso de formación.

def forward(self, x):
    # This entire block will be labeled 'forward' in the trace
    with xp.Trace('forward'):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
    with torch_xla.step():
        with xp.Trace('train_step_data_prep_and_forward'):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)

        with xp.Trace('train_step_loss_and_backward'):
            loss = loss_fn(output, target)
            loss.backward()

        with xp.Trace('train_step_optimizer_step_host'):
            optimizer.step()

Ejemplo completo

En el siguiente ejemplo se muestra cómo capturar un rastreo de una secuencia de comandos de PyTorch XLA, basado en el archivo mnist_xla.py.

import torch
import torch.optim as optim
from torchvision import datasets, transforms

# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

def train_mnist():
    # ... (model definition and data loading code) ...
    print("Starting training...")
    # ... (training loop as defined in the previous section) ...
    print("Training finished!")

if __name__ == '__main__':
    # 1. Start the profiler server
    server = xp.start_server(9012)

    # 2. Start capturing the trace and define the output directory
    xp.start_trace('/root/logs/')

    # Run the training function that contains custom trace labels
    train_mnist()

    # 3. Stop the trace
    xp.stop_trace()

Visualizar el rastreo

Cuando finalice la secuencia de comandos, los archivos de seguimiento se guardarán en el directorio que hayas especificado (por ejemplo, /root/logs/). Puedes visualizar este seguimiento con XProf y TensorBoard.

  1. Instala TensorBoard.

    pip install tensorboard_plugin_profile tensorboard
  2. Inicia TensorBoard. Indica a TensorBoard el directorio de registro que has usado en xp.start_trace():

    tensorboard --logdir /root/logs/
  3. Ver el perfil. Abre la URL proporcionada por TensorBoard en tu navegador web (normalmente, http://localhost:6006). Ve a la pestaña PyTorch XLA - Profile para ver el seguimiento interactivo. Podrás ver las etiquetas personalizadas que has creado y analizar el tiempo de ejecución de las diferentes partes de tu modelo.

Si usas Google Cloud para ejecutar tus cargas de trabajo, te recomendamos la herramienta cloud-diagnostics-xprof. Ofrece una experiencia optimizada de recogida y visualización de perfiles mediante máquinas virtuales que ejecutan TensorBoard y XProf.