PyTorch XLA performance profiling


This guide walks you through how to use Cloud TPU performance tools and the metrics auto-analysis feature with PyTorch. These tools help you debug and optimize training workload performance.

For more information about Cloud TPU performance with PyTorch, see the following blog posts:

If you are new to PyTorch / XLA refer to the PyTorch API_GUIDE and troubleshooting docs. For Cloud TPU refer to the concepts document.

TPU VM + PyTorch/XLA profiling

Use this section to profile PyTorch/XLA using the TPU VM architecture.

Export Environment Variables

  1. Create variables for your project ID and the zone to use for your TPU resources.

    export ZONE=ZONE

Create a Cloud TPU

Please refer to Manage TPUs and after setup, create a v3-8 TPU VM, which comes with torch, torch_xla, torchvision, and tensorboard preinstalled.

  1. Create a TPU resource.

    gcloud compute tpus tpu-vm create profiler-tutorial-tpu-vm \
     --project=${PROJECT_ID} \
     --zone=${ZONE} \
     --version=v2-alpha \

Tensorboard server startup

  1. SSH onto the VM, install the tensorboard-plugin-profile, and start up a tensorboard server.

      gcloud compute tpus tpu-vm ssh profiler-tutorial-tpu-vm \
       --project ${PROJECT_ID} \
       --zone ${ZONE} \
       --ssh-flag="-4 -L 9001:localhost:9001"
      pip3 install tf-nightly==2.6.0.dev20210511 tb-nightly==2.6.0a20210512 tbp-nightly==2.5.0a20210511
      tensorboard --logdir ./tensorboard --port 9001

When you view the TensorBoard output at http://localhost:9001 on your local machine, you should see something like this:


If you view the TensorBoard output at http://localhost:9001 you can also access the above profile page by selecting the PROFILE option on the dropdown on the top-right corner next to the UPLOAD button:


Profile the model

On a new terminal window on your development environment, export the same environment variables as above and ssh onto your TPU VM:

  1. In the new terminal window, export your project ID and zone variables again, since this is in a new shell.

    export ZONE=ZONE
  2. ssh into the VM:

      gcloud compute tpus tpu-vm ssh profiler-tutorial-tpu-vm \
       --project ${PROJECT_ID} \
       --zone ${ZONE}
  3. Clone the PyTorch/XLA repository and run our e2e test:

      git clone -b r1.8
      export XRT_TPU_CONFIG="localservice;0;localhost:51011"
      python3 xla/test/  # takes <1 min
  4. Before starting the training, edit the following lines in xla/test/


        accuracy = train_mnist(flags, dynamic_graph=True, fetch_often=True)
        accuracy = train_mnist(flags, dynamic_graph=False, fetch_often=False)

    The two arguments above in train_mnist, artificially cause dynamic graphs and tensor fetches, which are explored later in the Auto-metrics Analysis section. For now, since you are just profiling the TPU the following example runs with nominal performance.

  5. Start a training run:

     XLA_HLO_DEBUG=1 python3 xla/test/ --num_epochs 1000 --fake_data

TPU + Client Profiling

Once the training is running view the TensorBoard output at http://localhost:9001 and capture a profile using the following instructions:


You should then see the following page reloaded:


Currently in the TPU VM setup, only the trace viewer tool is selected so under the Tools dropdown select trace_viewer and inspect the traces. You can see that in the TPU VM setup you see both the "client" side and TPU device side traces in one full view:



  1. Exit from your VM and then delete the TPU, VM, and Cloud Storage bucket by running the following commands:

    (vm)$ exit

Delete the TPU VM you created:

  1. Delete your Cloud TPU and Compute Engine resources.

    $ gcloud compute tpus tpu-vm delete profiler-tutorial-tpu-vm \
      --project ${PROJECT_ID} --zone=${ZONE}
  2. Verify the resources have been deleted by running the following command. The deletion might take several minutes. A response like the one below indicates your instances have been successfully deleted.

    $ gcloud compute tpus tpu-vm list --project ${PROJECT_ID} --zone=${ZONE}
    Listed 0 items.