Jump to Content
AI & Machine Learning

Announcing PyTorch/XLA 2.4: A better Pallas and developer experience, plus “eager mode”

July 30, 2024
Bhavya Bahl

Software Engineer

Duncan Campbell

Developer Advocate

Google Cloud Summit Series

Discover the latest in AI, Security, Workspace, App Dev, & more.

Register

For deep learning researchers and practitioners, the open-source PyTorch machine learning (ML) library and XLA ML compiler provide flexible, powerful model training, fine-tuning and serving. Today, the PyTorch/XLA team is excited to announce the release of PyTorch/XLA 2.4, with several notable improvements to address developer challenges that build on the last release. Here we discuss some of the latest features that make it easier to work with PyTorch/XLA:

  • Support for Pallas, a custom kernel language, orginally developed for JAX,that supports both TPUs and GPUs

  • New API calls

  • Introduction of an experimental “eager mode” 

  • New TPU command line interface

Let’s dive in to the new useability enhancements.

Additional Pallas Support

The XLA compiler itself is capable of making optimizations to your existing models, but there are times when model creators can get better performance by writing custom kernel code. Originally developed as a part of JAX, Pallas supports both TPU and GPUs, letting you write more performant code in Python that’s closer to the hardware, without having to use a lower-level and more complicated language such as C++. Pallas is similar to the Triton library, but since it also works on both TPUs and GPUs, it’s easier to port your model from one machine learning accelerator to the other.

The recent PyTorch/XLA 2.4 release introduces additional support,  that bolster Pallas' integrations with PyTorch::

  • Added support for Flash Attention, fully integrated with PyTorch autograd (automatic gradient calculation)

  • Built-in support for Paged Attention for inference

  • Support for Megablocks’ block sparse kernels for group matrix multiplication as an Autograd function, removing the need to manually perform backpropagation

API changes

PyTorch/XLA 2.4 introduces a few new calls that make it easier to integrate it into your existing PyTorch workflow, for instance:

Loading...

is now an option where traditionally you would have to have called

Loading...

And now, instead of having to call xm.mark_step() you can call torch_xla.sync()instead. These improvements make it easier to convert your code over to PyTorch/XLA and improve the developer workflow. For more changes to API calls, check out the release notes.

Experimental eager mode

If you’ve been working with PyTorch/XLA for a while, you know that we refer to models being “lazily executed.” That means that PyTorch/XLA creates the compute graph of operation before sending models over to be executed on the XLA device target hardware. With new eager mode, operations are compiled and then immediately executed on the target hardware.

The catch to this feature though is that TPUs themselves do not have a true eager mode, since each instruction is not sent to the TPU by default right away. On TPUs, we achieve this by adding a “mark_step” call after each PyTorch operation to force the compilation and execution. This results in the functionality of eager mode but as an emulation rather than as a native feature. 

Our intent with eager mode in this release is not to run it in your production environment, but rather in your own local environments. We hope that eager mode makes it easier to debug your models locally on your own machines without having to deploy it to a larger fleet of devices, such as is the case of most production systems.

Cloud TPU info command line interface

If you’ve used Nvidia GPUs before, you may be familiar with the nvidia-smi tool, which you can use to debug your GPU workloads, identify which cores are being utilized, and see how much memory a given workload is consuming. And now, there’s a similar command line utility for Cloud TPUs that makes it easier to surface utilization information and device information: tpu-info. Here’s an example of its output:

Loading...

Get started with PyTorch/XLA 2.4 today

Even though PyTorch/XLA 2.4 includes some API changes, the best part is your existing code is still compatible with this latest version, and the new API calls will ease your future development processes. So what are you waiting for, give you the latest version a try. For more information, feel free to visit the project's GitHub repository.

Posted in