TPUEstimator API on Cloud TPU
This document covers the usage of the TPUEstimator API with Cloud TPU. TPUEstimator simplifies running models on a Cloud TPU by handling numerous low-level, hardware-specific details.
Models written using TPUEstimator work across CPUs, GPUs, single TPU devices, and TPU pods, generally with no code changes. TPUEstimator also makes it easier to achieve maximum performance by automatically performing some optimizations on your behalf.
To learn how machine learning workloads operate on TPU hardware in general, read the System Architecture documentation.
Standard TensorFlow Estimator API
At a high-level, the standard TensorFlow Estimator API provides:
Estimator.train()
- train a model on a given input for a fixed number of steps.Estimator.evaluate()
- evaluate the model on a test set.Estimator.predict()
- run inference using the trained model.Estimator.export_savedmodel()
- export your model for serving.
In addition, Estimator
includes default behavior common to training jobs,
such as saving and restoring checkpoints, creating summaries for TensorBoard,
and so on.
Estimator
requires you to write a model_fn
and an input_fn
that
correspond to the model and input portions of your TensorFlow graph.
TPUEstimator programming model
The TPUEstimator
wraps the computation (the model_fn
)
and distributes it to all available Cloud TPU cores. The learning
rate must be tuned with the batch size.
The
input_fn
function models the input pipeline running on the remote host CPU. Usetf.data
to program the input ops as described in the programmer's guide. Each invocation handles input of the global batch onto one device. The shard batch size is retrieved fromparams['batch_size']
. Pro tip: return a dataset instead of tensors for optimal performance.The
model_fn
function models the computation being replicated and distributed to the TPUs. The computation should contain only ops supported by Cloud TPU. TensorFlow ops includes the list of available ops.
A training example using TPUEstimator
The following code demonstrates training MNIST with TPUEstimator
:
def model_fn(features, labels, mode, params):
"""A simple CNN."""
del params # unused
input_layer = tf.reshape(features, [-1, 28, 28, 1])
conv1 = tf.layers.conv2d(
inputs=input_layer, filters=32, kernel_size=[5, 5], padding="same",
activation=tf.nn.relu)
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
conv2 = tf.layers.conv2d(
inputs=pool1, filters=64, kernel_size=[5, 5],
padding="same", activation=tf.nn.relu)
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])
dense = tf.layers.dense(inputs=pool2_flat, units=128, activation=tf.nn.relu)
dropout = tf.layers.dropout(
inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)
logits = tf.layers.dense(inputs=dropout, units=10)
onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=10)
loss = tf.losses.softmax_cross_entropy(
onehot_labels=onehot_labels, logits=logits)
learning_rate = tf.train.exponential_decay(
FLAGS.learning_rate, tf.train.get_global_step(), 100000, 0.96)
optimizer = tpu_optimizer.CrossShardOptimizer(
tf.train.GradientDescentOptimizer(learning_rate=learning_rate))
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op)
def make_input_fn(filename):
"""Returns an `input_fn` for train and eval."""
def input_fn(params):
"""An input_fn to parse 28x28 images from filename using tf.data."""
batch_size = params["batch_size"]
def parser(serialized_example):
"""Parses a single tf.Example into image and label tensors."""
features = tf.parse_single_example(
serialized_example,
features={
"image_raw": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features["image_raw"], tf.uint8)
image.set_shape([28 * 28])
# Normalize the values of the image from the range [0, 255] to [-0.5, 0.5]
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
label = tf.cast(features["label"], tf.int32)
return image, label
dataset = tf.contrib.data.TFRecordDataset(
filename, buffer_size=FLAGS.dataset_reader_buffer_size)
dataset = dataset.repeat()
dataset = dataset.apply(
tf.contrib.data.map_and_batch(
parser, batch_size=batch_size,
num_parallel_batches=8,
drop_remainder=True))
return dataset
return input_fn
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
run_config = tpu_config.RunConfig(
master=FLAGS.master,
model_dir=FLAGS.model_dir,
session_config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=True),
tpu_config=tpu_config.TPUConfig(FLAGS.iterations))
estimator = tpu_estimator.TPUEstimator(
model_fn=model_fn,
use_tpu=FLAGS.use_tpu,
train_batch_size=FLAGS.batch_size,
eval_batch_size=FLAGS.batch_size,
config=run_config)
estimator.train(input_fn=make_input_fn(FLAGS.train_file),
max_steps=FLAGS.train_steps)
The following section covers the new concepts introduced in the preceding sample, to help you use Cloud TPU effectively.
TPUEstimator concepts
TPUEstimator uses an in-graph replication approach to running TensorFlow programs. In-graph (single-session) replication differs from the between-graph (multi-session) replication training typically used in distributed TensorFlow. The major differences include:
In TPUEstimator, the TensorFlow session initiator is not local. Your Python program creates a single graph that is replicated across all cores in the Cloud TPU. A typical configuration sets the TensorFlow session initiator to be the first worker.
The input pipeline is placed on remote hosts (instead of local) to ensure that training examples are fed to the Cloud TPU as fast as possible. A dataset (
tf.data
) is required.Cloud TPU workers operate synchronously, with each worker performing the same step at the same time.
Converting from TensorFlow Estimator to TPUEstimator
We recommend that you port a small model first and test its
behavior. Doing so helps solidify your familiarity with the basic concepts of
TPUEstimator
. When your model runs, gradually add more functionality.
See the tutorials for a set of sample models and instructions for running them with Cloud TPU. Additional models are available on GitHub.
To convert your code from tf.estimator.Estimator
class to use
tf.contrib.tpu.TPUEstimator
, change the following:
- Change
tf.estimator.RunConfig
totf.contrib.tpu.RunConfig
. - Set
TPUConfig
(part of thetf.contrib.tpu.RunConfig
) to specify theiterations_per_loop
.iterations_per_loop
is the number of iterations to run on the Cloud TPU for onesession.run
call (per training loop).
Cloud TPU runs a specified number of iterations of the training loop before returning to the host. No checkpoints or summaries are saved until all Cloud TPU iterations are run.
In
model_fn
, usetf.contrib.tpu.CrossShardOptimizer
to wrap your optimizer. For example:optimizer = tf.contrib.tpu.CrossShardOptimizer( tf.train.GradientDescentOptimizer(learning_rate=learning_rate))
Change
tf.estimator.Estimator
totf.contrib.tpu.TPUEstimator
.
The default RunConfig
saves summaries for TensorBoard every 100 steps and
writes checkpoints every 10 minutes.
FAQ
Why is tf.data
required for the input pipeline?
There are two reasons:
Your application code runs on the client while the TPU computation is executed on the
worker
. Input pipeline ops must be run on the worker for good performance.tf.data
runs the ops on the worker.In order to amortize the TPU launch cost, the model training step is wrapped in a
tf.while_loop
, where oneSession.run
runs many iterations for a single training loop. Currently onlytf.data
can be wrapped in atf.while_loop
.
How do I profile model training performance?
You can profile model training performance using the profiler provided for TensorBoard.