PyTorch/XLA: Performance debugging on Cloud TPU VM: Part I
Vaibhav Singh
Group Product Manager
In this three part series we explore the performance debugging ecosystem of PyTorch/XLA on Google Cloud TPU VM. TPU VM last year (2021). The TPU VM architecture allows the ML practitioners to work directly on the host where TPU hardware is attached. With the TPU profiler, debugging your PyTorch training on TPU VM is simpler than ever before. While the process to analyze the performance has changed, the fundamentals of PyTorch/XLA that you have acquired with the network attached TPU architecture (aka TPU Node architecture), still apply.
In this (first) part we will briefly lay out the conceptual framework for PyTorch/XLA in the context of training performance. Please note that training performance in the current scope refers to training throughput, i.e. samples/sec, images/sec or equivalent. We use a case study to make sense of preliminary profiler logs and identify the corrective actions. The solution to solve the performance bottleneck will be left as an exercise to the reader.
In part-II of this series we will discuss the solution left as an exercise in the part-I and introduce further analysis of the performance to identify other performance improvement opportunities.
Finally, in part-III, we introduce the user defined code annotation. We will see how to visualize these annotations in the form of a trace and introduce some basic concepts to understand the trace.
By the end of this series, we aim to give you a better understanding of how to analyze performance of your PyTorch code on Cloud TPUs and things to consider when working with Cloud TPUs.
Pre-Reading
An understanding of inner workings of XLA Tensor can make the following content more accessible and useful. We encourage you to review this talk from PyTorch Developers Day 2020 and this talk from Google Cloud Next for a quick primer on XLA Tensors. You may also find this article helpful if you are new to PyTorch/XLA. This article also assumes that the reader is familiar with Google Cloud Platform SDK and has access to a Google Cloud project with permissions to create resources such as virtual machines and Cloud TPU instances. Most of the profiler concepts will be explained here, however, introductory reading of TPU VM Profiler is also recommended.
Client-Server Terminology for PyTorch/XLA
As in the TPU Node architecture (before TPU VM) PyTorch XLA still uses the lazy tensor paradigm, i.e. when you are using XLA Tensors, any operations performed on this tensor are simply recorded in an intermediate representation (IR) graph. When a step is marked (xm.mark_step() call), this graph is converted to XLA (HLO format - High Level Operations) and dispatched for execution to TPU runtime (server).
Note that the TPU runtime is the part of TPU server side functionality and all the work done up to the generation of the HLO graph is part of (and henceforth referred to as) the client side functionality. Unlike the previous generation where the TPU runtime (server) was automatically started when you created a TPU instance, incase of TPU VM, PyTorch/XLA library takes care of starting the server when you submit a training. You can also start the XRT (XLA Runtime) server manually on the desired port, Hence the XRT_TPU_CONFIG set in the code snippets later in the post refers to the default port where PyTorch/XLA starts the XRT server. Unlike the previous generation, client and server run on the same host however the abstractions still hold and are helpful to understand the performance (more details here).
Case Study
Context
We will examine UniT (Unified Transformer) training on GLUE/QNLI task using the MMF framework for multi-modal learning from Facebook Research. We will discover an interesting aspect of Multihead Attention Implementation (observed in PyTorch 1.8) that incidentally results in sub-optimal training performance with PyTorch/XLA and discuss a potential corrective action.
Environment Setup
The case study uses TPU VM. In the following steps we create a TPU VM. The following commands can be run from Google Cloud Shell or any machine with the Google Cloud SDK installed and the correct credentials provisioned. (For more details please refer to TPU VM user guide.)
Once the TPU VM is created and is in READY state, login (ssh) onto the TPU VM host, install TensorBoard profiler plugin and start the TensorBoard server. Please refer to the instructions included in the TPU VM profiler user guide to setup the environment.
Training Setup
We will use two PyTorch environments in the case study beginning with PyTorch 1.8.1 and then move to PyTorch 1.9 as we develop. To ensure PyTorch 1.8.1 as the starting point please execute the following instructions on your TPU VM that was created from the previous section.
Update alternative (make python3 default):
Configure environment variables:
MMF Training Environment
MMF (Multimodal Training Framework) library developed by Meta Research is built to help researchers easily experiment with the models for multi-modal (text/image/audio) learning problems. As described in the case study context we will use the Unified Transformer (UniT) model for this case study. We will begin by cloning and installing the mmf library (specific hash chosen for reproducibility purpose).
Before we install mmf library in the developer mode, please make the following modifications in the requirement.txt (such that the existing PyTorch environment is not overridden when mmf is installed, to apply the patch copy the text in the following box in a file, e.g. patch-1.txt and run git apply patch-1.txt from the mmf directory.): 
Apply the following patch (using git apply as explained above)  for validate_batch_sizes method (specific to the commit selected for this article):
Install the mmf library in developer mode:
Debugging Basics
In order to understand the slow training we try to answer the following three questions:
Does the number of XLA compilations grow linearly with the number of training steps?
Does the device to host context switches grow linearly?
Does the model use any op which does not have an XLA lowering?
To answer these questions, PyTorch/XLA provides a few tools. The quickest way to find these metrics/counters is to enable client side profiling. For a more detailed report you can print metrics_report as explained on the PyTorch/XLA troubleshooting page. PyTorch/XLA client side profiler will often mention one of these metrics in the summary log. Here is an example metrics log:
It’s helpful to establish an understanding of these metrics and counters. Let’s get to know them.
Debug Metrics
CompileTime Metric
A few important fields to notice here are TotalSamples, Counter, and 50% compilation time. TotalSample indicates how many times XLA compilation happened. Counter indicates overall time spent compilation, and 50%= indicates median completion time.
aten::__local_scalar_dense Counter
This counter implies the number of device-to-host transfers. Once XLA compilation is complete, the execution of the graph is done on the device, however the tensors still live on the device until something in the user’s code does not require the value of the tensor and thus causing the device to host transfer. Common examples of such instances include .item() calls or a control structure in the code which requires the value such as if (...).any() statements. At the execution point when these calls are encountered, if the compilation and execution has not been done, it results in early compilation and evaluation, making training further slower.
aten::<op_name> Counter
This counter indicates the number of instances the said op was seen. The prefix aten:: indicates that cpu/aten default implementation of this op is being used and XLA implementation is not available. Since the Intermediate Representation (IR) graph is to be converted to XLA format and executed on the device, this means that in the forward pass at the instances of these ops, the IR graph needs to be truncated. The inputs to the op are evaluated on device, brought to host and  the op is executed with the said inputs. The output from the op is then plugged into the remainder of the graph and execution continues. Based on the number of instances and the location of such ops.
TransferFromServerTime Metric
Total number of samples of this metric indicates the number device to host transfers. In the detailed metric report (torch_xla.debug.metrics.metrics_report()) total time spent in device to host transfers (Accumulator value) and various quantiles are also reported. Client side profiling logs report the count/number of samples only. If this value scales rapidly (rate >=1) with number of training steps, this indicates that there are one or more unlowered ops (aten::*) or constructs fetching tensor values in the model or training code.
Interested readers can find the full list of PyTorch/XLA performance metrics and counters as follows:
With the fundamentals discussed thus far, now we are ready to start some experiments to apply the concepts we have learnt.
Experiment-0: Default Run
Once mmf is installed we are ready to start our training of the UniTransformer model on glue/qnli dataset.
Best Practice
Notice that we are using only a single TPU core for the debug run. Notice also that training.log_interval is set to 100. Usually logging involves accessing one or more tensor values. Accessing a tensor value involves a graph evaluation and device to host transfer. If done too frequently it can contribute unnecessary overhead to the training time. Therefore beyond debug/development stage higher logging intervals are recommended.
Observation
Once you execute this training, you will notice logs similar to the following snippet:
Notice that for 1500 steps the training takes over 33 minutes, updates per sec reported for the final 100 steps is 1.06. Let’s assume you are not impressed with the training speed, and you would like to investigate. Here’s where the PyTorch/XLA profiler can help.
Experiment-1: Enable Client-Side Profiling
PT_XLA_DEBUG environment variable enables the client side debugging functionality, i.e. any part of the user code which can cause frequent recompilations or device to host transfer will be reported during the training and summarized at the end when this functionality is enabled.
Observation
Once the client side profiling is enabled, you will notice the following messages starting to appear in your training log:
Notice the logs tagged pt-xla-profiler. The profiler reports too frequent TransferFromServerTime (translation: device to host transfer). Since PyTorch XLA works using the lazy tensor approach, the execution of PyTorch operation graphs it builds and optimizes, is deferred until either a step marker is seen or a tensor value is fetched (device to host transfer) As noted earlier too many of these transfers add to the overhead. 36K occurrences in just 1500 steps is certainly worth investigating. Note that It’s expected to grow with some factor linearly with the number of steps/log_interval, if this factor is greater than number of tensors being captured in the log, it means there are device-to-host transfers beyond what is being logged and it can be reduced to gain performance. This is also why increasing the log interval almost always helps.
Notice also that at the completion of the training (or when training is interrupted once ) the profiler provides further summaries, including a stack trace and frame count (refers to graphs diffs). “Equal” operator with count 9000 appears multiple times and therefore seems to be correlated with high device to host transfers. Notice also “_local_scalar_dense” since it have fewer occurrences, we will investigate it after examining the “equal” op.
The stack trace point to the following code from Multihead Attention Implementation:
Analysis
Looking up torch.equal manual page reveals:


