Frequently Asked Questions - Cloud TPU on TensorFlow

What models are supported on Cloud TPU? A list of supported models can be found here: Official supported models.

Can I use a Cloud TPU for inference? You can perform inference on a model trained on Cloud TPU via the Keras tf.keras.Model.predict method.

Are there built-in TensorFlow ops that are not available on Cloud TPU?

There are a few built-in TensorFlow ops that are not currently available on the Cloud TPU. See available TensorFlow Ops, which details the current workarounds.

Can I train a reinforcement learning (RL) model with a Cloud TPU?

Reinforcement learning covers a wide array of techniques, some of which currently are not compatible with the software abstractions for TPUs. Some reinforcement learning configurations require executing a black-box "simulation environment" using a CPU as part of the training loop. We have found that these cannot keep up with the Cloud TPU and result in significant inefficiencies.

Can I use embeddings with a Cloud TPU?

Yes. You can use the TPUEmbedding layer to support embeddings in Keras models. In custom layers and models, you can use tf.nn.embedding_lookup().

Can I use variable-length sequences with Cloud TPU?

There are several methods for representing variable-length sequences in TensorFlow, including padding, tf.while_loop(), inferred tensor dimensions, and bucketing. Unfortunately, the current Cloud TPU execution engine supports a subset of these. Variable-length sequences must be implemented using tf.while_loop(), tf.dynamic_rnn(), bucketing, padding, or sequence concatenation.

Can I train a Recurrent Neural Network (RNN) on Cloud TPU?

Yes. To train a RNN with TensorFlow, use Keras RNN layers.

Can I train a generative adversarial network (GAN) with Cloud TPU?

Training GANs typically requires frequently alternating between training the generator and training the discriminator. The current TPU execution engine only supports a single execution graph. Alternating between graphs requires a complete re-compilation, which can take 30 seconds or more.

One potential workaround is to always compute the sum of losses for both the generator and discriminator, but multiply these losses by two input tensors: g_w and d_w. In batches where the generator should be trained, you can pass in g_w=1.0 and d_w=0.0, and vice-versa for batches where the discriminator should be trained.

Can I train a multi-task learning model with Cloud TPU?

If the tasks can be represented as one large graph with an aggregate loss function, then no special support is needed for multi-task learning. However, the TPU execution engine currently only supports a single execution graph. Therefore, it is not possible to quickly alternate between multiple execution graphs which share variables but have different structure. Changing execution graphs requires re-running the graph compilation step, which can take 30 seconds or more.

Does Cloud TPU support TensorFlow eager mode?

In TensorFlow, users can use a @tf.function decorator to compile into XLA and use the TPU with TensorFlow eager mode.

Does Cloud TPU support model parallelism?

Cloud TPU supports Single Program Multiple Data (SPMD)-based model parallelism:

  • in TensorFlow 2.x spatial partitioning is supported by using TPUStrategy's experimental_split_to_logical_devices() on the tensor you want to split for spatial partitioning, and by setting experimental_spmd_xla_partitioning=True.
  • in JAX via pjit.

How can I inspect the actual value of intermediate tensors on Cloud TPU, as with tf.Print or tfdbg?

This capability is currently not supported on Cloud TPU. A good practice is to debug your models on the CPU/GPU using TensorBoard, and then switch to the Cloud TPU when your model is ready for full-scale training.

My training requirements are too complex or specialized for the Keras compile/fit API, is there a lower-level API that I can use?

If you need lower-level control when using TensorFlow, you can use custom training loops. The TensorFlow documentation describes how to use custom training loops specifically with TPUs and for the more general case using tf.distribute.