Using the TPU Estimator API on Cloud TPU

This document covers the usage of the TPUEstimator API with Cloud TPU. TPUEstimator simplifies running models on a Cloud TPU by handling numerous low-level, hardware-specific details.

Models written using TPUEstimator work across CPUs, GPUs, single TPU devices, and whole TPU pods, generally with no code changes. TPUEstimator also makes it easier to achieve maximum performance by automatically performing some optimizations on your behalf.

Standard TensorFlow Estimator API

At a high-level, the standard TensorFlow Estimator API provides:

  • Estimator.train() - train a model on a given input for a fixed number of steps.
  • Estimator.evaluate() - evaluate the model on a test set.
  • Estimator.predict() - run inference using the trained model.
  • Estimator.export_savedmodel() - export your model for serving.

In addition, Estimator includes default behavior common to training jobs, such as saving and restoring checkpoints, creating summaries for TensorBoard, etc.

Estimator requires you to write a model_fn and an input_fn that correspond to the model and input portions of your TensorFlow graph.

TPUEstimator programming model

The TPUEstimator wraps the computation (the model_fn) and distributes it to all available Cloud TPU cores. The learning rate must be tuned with the batch size.

  • The input_fn function models the input pipeline running on the remote host CPU. Use to program the input ops as described in the programmer's guide. Each invocation handles input of the global batch onto one device. The shard batch size is retrieved from params['batch_size']. Pro tip: return a dataset instead of tensors for optimal performance.

  • The model_fn function models the computation being replicated and distributed to the TPUs. The computation should contain only ops supported by Cloud TPU. TensorFlow ops includes the list of available ops.

A training example using TPUEstimator

The following code demonstrates training a MNIST model using TPUEstimator:

def model_fn(features, labels, mode, params):
  """A simple CNN."""
  del params  # unused

  input_layer = tf.reshape(features, [-1, 28, 28, 1])
  conv1 = tf.layers.conv2d(
      inputs=input_layer, filters=32, kernel_size=[5, 5], padding="same",
  pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
  conv2 = tf.layers.conv2d(
      inputs=pool1, filters=64, kernel_size=[5, 5],
      padding="same", activation=tf.nn.relu)
  pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
  pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])
  dense = tf.layers.dense(inputs=pool2_flat, units=128, activation=tf.nn.relu)
  dropout = tf.layers.dropout(
      inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)
  logits = tf.layers.dense(inputs=dropout, units=10)
  onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=10)

  loss = tf.losses.softmax_cross_entropy(
      onehot_labels=onehot_labels, logits=logits)

  learning_rate = tf.train.exponential_decay(
      FLAGS.learning_rate, tf.train.get_global_step(), 100000, 0.96)

  optimizer = tpu_optimizer.CrossShardOptimizer(

  train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
  return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op)

def make_input_fn(filename):
  """Returns an `input_fn` for train and eval."""

  def input_fn(params):
    """An input_fn to parse 28x28 images from filename using"""
    batch_size = params["batch_size"]

    def parser(serialized_example):
      """Parses a single tf.Example into image and label tensors."""
      features = tf.parse_single_example(
              "image_raw": tf.FixedLenFeature([], tf.string),
              "label": tf.FixedLenFeature([], tf.int64),
      image = tf.decode_raw(features["image_raw"], tf.uint8)
      image.set_shape([28 * 28])
      # Normalize the values of the image from the range [0, 255] to [-0.5, 0.5]
      image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
      label = tf.cast(features["label"], tf.int32)
      return image, label

    dataset =
        filename, buffer_size=FLAGS.dataset_reader_buffer_size)
    dataset = dataset.repeat()
    dataset = dataset.apply(
         dataset_parser, batch_size=batch_size,
    return dataset

  return input_fn

def main(unused_argv):


  run_config = tpu_config.RunConfig(
          allow_soft_placement=True, log_device_placement=True),

  estimator = tpu_estimator.TPUEstimator(


The following section covers the new concepts introduced in the above sample, to help you use Cloud TPU effectively.

TPUEstimator concepts

TPUEstimator uses an in-graph replication approach to running TensorFlow programs. In-graph (single-session) replication differs from the between-graph (multi-session) replication training typically used in distributed TensorFlow. The major differences include:

  1. In TPUEstimator, the TensorFlow session master is not local. Your Python program creates a single graph that is replicated across all of the cores in the Cloud TPU. A typical configuration sets the TensorFlow session master to be the first worker.

  2. The input pipeline is placed on remote hosts (instead of local) to ensure that training examples can be fed to the Cloud TPU as fast as possible. A dataset ( is required.

  3. Cloud TPU workers operate synchronously, with each worker performing the same step at the same time.

Converting from TensorFlow Estimator to TPUEstimator

We recommend that you port a small, simple model first and test end-to-end behavior. Doing so helps solidify your familiarity with the basic concepts of TPUEstimator. When your simple model runs, gradually add more functionality.

See the tutorials for a set of sample models and instructions for running them with Cloud TPU. Additional models are available on GitHub.

To convert your code from tf.estimator.Estimator class to use tf.contrib.tpu.TPUEstimator, change the following:

  • Change tf.estimator.RunConfig to tf.contrib.tpu.RunConfig.
  • Set TPUConfig (part of the tf.contrib.tpu.RunConfig) to specify the iterations_per_loop. iterations_per_loop is the number of iterations to run on the Cloud TPU for one call (per training loop).

Cloud TPU runs a specified number of iterations of the training loop before returning to the host. No checkpoints or summaries are saved until all Cloud TPU iterations are run.

  • In model_fn, use tf.contrib.tpu.CrossShardOptimizer to wrap your optimizer. For example:

     optimizer = tf.contrib.tpu.CrossShardOptimizer(
  • Change tf.estimator.Estimator to tf.contrib.tpu.TPUEstimator.

The default RunConfig saves summaries for TensorBoard every 100 steps and writes checkpoints every 10 minutes.


Why is required for the input pipeline?

There are two reasons:

  1. Your application code runs on the client while the TPU computation is executed on the worker. Input pipeline ops must be placed on the remote worker for good performance. Only (Dataset) supports this.

  2. In order to amortize the TPU launch cost, the model training step is wrapped in a tf.while_loop, such that one actually runs many iterations for a single training loop. Currently only can be wrapped by a tf.while_loop.

Where can I learn to run my model on Google Cloud Platform Cloud TPUs?

You can learn about running models on TPUEstimator by working through the MNIST tutorial and other tutorials on this site.

How do I profile a worker?

You can profile a worker using the instructions specified in the Cloud TPU tools.

Was this page helpful? Let us know how we did:

Send feedback about...