PyTorch XLA performance profiling
Overview
This guide walks you through how to use Cloud TPU performance tools and the metrics auto-analysis feature with PyTorch. These tools help you debug and optimize training workload performance.
For more information about Cloud TPU performance with PyTorch, see the following blog posts:
- Scaling deep learning PyTorch workloads
- PyTorch XLA performance debugging on TPU VMs - part 1
- PyTorch XLA performance debugging on TPU VMs - part 2
- PyTorch XLA performance debugging on TPU VMs - part 3
- Lazy tensor performance with PyTorch XLA
If you are new to PyTorch / XLA refer to the PyTorch API_GUIDE and troubleshooting docs. For Cloud TPU refer to the concepts document.
TPU Node + PyTorch/XLA profiling
Create Cloud TPU-related resources
Create and initialize required Cloud TPU-related resources.
Create variables for your project ID, your Cloud Storage bucket, and the zone to use for your TPU resources.
export PROJECT_ID=PROJECT_ID export BUCKET_NAME=BUCKET_NAME export ZONE=ZONE gcloud --project=$PROJECT_ID compute project-info add-metadata \ --metadata BUCKET_NAME=$BUCKET_NAME
Create a Compute Engine VM instance. This is where all your Python scripts and models are stored.
gcloud compute instances create profiler-tutorial-vm \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --machine-type=n1-standard-16 \ --image-project=ml-images \ --image-family=torch-xla \ --boot-disk-size=300GB \ --scopes=https://www.googleapis.com/auth/cloud-platform
Create a TPU resource.
gcloud compute tpus create profiler-tutorial-tpu \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --network=default \ --version=pytorch-1.8 \ --accelerator-type=v3-8
Create a Cloud Storage bucket.
First install
gsutil
CLI if you do not have it installed already: installation instructions.Use the
gsutil mb
command to create a Cloud Storage bucket where all the profiling artifacts are stored. Replace the region and bucket-name variables with the values you will use for your training.gsutil mb -p ${PROJECT_ID} -c standard -l REGION gs://${BUCKET_NAME}
where:
- REGION is the region where you created the
Cloud TPU, for example
europe-west4
.
- REGION is the region where you created the
Cloud TPU, for example
Create a Service Account for the Cloud TPU project.
gcloud beta services identity create --service tpu.googleapis.com --project $PROJECT_ID
The command returns a Cloud TPU Service Account with following format:
service-PROJECT_NUMBER@cloud-tpu.iam.gserviceaccount.com
For example: service-164006649440@cloud-tpu.iam.gserviceaccount.com
Export the service account and grant service account permissions on the storage bucket. Replace
account-number
with the PROJECT_NUMBER returned in the service account creation output.export SERVICE_ACCOUNT=service-ACCOUNT_NUMBER@cloud-tpu.iam.gserviceaccount.com gsutil acl ch -u $SERVICE_ACCOUNT:READER gs://${BUCKET_NAME} gsutil acl ch -u $SERVICE_ACCOUNT:WRITER gs://${BUCKET_NAME}
Setup TensorBoard
ssh
to your VM forwarding port 9001 on your VM to port 9001 on your local machine. This port is used to open up the TensorBoard UI on your local browser.gcloud compute ssh profiler-tutorial-vm \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --ssh-flag="-4 -L 9001:localhost:9001"
Create a conda environment dedicated for TensorBoard installation:
conda create -y -n tensorboard python=3.6 conda activate tensorboard pip install tf-nightly==2.6.0.dev20210511 tb-nightly==2.6.0a20210512 tbp-nightly==2.5.0a20210511
Test your installation by running the TensorBoard server on your Compute Engine VM and then trying to connect to the server by visiting http://localhost:9001/#profile on your local machine:
# Get bucket name BUCKET_NAME=$(curl "http://metadata.google.internal/computeMetadata/v1/project/attributes/BUCKET_NAME" -H "Metadata-Flavor: Google")
tensorboard --logdir gs://${BUCKET_NAME} --port 9001
When you visit http://localhost:9001/#profile on your local machine, you should see something like this:
If you visit http://localhost:9001 you can also access the above profile page by selecting the PROFILE option on the dropdown on the top-right corner next to the UPLOAD button:
Profile the model
To keep the TensorBoard server alive, from your local machine,
start a new terminal window and ssh
onto the GCE VM again
(this time without using the -L
option for port forwarding).
In the new terminal window, export your project ID, storage bucket environment, and zone variables again, since this is in a new shell.
export PROJECT_ID=PROJECT_ID export ZONE=ZONE
ssh
into the VM:gcloud compute ssh profiler-tutorial-vm \ --project ${PROJECT_ID} \ --zone ${ZONE}
conda activate torch-xla-1.8 PROJECT_ID=$(curl "http://metadata.google.internal/computeMetadata/v1/project/project-id" -H "Metadata-Flavor: Google") export TPU_IP_ADDRESS=$(gcloud compute tpus describe profiler-tutorial-tpu --zone=${ZONE} --project=${PROJECT_ID} \ --format="value(ipAddress)") echo TPU_IP_ADDRESS=${TPU_IP_ADDRESS} export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
Verify that integration tests are working end-to-end in your environment:
python /usr/share/torch-xla-1.8/pytorch/xla/test/test_profiler.py # takes <1 min
Before starting to train, edit the following lines in
/usr/share/torch-xla-1.8/pytorch/xla/test/test_profile_mp_mnist.py
:Change:
accuracy = train_mnist(flags, dynamic_graph=True, fetch_often=True)
To:accuracy = train_mnist(flags, dynamic_graph=False, fetch_often=False)
The two arguments above in
train_mnist
, artificially cause dynamic graphs and tensor fetches, which are explored later in the Auto-metrics Analysis section. For now, since you are just profiling the TPU the following example runs with nominal performance.Start a training run that is used for server profiling:
PT_XLA_DEBUG=1 XLA_HLO_DEBUG=1 python /usr/share/torch-xla-1.8/pytorch/xla/test/test_profile_mp_mnist.py --num_epochs 1000 --fake_data
TPU (server) profiling
Once the training is running visit http://localhost:9001/#profile and capture a profile using the following instructions:
The following page is automatically reloaded:
Supported tools are shown under the Tools dropdown on the left pane:
- Overview page (this does not include displaying the input pipeline analyzer, which is fundamentally different between TensorFlow/TPU and PyTorch / XLA TPU).
- Memory viewer
- Op Profile
- Pod viewer
- Tensorflow Stats (Framework level stats; i.e. PyTorch stats)
- Trace viewer (requires the Chrome browser)
Overview page
This page shows an overview of a captured profile. This example shows very high idle time because it is training a tiny model on the MNIST dataset.
Memory viewer
Shows device memory (HBM) used per tensor and HLO PyTorch op. The memory viewer captures a view of memory per HLO module, so there will be some modules such as device data allocation (input and labels batch sent) graphs. To view the memory usage of a particular HLO module, select the module from the Hosts dropdown on the left:
Once you are viewing your selected HLO module, you get a holistic view of the module execution HBM footprint timeline. This is ordered by allocation size, program execution size, and padding size.
Each of the buffer allocations can then be further inspected by hovering over them. For example, to tell what allocation is taking up most of device HBM:
In the above example, (1) corresponds to the torch_xla.debug.profiler.Trace
annotations added by the user code. Inspecting the test_profile_mp_mnist.py
code it corresponds to this line:
class MNIST(nn.Module): ... def forward(self, x): with xp.Trace('conv1'): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = self.bn1(x) ...
Also from the test_mnist
op namespace, you can tell that this HLO
module corresponds to the eval loop as it has the xp.trace('test_mnist')
context manager.
XLA Op Profile
Op profile is a Cloud TPU tool that displays performance statistics of XLA operations executed during a profiling period. The op profile shows:
- How well your application uses the Cloud TPU as a percentage of time spent on operations by category and of TPU FLOPS utilization.
- The most time-consuming operations. These operations are potential targets for optimization.
- Details of individual operations, including shape, padding and expressions that use the operation.
You can use the op profile to find good targets for optimization. For example, if your model achieves only 5% of the TPU peak FLOPS, you can use the tool to identify which XLA operations are taking the longest time to execute and how many TPU FLOPS they consume.
Description of each:
- Overview section. Shows Cloud TPU utilization and provides suggestions for optimization.
- Control panel. Contains controls that allow you to set the number of operations displayed in the table, which operations are displayed, and how they are sorted.
- Op table. A table that lists the top TensorFlow operation categories associated with the XLA ops. These operations are sorted by percentage of Cloud TPU usage.
- Op details cards. Details about the op that appear when you hover over an op in the table. These include the FLOPS utilization, the expression in which the op is used, and the op layout (fit).
Pod viewer
See TPU tools for a full description of the Pod viewer tool.
Framework stats (Tensorflow/PyTorch stats)
Framework stats provides a detailed PyTorch and XRT op statistics breakdown running on TPU devices and TPU hosts.
Trace viewer
Trace viewer is a Cloud TPU performance analysis tool. The trace viewer uses the Chrome trace event profiling viewer so it requires use of the Chrome browser.
Trace viewer displays a timeline that shows:
- Durations for the operations that were executed by your TensorFlow model.
- Which part of the system (TPU or host machine) executed an operation. For PyTorch / XLA typically, the host machine primarily works on the compilation and buffer allocation/deallocation, whereas the TPU executes the actual model training.
- Trace viewer allows you to identify performance problems in your model, then take steps to resolve them. Drilling down, you can identify which PyTorch / XLA operations are taking the longest to execute.
Note that you can directly add traces to measure how long certain parts
of your model take to execute by adding xp.Trace(NAME)
annotations.
For example, the following trace shows:
- Generated by explicit user annotations in the model code of test_profile_mp_mnist.py.
- PyTorch ops executed (pre-lowering).
- PyTorch / XLA Auto-generated HLO module name.
- XLA Ops executed on device (fused).
For more detailed information, refer to the generic TPU documentation for the trace viewer, but ignore sections around input pipeline and other TensorFlow specific parts as they are not relevant in the context of this document.
PyTorch / XLA client profiling
Similar to when you profiled the TPU side while the model execution was ongoing, now you will profile the PyTorch / XLA client side while training. The main monitoring tool used on the client side is the Trace viewer.
You must start up the profiling server in your training script. For an example, see, which you can query from TensorBoard to capture a trace.
To capture traces from multiple processes, each process can start
up profiling servers on different ports (for example,
By adding xm.get_ordinal()
to a base port number) and then providing a list of localhost:port
concatenated by ','. Tensorboard does not support viewing traces from
multiple processes at one time, so you will see different
Host dropdowns for each process.
The following diagram shows a sample trace:
Similar to how different namespace traces could be added for TPU traces,
you can use the same API to add them to the client side
traces (xp.Trace(NAME)
). Note that since this model is small
and is being used with small MNIST images, the step times
will be short and
not necessarily uniform. As an exercise you can try to add the
traces and startup a profiler server similar to the one in our
example to test_train_mp_imagenet.py --fake_data
to
get traces of ResNet50.
The traces have additional metadata that can be inspected. For example, TransferToServer and TransferFromServer traces, show the exact number of tensors being sent and received and their total size:
For XLA graph compilations, you can see the graph hash that can be helpful in diagnosing problems:
Additionally, instead of profiling through the TensorBoard UI we
also provide API for programmatically profiling both the TPU and
the client from PyTorch / XLA: xp.trace()
.
Auto-metrics analysis
In this section, you see how to use debugging mode to detect performance issues, such as:
- Dynamic graphs / continuous compilations
- Very slow graph compilation
- Very slow graph execution
- Frequent XLA→CPU transfers
- Repeated device HBM to host RAM swapping
- Repeated HBM defragmentation
- Unlowered
aten::
ops
Before starting training, revert the following lines in
/usr/share/torch-xla-1.8/pytorch/xla/test/test_profile_mp_mnist.py
:
Change:
accuracy = train_mnist(flags, dynamic_graph=False, fetch_often=False)To:
accuracy = train_mnist(flags, dynamic_graph=True, fetch_often=True)
These changes artificially cause compilations and tensor fetches.
dynamic_graph=True
artificially changes the batch size for each step,
causing XLA lowered graphs to be different at every step
and recompilation. fetch_often=True
inserts loss.item()
calls
at every step resulting in fetching tensor values from the
device at each step, slowing performance.
Running an example training script:
PT_XLA_DEBUG=1 python /usr/share/torch-xla-1.8/pytorch/xla/test/test_profile_mp_mnist.py --fake_data --num_cores=1
When debugging, best practice is to run with --num_cores=1
as it simplifies the debugging process. Some of the sample output looks
like this:
Epoch 1 train begin 01:18:05 | Training Device=xla:1/0 Step=0 Loss=0.00000 Rate=1905.00 GlobalRate=1904.72 Time=01:18:05 pt-xla-profiler: TransferFromServerTime too frequent: 3 counts during 3 steps pt-xla-profiler: TransferFromServerTime too frequent: 4 counts during 4 steps pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 5 steps pt-xla-profiler: TransferFromServerTime too frequent: 6 counts during 6 steps pt-xla-profiler: TransferFromServerTime too frequent: 7 counts during 7 steps pt-xla-profiler: TransferFromServerTime too frequent: 8 counts during 8 steps pt-xla-profiler: TransferFromServerTime too frequent: 9 counts during 9 steps pt-xla-profiler: TransferFromServerTime too frequent: 10 counts during 10 steps pt-xla-profiler: CompileTime too frequent: 21 counts during 11 steps pt-xla-profiler: TransferFromServerTime too frequent: 11 counts during 11 steps pt-xla-profiler: CompileTime too frequent: 23 counts during 12 steps
Lines with the prefix pt-xla-profiler
correspond to the auto metrics
analysis output. In this example you can see TransferFromServerTime
was seen too frequently, once per step. This is due to the training loop
retrieving the value of loss.item()
at every step. Also, you might
see a CompileTime too frequent warning, when
when your model has to be recompiled repeatedly because of dynamic
shapes in the graph. The follow code snippet causes this type of problem:
test_profile_mp_mnist.py
for step, (data, target) in enumerate(loader): if dynamic_graph: # The batch dimension is different every step. index = max(-step, -flags.batch_size + 1) # non-empty data, target = data[:-index, :, :, :], target[:-index] ... if fetch_often: # Fetch tensor value from XLA:TPU to CPU every step. loss_i = loss.item()
You can then Ctrl^C out of the training script and at the
end you should see a summary of the unlowered ops. Note
that aten::_local_scalar_dense
is a special op that
corresponds to retrieving XLA tensors back to the CPU context.
In this report you see that there are two main places
where the aten::_local_scalar_dense
op is being called,
both correspond to source code of loss.item()
:
test/test_profile_mp_mnist.py:158
test/test_profile_mp_mnist.py:61
pt-xla-profiler: ================================================================================ pt-xla-profiler: Unlowered Op usage summary (more of these ops, lower performance) pt-xla-profiler: Note: _local_scalar_dense typically indicates CPU context access pt-xla-profiler: -------------------------------------------------------------------------------- pt-xla-profiler: FRAME (count=27): pt-xla-profiler: Unlowered Op: "_local_scalar_dense" pt-xla-profiler: Python Frames: pt-xla-profiler: train_loop_fn (test/test_profile_mp_mnist.py:158) pt-xla-profiler: train_mnist (test/test_profile_mp_mnist.py:184) pt-xla-profiler: _mp_fn (test/test_profile_mp_mnist.py:206) pt-xla-profiler: _start_fn (/home/jysohn/git/jysohn23/pytorch/xla/torch_xla/distributed/xla_multiprocessing.py:323) pt-xla-profiler: spawn (/home/jysohn/git/jysohn23/pytorch/xla/torch_xla/distributed/xla_multiprocessing.py:386) pt-xla-profiler:(test/test_profile_mp_mnist.py:216) pt-xla-profiler: pt-xla-profiler: pt-xla-profiler: FRAME (count=2): pt-xla-profiler: Unlowered Op: "_local_scalar_dense" pt-xla-profiler: Python Frames: pt-xla-profiler: _train_update (test/test_profile_mp_mnist.py:61) pt-xla-profiler: (/home/jysohn/git/jysohn23/pytorch/xla/torch_xla/core/xla_model.py:700) pt-xla-profiler: _run_step_closures (/home/jysohn/git/jysohn23/pytorch/xla/torch_xla/core/xla_model.py:709) pt-xla-profiler: mark_step (/home/jysohn/git/jysohn23/pytorch/xla/torch_xla/core/xla_model.py:723) pt-xla-profiler: __exit__ (/home/jysohn/git/jysohn23/pytorch/xla/torch_xla/debug/profiler.py:153) pt-xla-profiler: train_loop_fn (test/test_profile_mp_mnist.py:162) pt-xla-profiler: train_mnist (test/test_profile_mp_mnist.py:184) pt-xla-profiler: _mp_fn (test/test_profile_mp_mnist.py:206) pt-xla-profiler: _start_fn (/home/jysohn/git/jysohn23/pytorch/xla/torch_xla/distributed/xla_multiprocessing.py:323) pt-xla-profiler: spawn (/home/jysohn/git/jysohn23/pytorch/xla/torch_xla/distributed/xla_multiprocessing.py:386) pt-xla-profiler: (test/test_profile_mp_mnist.py:216) pt-xla-profiler: pt-xla-profiler: pt-xla-profiler: ================================================================================
Now run auto-metrics analysis on the below script, which contains
an unlowered op (as of our 1.8 release), _ctc_loss
:
PT_XLA_DEBUG=1 python <<EOF
import torch
import torch_xla.core.xla_model as xm
dev = xm.xla_device()
t = torch.randn(50, 16, 20).log_softmax(2).to(dev)
target = torch.randint(low=1, high=20, size=(16, 30), dtype=torch.long).to(dev)
input_lengths = torch.full(size=(16,), fill_value=50, dtype=torch.long).to(dev)
target_lengths = torch.randint(low=10, high=30, size=(16,), dtype=torch.long).to(dev)
for _ in range(10):
loss = torch.nn.CTCLoss()(t, target, input_lengths, target_lengths)
xm.mark_step()
EOF
Running the above script with PT_XLA_DEBUG=1
the output
should look something like the following:
… pt-xla-profiler: TransferFromServerTime too frequent: 30 counts during 10 steps pt-xla-profiler: Op(s) not lowered: aten::_ctc_loss, Please open a GitHub issue with the above op lowering requests. pt-xla-profiler: ================================================================================ pt-xla-profiler: Unlowered Op usage summary (more of these ops, lower performance) pt-xla-profiler: Note: _local_scalar_dense typically indicates CPU context access pt-xla-profiler: -------------------------------------------------------------------------------- pt-xla-profiler: FRAME (count=10): pt-xla-profiler: Unlowered Op: "_ctc_loss" pt-xla-profiler: Python Frames: pt-xla-profiler: ctc_loss (/anaconda3/envs/torch-xla-1.8/lib/python3.6/site-packages/torch/nn/functional.py:2305) pt-xla-profiler: forward (/anaconda3/envs/torch-xla-1.8/lib/python3.6/site-packages/torch/nn/modules/loss.py:1593) pt-xla-profiler: _call_impl (/anaconda3/envs/torch-xla-1.8/lib/python3.6/site-packages/torch/nn/modules/module.py:889) pt-xla-profiler:( :11) pt-xla-profiler: pt-xla-profiler: pt-xla-profiler: ================================================================================
The auto-metrics analyzer
shows that the 11th line of STDIN is causing
this unlowered op (i.e., the line with torch.nn.CTCLoss()
).
Currently, the ctc_loss
op has not been lowered,
which is why you see the above report. You can also see some warnings for
TransferFromServerTime
as the tensors are initially in XLA:TPU
before execution, but since the op is not lowered, you first
need to transfer the XLA tensors back to CPU, execute the
aten::
op on CPU and transfer it back.
If you would like to write the pt-xla-profiler
output
to a file instead, set
PT_XLA_DEBUG=1
and PT_XLA_DEBUG_FILE=$PATH_TO_FILE
.
Cleanup
Exit from your VM and then delete the TPU, VM, and Cloud Storage bucket by running the following commands:
(vm)$ exit
gcloud compute instances delete profiler-tutorial-vm \ --zone=${ZONE} \ --project=${PROJECT_ID}
gcloud compute tpus delete profiler-tutorial-tpu \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --async
gsutil rm -fr gs://${BUCKET_NAME}
TPU VM + PyTorch/XLA profiling
Use this section to profile PyTorch/XLA using the TPU VM architecture.
Export Environment Variables
Create variables for your project ID and the zone to use for your TPU resources.
export PROJECT_ID=PROJECT_ID export ZONE=ZONE
Create a Cloud TPU
Please refer to the
TPU VM user guide
and after setup, create a v3-8 TPU VM, which comes with torch
,
torch_xla
,
torchvision
, and tensorboard
preinstalled.
Create a TPU resource.
gcloud compute tpus tpu-vm create profiler-tutorial-tpu-vm \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --version=v2-alpha \ --accelerator-type=v3-8
Tensorboard server startup
SSH onto the VM, install the
tensorboard-plugin-profile
, and start up a tensorboard server.gcloud compute tpus tpu-vm ssh profiler-tutorial-tpu-vm \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --ssh-flag="-4 -L 9001:localhost:9001"
pip3 install tf-nightly==2.6.0.dev20210511 tb-nightly==2.6.0a20210512 tbp-nightly==2.5.0a20210511
tensorboard --logdir ./tensorboard --port 9001
When you view the TensorBoard output at http://localhost:9001 on your local machine, you should see something like this:
If you view the TensorBoard output at http://localhost:9001 you can also access the above profile page by selecting the PROFILE option on the dropdown on the top-right corner next to the UPLOAD button:
Profile the model
On a new terminal window on your development environment, export the same
environment variables as above and ssh
onto your TPU VM:
In the new terminal window, export your project ID and zone variables again, since this is in a new shell.
export PROJECT_ID=PROJECT_ID export ZONE=ZONE
ssh
into the VM:gcloud compute tpus tpu-vm ssh profiler-tutorial-tpu-vm \ --project ${PROJECT_ID} \ --zone ${ZONE}
Clone the PyTorch/XLA repository and run our e2e test:
git clone -b r1.8 https://github.com/pytorch/xla export XRT_TPU_CONFIG="localservice;0;localhost:51011" python3 xla/test/test_profiler.py # takes <1 min
Before starting the training, edit the following lines in
xla/test/test_profile_mp_mnist.py
:Change:
accuracy = train_mnist(flags, dynamic_graph=True, fetch_often=True)
To:accuracy = train_mnist(flags, dynamic_graph=False, fetch_often=False)
The two arguments above in
train_mnist
, artificially cause dynamic graphs and tensor fetches, which are explored later in the Auto-metrics Analysis section. For now, since you are just profiling the TPU the following example runs with nominal performance.Start a training run:
XLA_HLO_DEBUG=1 python3 xla/test/test_profile_mp_mnist.py --num_epochs 1000 --fake_data
TPU + Client Profiling
Once the training is running view the TensorBoard output at http://localhost:9001 and capture a profile using the following instructions:
You should then see the following page reloaded:
Currently in the TPU VM setup, only the trace viewer
tool is selected so under the Tools dropdown
select trace_viewer
and inspect the traces. You can
see that in the TPU VM setup you see both the "client" side and TPU device side
traces in one full view:
Cleanup
Exit from your VM and then delete the TPU, VM, and Cloud Storage bucket by running the following commands:
(vm)$ exit
Delete the TPU VM you created:
Delete your Cloud TPU and Compute Engine resources.
$ gcloud compute tpus tpu-vm delete profiler-tutorial-tpu-vm \ --project ${PROJECT_ID} --zone=${ZONE}
Verify the resources have been deleted by running the following command. The deletion might take several minutes. A response like the one below indicates your instances have been successfully deleted.
$ gcloud compute tpus tpu-vm list --project ${PROJECT_ID} --zone=${ZONE}
Listed 0 items.