Crea perfiles de cargas de trabajo de PyTorch/XLA

La optimización del rendimiento es una parte fundamental de la creación de modelos de aprendizaje automático eficientes. Puedes usar la herramienta de generación de perfiles XProf para medir el rendimiento de tus cargas de trabajo de aprendizaje automático. XProf te permite capturar registros detallados de la ejecución de tu modelo en dispositivos XLA. Estos registros pueden ayudarte a identificar cuellos de botella en el rendimiento, comprender el uso del dispositivo y optimizar tu código.

En esta guía, se describe el proceso para capturar de forma programática un registro de tu secuencia de comandos de PyTorch XLA y visualizarlo con XProf y TensorBoard.

Cómo capturar un registro

Para capturar un registro, agrega algunas líneas de código a tu secuencia de comandos de entrenamiento existente. La herramienta principal para capturar un registro es el módulo torch_xla.debug.profiler, que suele importarse con el alias xp.

1. Inicia el servidor del profiler

Antes de capturar un registro, debes iniciar el servidor del generador de perfiles. Este servidor se ejecuta en segundo plano en tu secuencia de comandos y recopila los datos de seguimiento. Puedes iniciarla llamando a xp.start_server() cerca del comienzo de tu bloque de ejecución principal.

2. Cómo definir la duración del registro

Encapsula el código que deseas analizar dentro de las llamadas xp.start_trace() y xp.stop_trace(). La función start_trace toma una ruta de acceso a un directorio en el que se guardan los archivos de registro.

Es una práctica habitual encapsular el bucle de entrenamiento principal para capturar las operaciones más relevantes.

# The directory where the trace files are stored.
log_dir = '/root/logs/'

# Start tracing
xp.start_trace(log_dir)

# ... your training loop or other code to be profiled ...
train_mnist()

# Stop tracing
xp.stop_trace()

3. Agrega etiquetas de seguimiento personalizadas

De forma predeterminada, los registros capturados son funciones de Pytorch XLA de bajo nivel y pueden ser difíciles de navegar. Puedes agregar etiquetas personalizadas a secciones específicas de tu código con el administrador de contexto xp.Trace(). Estas etiquetas aparecerán como bloques con nombre en la vista de la línea de tiempo del generador de perfiles, lo que facilitará la identificación de operaciones específicas, como la preparación de datos, el pase hacia adelante o el paso del optimizador.

En el siguiente ejemplo, se muestra cómo puedes agregar contexto a diferentes partes de un paso de entrenamiento.

def forward(self, x):
    # This entire block will be labeled 'forward' in the trace
    with xp.Trace('forward'):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
    with torch_xla.step():
        with xp.Trace('train_step_data_prep_and_forward'):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)

        with xp.Trace('train_step_loss_and_backward'):
            loss = loss_fn(output, target)
            loss.backward()

        with xp.Trace('train_step_optimizer_step_host'):
            optimizer.step()

Ejemplo completo

En el siguiente ejemplo, se muestra cómo capturar un registro de un script de PyTorch XLA, basado en el archivo mnist_xla.py.

import torch
import torch.optim as optim
from torchvision import datasets, transforms

# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

def train_mnist():
    # ... (model definition and data loading code) ...
    print("Starting training...")
    # ... (training loop as defined in the previous section) ...
    print("Training finished!")

if __name__ == '__main__':
    # 1. Start the profiler server
    server = xp.start_server(9012)

    # 2. Start capturing the trace and define the output directory
    xp.start_trace('/root/logs/')

    # Run the training function that contains custom trace labels
    train_mnist()

    # 3. Stop the trace
    xp.stop_trace()

Visualiza el registro

Cuando finalice el script, los archivos de registro se guardarán en el directorio que especificaste (por ejemplo, /root/logs/). Puedes visualizar este registro con XProf y TensorBoard.

  1. Instala TensorBoard.

    pip install tensorboard_plugin_profile tensorboard
  2. Inicia TensorBoard. Apunta TensorBoard al directorio de registros que usaste en xp.start_trace():

    tensorboard --logdir /root/logs/
  3. Ver el perfil Abre la URL que proporciona TensorBoard en tu navegador web (generalmente http://localhost:6006). Navega a la pestaña PyTorch XLA - Profile para ver el registro interactivo. Podrás ver las etiquetas personalizadas que creaste y analizar el tiempo de ejecución de diferentes partes de tu modelo.

Si usas Google Cloud para ejecutar tus cargas de trabajo, te recomendamos la herramienta cloud-diagnostics-xprof. Proporciona una experiencia optimizada de recopilación y visualización de perfiles con VMs que ejecutan TensorBoard y XProf.