Membuat profil workload PyTorch XLA
Pengoptimalan performa adalah bagian penting dalam membangun model machine learning yang efisien. Anda dapat menggunakan alat pembuatan profil XProf untuk mengukur performa beban kerja machine learning Anda. XProf memungkinkan Anda merekam rekaman aktivitas mendetail tentang eksekusi model di perangkat XLA. Rekaman aktivitas ini dapat membantu Anda mengidentifikasi bottleneck performa, memahami penggunaan perangkat, dan mengoptimalkan kode.
Panduan ini menjelaskan proses pengambilan rekaman aktivitas secara terprogram dari skrip PyTorch XLA dan visualisasi menggunakan XProf dan Tensorboard.
Merekam aktivitas
Anda dapat merekam rekaman aktivitas dengan menambahkan beberapa baris kode ke skrip pelatihan yang ada. Alat utama untuk merekam aktivitas adalah modul torch_xla.debug.profiler
, yang biasanya diimpor dengan alias xp
.
1. Mulai server profiler
Sebelum dapat merekam aktivitas, Anda harus memulai server profiler. Server
ini berjalan di latar belakang skrip Anda dan mengumpulkan data rekaman aktivitas. Anda
dapat memulainya dengan memanggil xp.start_server()
di dekat awal blok
eksekusi utama.
2. Menentukan durasi rekaman aktivitas
Gabungkan kode yang ingin Anda buat profilnya dalam panggilan xp.start_trace()
dan
xp.stop_trace()
. Fungsi start_trace
mengambil jalur ke direktori tempat file rekaman aktivitas disimpan.
Praktik umumnya adalah membungkus loop pelatihan utama untuk merekam operasi yang paling relevan.
# The directory where the trace files are stored.
log_dir = '/root/logs/'
# Start tracing
xp.start_trace(log_dir)
# ... your training loop or other code to be profiled ...
train_mnist()
# Stop tracing
xp.stop_trace()
3. Menambahkan label rekaman aktivitas kustom
Secara default, rekaman aktivitas yang diambil adalah fungsi Pytorch XLA tingkat rendah dan mungkin sulit dinavigasi. Anda dapat menambahkan label kustom ke bagian tertentu pada kode menggunakan pengelola konteks xp.Trace()
. Label ini akan muncul sebagai blok bernama
dalam tampilan linimasa profiler, sehingga mempermudah identifikasi operasi
tertentu seperti penyiapan data, penerusan, atau langkah pengoptimal.
Contoh berikut menunjukkan cara menambahkan konteks ke berbagai bagian langkah pelatihan.
def forward(self, x):
# This entire block will be labeled 'forward' in the trace
with xp.Trace('forward'):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 7*7*64)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
with torch_xla.step():
with xp.Trace('train_step_data_prep_and_forward'):
optimizer.zero_grad()
data, target = data.to(device), target.to(device)
output = model(data)
with xp.Trace('train_step_loss_and_backward'):
loss = loss_fn(output, target)
loss.backward()
with xp.Trace('train_step_optimizer_step_host'):
optimizer.step()
Contoh lengkap
Contoh berikut menunjukkan cara merekam aktivitas dari skrip PyTorch XLA, berdasarkan file mnist_xla.py.
import torch
import torch.optim as optim
from torchvision import datasets, transforms
# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
def train_mnist():
# ... (model definition and data loading code) ...
print("Starting training...")
# ... (training loop as defined in the previous section) ...
print("Training finished!")
if __name__ == '__main__':
# 1. Start the profiler server
server = xp.start_server(9012)
# 2. Start capturing the trace and define the output directory
xp.start_trace('/root/logs/')
# Run the training function that contains custom trace labels
train_mnist()
# 3. Stop the trace
xp.stop_trace()
Memvisualisasikan rekaman aktivitas
Setelah skrip Anda selesai, file rekaman aktivitas akan disimpan di direktori yang Anda
tentukan (misalnya, /root/logs/
). Anda dapat memvisualisasikan rekaman aktivitas ini menggunakan
XProf dan TensorBoard.
Instal TensorBoard.
pip install tensorboard_plugin_profile tensorboard
Luncurkan TensorBoard. Arahkan TensorBoard ke direktori log yang Anda gunakan di
xp.start_trace()
:tensorboard --logdir /root/logs/
Lihat Profil. Buka URL yang diberikan oleh TensorBoard di browser web Anda (biasanya
http://localhost:6006
). Buka tab PyTorch XLA - Profile untuk melihat rekaman aktivitas interaktif. Anda dapat melihat label kustom yang dibuat dan menganalisis waktu eksekusi berbagai bagian model.
Jika Anda menggunakan Google Cloud untuk menjalankan workload, sebaiknya gunakan alat cloud-diagnostics-xprof. Fitur ini memberikan pengalaman pengumpulan dan penayangan profil yang disederhanakan menggunakan VM yang menjalankan TensorBoard dan XProf.