Supported reference models

Cloud TPU provides a set of reference models that are optimized for fast and accurate training.

Cloud TPU supports the following major and minor framework releases of TensorFlow, PyTorch, and JAX/FLAX.

TensorFlow release numbering has changed with release 2.5.0. Major TensorFlow release numbers end with '0' and all patch release numbers end with numbers greater than '0'. For example, TF 2.10.0 is a major release and TF 2.10.1 is a minor release.

In order to run the latest-supported framework version, check to see if there are any patch releases to the major release. If so, you can run the latest-supported patch release rather than the major release.

Framework Major version Model category Reference models Supported versions
TensorFlow 2.x Image classification ResNet-2.x, MNIST-2.x, EfficientNet-2.x See Supported TensorFlow versions.
Language modeling Transformer-2.x, BERT-2.x See Supported TensorFlow versions.
Object detection RetinaNet-2.x See Supported TensorFlow versions.
Image segmentation Mask-RCNN-2.x, ShapeMask-2.x See Supported TensorFlow versions.
Recommendation systems DLRM-2.x, DCN-2.x, NCF-2.x See Supported TensorFlow versions.
PyTorch 2.x Image classification ResNet-PyTorch 1.13, 2.0
Image generation Stable Diffusion, 1.13, 2.0
JAX latest Large Language Models MaxText latest
JAX/FLAX latest Image classification ResNet50 latest