This operator returns a scalar (boolean) value. Using this op in an if statement forces PyTorch/XLA to execute the subgraph leading to this scalar value before it can infer the graph where the boolean value is used. Since this code snippet is part of the forward path, subgraph gets evaluated for every step, as many times as the instances of the .equal operator. And the result of the execution needs to be transferred to the host (device to host transfer) to enable the client to build the rest of the graph. It therefore creates a big bottleneck not only for the overhead of the device to host transfer but also for slowing down the graph building and compilation pipeline. We call such forced evaluations early or premature evaluations.
Note that incase of == operator with tensor operands results in a tensor. == operator itself can become a part of the graph. Therefore using == operator does not result in the early evaluation. However if the value of the resulting tensor affects the resulting graph, i.e. creates a dynamic graph, it can quickly diminish advantages of graph compilation approach with caching (to understand more watch this video).
Potential Corrective Action
We also note that this implementation choice of torch.equal for MHA (MultiHead Attention) serves to optimize resulting GPU kernels, and therefore a good solution to allow this optimization without creating a bottleneck for TPUs is to qualify torch.equal calls with some (non-trainable/configuration) parameter. One potential solution is an implementation similar to this one. However, PyTorch 1.9 code fixed the issue by simplifying the implementation by moving to tensor comparison (is op).
Next Steps
In this post we introduced the basic concepts to understand PyTorch/XLA performance. We also introduced an experiment with a performance bottleneck due to forced execution caused by .equal op. A potential solution in this instance involves the update in the PyTorch core code or update to PyTorch release 1.9. The reader may find these instructions helpful. After the environment is updated, please re-execute the experiment-1 and note the new performance log. In the next part of this article we will review the results and develop further insights into the performance.
Until next time, Happy Hacking! Have a question or want to chat? Find me on LinkedIn.




