Supported reference models
Cloud TPU provides a set of reference models that are optimized for fast and accurate training.
Cloud TPU supports the following major and minor framework releases of TensorFlow, PyTorch, and JAX/FLAX.
TensorFlow release numbering has changed with release 2.5.0. Major TensorFlow release numbers end with '0' and all patch release numbers end with numbers greater than '0'. For example, TF 2.10.0 is a major release and TF 2.10.1 is a minor release.
In order to run the latest-supported framework version, check to see if there are any patch releases to the major release. If so, you can run the latest-supported patch release rather than the major release.
Framework | Major version | Model category | Reference models | Supported versions |
---|---|---|---|---|
TensorFlow | 2.x | Image classification | ResNet-2.x, MNIST-2.x, EfficientNet-2.x | See Supported TensorFlow versions. |
Language modeling | Transformer-2.x, BERT-2.x | See Supported TensorFlow versions. | ||
Object detection | RetinaNet-2.x | See Supported TensorFlow versions. | ||
Image segmentation | Mask-RCNN-2.x, ShapeMask-2.x | See Supported TensorFlow versions. | ||
Recommendation systems | DLRM-2.x, DCN-2.x, NCF-2.x | See Supported TensorFlow versions. | ||
PyTorch | 2.x | Image classification | ResNet-PyTorch | 1.13, 2.0 |
Image generation | Stable Diffusion, | 1.13, 2.0 | ||
JAX | latest | Large Language Models | MaxText | latest |
JAX/FLAX | latest | Image classification | ResNet50 | latest |