Criar perfil de cargas de trabalho do PyTorch XLA
A otimização de performance é uma parte crucial da criação de modelos de machine learning eficientes. Use a ferramenta de criação de perfil XProf para medir o desempenho das cargas de trabalho de machine learning. O XProf permite capturar rastreamentos detalhados da execução do modelo em dispositivos XLA. Esses rastreamentos podem ajudar você a identificar gargalos de desempenho, entender a utilização do dispositivo e otimizar seu código.
Este guia descreve o processo de captura programática de um rastreamento do seu script do PyTorch XLA e visualização usando o XProf e o Tensorboard.
Capturar um rastreamento
Para capturar um rastreamento, adicione algumas linhas de código ao script de treinamento
atual. A principal ferramenta para capturar um rastreamento é o módulo torch_xla.debug.profiler
, que geralmente é importado com o alias xp
.
1. Iniciar o servidor do profiler
Antes de capturar um rastreamento, é necessário iniciar o servidor do profiler. Esse servidor é executado em segundo plano no seu script e coleta os dados de rastreamento. Para iniciar, chame xp.start_server()
perto do início do bloco de execução principal.
2. Definir a duração do rastreamento
Encapsule o código que você quer criar um perfil dentro das chamadas xp.start_trace()
e xp.stop_trace()
. A função start_trace
usa um caminho para um diretório
em que os arquivos de rastreamento são salvos.
É uma prática comum encapsular o loop de treinamento principal para capturar as operações mais relevantes.
# The directory where the trace files are stored.
log_dir = '/root/logs/'
# Start tracing
xp.start_trace(log_dir)
# ... your training loop or other code to be profiled ...
train_mnist()
# Stop tracing
xp.stop_trace()
3. Adicionar rótulos de trace personalizados
Por padrão, os rastreamentos capturados são funções Pytorch XLA de baixo nível e podem ser difíceis de navegar. É possível adicionar rótulos personalizados a seções específicas do seu código
usando o gerenciador de contexto xp.Trace()
. Esses rótulos aparecem como blocos nomeados na visualização da linha do tempo do criador de perfil, facilitando a identificação de operações específicas, como preparação de dados, transmissão direta ou etapa do otimizador.
O exemplo a seguir mostra como adicionar contexto a diferentes partes de uma etapa de treinamento.
def forward(self, x):
# This entire block will be labeled 'forward' in the trace
with xp.Trace('forward'):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 7*7*64)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
with torch_xla.step():
with xp.Trace('train_step_data_prep_and_forward'):
optimizer.zero_grad()
data, target = data.to(device), target.to(device)
output = model(data)
with xp.Trace('train_step_loss_and_backward'):
loss = loss_fn(output, target)
loss.backward()
with xp.Trace('train_step_optimizer_step_host'):
optimizer.step()
Exemplo completo
O exemplo a seguir mostra como capturar um rastreamento de um script do PyTorch XLA com base no arquivo mnist_xla.py.
import torch
import torch.optim as optim
from torchvision import datasets, transforms
# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
def train_mnist():
# ... (model definition and data loading code) ...
print("Starting training...")
# ... (training loop as defined in the previous section) ...
print("Training finished!")
if __name__ == '__main__':
# 1. Start the profiler server
server = xp.start_server(9012)
# 2. Start capturing the trace and define the output directory
xp.start_trace('/root/logs/')
# Run the training function that contains custom trace labels
train_mnist()
# 3. Stop the trace
xp.stop_trace()
Visualizar o trace
Quando o script terminar, os arquivos de rastreamento serão salvos no diretório especificado (por exemplo, /root/logs/
). É possível visualizar esse rastreamento usando o XProf e o TensorBoard.
Instale o TensorBoard.
pip install tensorboard_plugin_profile tensorboard
Inicie o TensorBoard. Direcione o TensorBoard para o diretório de registros usado em
xp.start_trace()
:tensorboard --logdir /root/logs/
Acesse o perfil. Abra o URL fornecido pelo TensorBoard no navegador da Web (geralmente
http://localhost:6006
). Navegue até a guia PyTorch XLA - Profile para ver o rastreamento interativo. Você poderá conferir os rótulos personalizados que criou e analisar o tempo de execução de diferentes partes do modelo.
Se você usa Google Cloud para executar suas cargas de trabalho, recomendamos a ferramenta cloud-diagnostics-xprof. Ele oferece uma experiência simplificada de coleta e visualização de perfis usando VMs que executam o Tensorboard e o XProf.