Criar perfil de cargas de trabalho do PyTorch XLA

A otimização de performance é uma parte crucial da criação de modelos de machine learning eficientes. Use a ferramenta de criação de perfil XProf para medir o desempenho das cargas de trabalho de machine learning. O XProf permite capturar rastreamentos detalhados da execução do modelo em dispositivos XLA. Esses rastreamentos podem ajudar você a identificar gargalos de desempenho, entender a utilização do dispositivo e otimizar seu código.

Este guia descreve o processo de captura programática de um rastreamento do seu script do PyTorch XLA e visualização usando o XProf e o Tensorboard.

Capturar um rastreamento

Para capturar um rastreamento, adicione algumas linhas de código ao script de treinamento atual. A principal ferramenta para capturar um rastreamento é o módulo torch_xla.debug.profiler, que geralmente é importado com o alias xp.

1. Iniciar o servidor do profiler

Antes de capturar um rastreamento, é necessário iniciar o servidor do profiler. Esse servidor é executado em segundo plano no seu script e coleta os dados de rastreamento. Para iniciar, chame xp.start_server() perto do início do bloco de execução principal.

2. Definir a duração do rastreamento

Encapsule o código que você quer criar um perfil dentro das chamadas xp.start_trace() e xp.stop_trace(). A função start_trace usa um caminho para um diretório em que os arquivos de rastreamento são salvos.

É uma prática comum encapsular o loop de treinamento principal para capturar as operações mais relevantes.

# The directory where the trace files are stored.
log_dir = '/root/logs/'

# Start tracing
xp.start_trace(log_dir)

# ... your training loop or other code to be profiled ...
train_mnist()

# Stop tracing
xp.stop_trace()

3. Adicionar rótulos de trace personalizados

Por padrão, os rastreamentos capturados são funções Pytorch XLA de baixo nível e podem ser difíceis de navegar. É possível adicionar rótulos personalizados a seções específicas do seu código usando o gerenciador de contexto xp.Trace(). Esses rótulos aparecem como blocos nomeados na visualização da linha do tempo do criador de perfil, facilitando a identificação de operações específicas, como preparação de dados, transmissão direta ou etapa do otimizador.

O exemplo a seguir mostra como adicionar contexto a diferentes partes de uma etapa de treinamento.

def forward(self, x):
    # This entire block will be labeled 'forward' in the trace
    with xp.Trace('forward'):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
    with torch_xla.step():
        with xp.Trace('train_step_data_prep_and_forward'):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)

        with xp.Trace('train_step_loss_and_backward'):
            loss = loss_fn(output, target)
            loss.backward()

        with xp.Trace('train_step_optimizer_step_host'):
            optimizer.step()

Exemplo completo

O exemplo a seguir mostra como capturar um rastreamento de um script do PyTorch XLA com base no arquivo mnist_xla.py.

import torch
import torch.optim as optim
from torchvision import datasets, transforms

# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

def train_mnist():
    # ... (model definition and data loading code) ...
    print("Starting training...")
    # ... (training loop as defined in the previous section) ...
    print("Training finished!")

if __name__ == '__main__':
    # 1. Start the profiler server
    server = xp.start_server(9012)

    # 2. Start capturing the trace and define the output directory
    xp.start_trace('/root/logs/')

    # Run the training function that contains custom trace labels
    train_mnist()

    # 3. Stop the trace
    xp.stop_trace()

Visualizar o trace

Quando o script terminar, os arquivos de rastreamento serão salvos no diretório especificado (por exemplo, /root/logs/). É possível visualizar esse rastreamento usando o XProf e o TensorBoard.

  1. Instale o TensorBoard.

    pip install tensorboard_plugin_profile tensorboard
  2. Inicie o TensorBoard. Direcione o TensorBoard para o diretório de registros usado em xp.start_trace():

    tensorboard --logdir /root/logs/
  3. Acesse o perfil. Abra o URL fornecido pelo TensorBoard no navegador da Web (geralmente http://localhost:6006). Navegue até a guia PyTorch XLA - Profile para ver o rastreamento interativo. Você poderá conferir os rótulos personalizados que criou e analisar o tempo de execução de diferentes partes do modelo.

Se você usa Google Cloud para executar suas cargas de trabalho, recomendamos a ferramenta cloud-diagnostics-xprof. Ele oferece uma experiência simplificada de coleta e visualização de perfis usando VMs que executam o Tensorboard e o XProf.