# The directory where the trace files are stored.log_dir='/root/logs/'# Start tracingxp.start_trace(log_dir)# ... your training loop or other code to be profiled ...train_mnist()# Stop tracingxp.stop_trace()
defforward(self,x):# This entire block will be labeled 'forward' in the tracewithxp.Trace('forward'):x=F.relu(F.max_pool2d(self.conv1(x),2))x=F.relu(F.max_pool2d(self.conv2(x),2))x=x.view(-1,7*7*64)x=F.relu(self.fc1(x))x=self.fc2(x)returnF.log_softmax(x,dim=1)# You can also nest context managers for more granular detailforbatch_idx,(data,target)inenumerate(train_loader):withtorch_xla.step():withxp.Trace('train_step_data_prep_and_forward'):optimizer.zero_grad()data,target=data.to(device),target.to(device)output=model(data)withxp.Trace('train_step_loss_and_backward'):loss=loss_fn(output,target)loss.backward()withxp.Trace('train_step_optimizer_step_host'):optimizer.step()
importtorchimporttorch.optimasoptimfromtorchvisionimportdatasets,transforms# PyTorch/XLA specific importsimporttorch_xlaimporttorch_xla.core.xla_modelasxmimporttorch_xla.debug.profilerasxpdeftrain_mnist():# ... (model definition and data loading code) ...print("Starting training...")# ... (training loop as defined in the previous section) ...print("Training finished!")if__name__=='__main__':# 1. Start the profiler serverserver=xp.start_server(9012)# 2. Start capturing the trace and define the output directoryxp.start_trace('/root/logs/')# Run the training function that contains custom trace labelstrain_mnist()# 3. Stop the tracexp.stop_trace()
[[["容易理解","easyToUnderstand","thumb-up"],["確實解決了我的問題","solvedMyProblem","thumb-up"],["其他","otherUp","thumb-up"]],[["難以理解","hardToUnderstand","thumb-down"],["資訊或程式碼範例有誤","incorrectInformationOrSampleCode","thumb-down"],["缺少我需要的資訊/範例","missingTheInformationSamplesINeed","thumb-down"],["翻譯問題","translationIssue","thumb-down"],["其他","otherDown","thumb-down"]],["上次更新時間:2025-09-04 (世界標準時間)。"],[],[],null,["# Profile PyTorch XLA workloads\n=============================\n\nPerformance optimization is a crucial part of building efficient machine\nlearning models. You can use the [XProf](https://openxla.org/xprof) profiling\ntool to measure the performance of your machine learning workloads. XProf lets\nyou capture detailed traces of your model's execution on XLA devices. These\ntraces can help you to identify performance bottlenecks, understand device\nutilization, and optimize your code.\n\nThis guide describes the process of programmatically capturing a trace from your\nPyTorch XLA script and visualizing using XProf and Tensorboard.\n\nCapture a trace\n---------------\n\nYou can capture a trace by adding a few lines of code to your existing training\nscript. The primary tool for capturing a trace is the `torch_xla.debug.profiler`\nmodule, which is typically imported with the alias `xp`.\n\n### 1. Start the profiler server\n\nBefore you can capture a trace, you need to start the profiler server. This\nserver runs in the background of your script and collects the trace data. You\ncan start it by calling `xp.start_server()` near the beginning of your main\nexecution block.\n\n### 2. Define the trace duration\n\nWrap the code you want to profile within `xp.start_trace()` and\n`xp.stop_trace()` calls. The `start_trace` function takes a path to a directory\nwhere the trace files are saved.\n\nIt's common practice to wrap the main training loop to capture the most relevant\noperations. \n\n # The directory where the trace files are stored.\n log_dir = '/root/logs/'\n\n # Start tracing\n xp.start_trace(log_dir)\n\n # ... your training loop or other code to be profiled ...\n train_mnist()\n\n # Stop tracing\n xp.stop_trace()\n\n### 3. Add custom trace labels\n\nBy default, the traces captured are low-level Pytorch XLA functions and can be\nhard to navigate. You can add custom labels to specific sections of your code\nusing the `xp.Trace()` context manager. These labels will appear as named blocks\nin the profiler's timeline view, making it much easier to identify specific\noperations like data preparation, the forward pass, or the optimizer step.\n\nThe following example shows how you can add context to different parts of a\ntraining step. \n\n def forward(self, x):\n # This entire block will be labeled 'forward' in the trace\n with xp.Trace('forward'):\n x = F.relu(F.max_pool2d(self.conv1(x), 2))\n x = F.relu(F.max_pool2d(self.conv2(x), 2))\n x = x.view(-1, 7*7*64)\n x = F.relu(self.fc1(x))\n x = self.fc2(x)\n return F.log_softmax(x, dim=1)\n\n # You can also nest context managers for more granular detail\n for batch_idx, (data, target) in enumerate(train_loader):\n with torch_xla.step():\n with xp.Trace('train_step_data_prep_and_forward'):\n optimizer.zero_grad()\n data, target = data.to(device), target.to(device)\n output = model(data)\n\n with xp.Trace('train_step_loss_and_backward'):\n loss = loss_fn(output, target)\n loss.backward()\n\n with xp.Trace('train_step_optimizer_step_host'):\n optimizer.step()\n\nComplete example\n----------------\n\nThe following example shows how to capture a trace from a PyTorch XLA script,\nbased on the\n[mnist_xla.py](https://gist.github.com/vfdev-5/70f695e462443685a0922e79ce0ee899#file-mnist_xla-py)\nfile. \n\n import torch\n import torch.optim as optim\n from torchvision import datasets, transforms\n\n # PyTorch/XLA specific imports\n import torch_xla\n import torch_xla.core.xla_model as xm\n import torch_xla.debug.profiler as xp\n\n def train_mnist():\n # ... (model definition and data loading code) ...\n print(\"Starting training...\")\n # ... (training loop as defined in the previous section) ...\n print(\"Training finished!\")\n\n if __name__ == '__main__':\n # 1. Start the profiler server\n server = xp.start_server(9012)\n\n # 2. Start capturing the trace and define the output directory\n xp.start_trace('/root/logs/')\n\n # Run the training function that contains custom trace labels\n train_mnist()\n\n # 3. Stop the trace\n xp.stop_trace()\n\nVisualize the trace\n-------------------\n\nWhen your script has finished, the trace files are saved in the directory you\nspecified (for example, `/root/logs/`). You can visualize this trace using\n[XProf and TensorBoard](https://github.com/openxla/xprof).\n\n1. **Install TensorBoard.**\n\n ```bash\n pip install tensorboard_plugin_profile tensorboard\n ```\n2. **Launch TensorBoard.** Point TensorBoard to the log directory you used in\n `xp.start_trace()`:\n\n ```bash\n tensorboard --logdir /root/logs/\n ```\n3. **View the Profile.** Open the URL provided by TensorBoard in your web\n browser (usually `http://localhost:6006`). Navigate to the\n **PyTorch XLA - Profile** tab to view the interactive trace. You will be able\n to see the custom labels you created and analyze the execution time of\n different parts of your model.\n\nIf you use Google Cloud to run your workloads, we recommend the\n[cloud-diagnostics-xprof tool](https://github.com/AI-Hypercomputer/cloud-diagnostics-xprof).\nIt provides a streamlined profile collection and viewing experience using VMs\nrunning Tensorboard and XProf."]]