Membuat profil workload PyTorch XLA

Pengoptimalan performa adalah bagian penting dalam membangun model machine learning yang efisien. Anda dapat menggunakan alat pembuatan profil XProf untuk mengukur performa beban kerja machine learning Anda. XProf memungkinkan Anda merekam rekaman aktivitas mendetail tentang eksekusi model di perangkat XLA. Rekaman aktivitas ini dapat membantu Anda mengidentifikasi bottleneck performa, memahami penggunaan perangkat, dan mengoptimalkan kode.

Panduan ini menjelaskan proses pengambilan rekaman aktivitas secara terprogram dari skrip PyTorch XLA dan visualisasi menggunakan XProf dan Tensorboard.

Merekam aktivitas

Anda dapat merekam rekaman aktivitas dengan menambahkan beberapa baris kode ke skrip pelatihan yang ada. Alat utama untuk merekam aktivitas adalah modul torch_xla.debug.profiler, yang biasanya diimpor dengan alias xp.

1. Mulai server profiler

Sebelum dapat merekam aktivitas, Anda harus memulai server profiler. Server ini berjalan di latar belakang skrip Anda dan mengumpulkan data rekaman aktivitas. Anda dapat memulainya dengan memanggil xp.start_server() di dekat awal blok eksekusi utama.

2. Menentukan durasi rekaman aktivitas

Gabungkan kode yang ingin Anda buat profilnya dalam panggilan xp.start_trace() dan xp.stop_trace(). Fungsi start_trace mengambil jalur ke direktori tempat file rekaman aktivitas disimpan.

Praktik umumnya adalah membungkus loop pelatihan utama untuk merekam operasi yang paling relevan.

# The directory where the trace files are stored.
log_dir = '/root/logs/'

# Start tracing
xp.start_trace(log_dir)

# ... your training loop or other code to be profiled ...
train_mnist()

# Stop tracing
xp.stop_trace()

3. Menambahkan label rekaman aktivitas kustom

Secara default, rekaman aktivitas yang diambil adalah fungsi Pytorch XLA tingkat rendah dan mungkin sulit dinavigasi. Anda dapat menambahkan label kustom ke bagian tertentu pada kode menggunakan pengelola konteks xp.Trace(). Label ini akan muncul sebagai blok bernama dalam tampilan linimasa profiler, sehingga mempermudah identifikasi operasi tertentu seperti penyiapan data, penerusan, atau langkah pengoptimal.

Contoh berikut menunjukkan cara menambahkan konteks ke berbagai bagian langkah pelatihan.

def forward(self, x):
    # This entire block will be labeled 'forward' in the trace
    with xp.Trace('forward'):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
    with torch_xla.step():
        with xp.Trace('train_step_data_prep_and_forward'):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)

        with xp.Trace('train_step_loss_and_backward'):
            loss = loss_fn(output, target)
            loss.backward()

        with xp.Trace('train_step_optimizer_step_host'):
            optimizer.step()

Contoh lengkap

Contoh berikut menunjukkan cara merekam aktivitas dari skrip PyTorch XLA, berdasarkan file mnist_xla.py.

import torch
import torch.optim as optim
from torchvision import datasets, transforms

# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

def train_mnist():
    # ... (model definition and data loading code) ...
    print("Starting training...")
    # ... (training loop as defined in the previous section) ...
    print("Training finished!")

if __name__ == '__main__':
    # 1. Start the profiler server
    server = xp.start_server(9012)

    # 2. Start capturing the trace and define the output directory
    xp.start_trace('/root/logs/')

    # Run the training function that contains custom trace labels
    train_mnist()

    # 3. Stop the trace
    xp.stop_trace()

Memvisualisasikan rekaman aktivitas

Setelah skrip Anda selesai, file rekaman aktivitas akan disimpan di direktori yang Anda tentukan (misalnya, /root/logs/). Anda dapat memvisualisasikan rekaman aktivitas ini menggunakan XProf dan TensorBoard.

  1. Instal TensorBoard.

    pip install tensorboard_plugin_profile tensorboard
  2. Luncurkan TensorBoard. Arahkan TensorBoard ke direktori log yang Anda gunakan di xp.start_trace():

    tensorboard --logdir /root/logs/
  3. Lihat Profil. Buka URL yang diberikan oleh TensorBoard di browser web Anda (biasanya http://localhost:6006). Buka tab PyTorch XLA - Profile untuk melihat rekaman aktivitas interaktif. Anda dapat melihat label kustom yang dibuat dan menganalisis waktu eksekusi berbagai bagian model.

Jika Anda menggunakan Google Cloud untuk menjalankan workload, sebaiknya gunakan alat cloud-diagnostics-xprof. Fitur ini memberikan pengalaman pengumpulan dan penayangan profil yang disederhanakan menggunakan VM yang menjalankan TensorBoard dan XProf.