This guide provides troubleshooting help for users who want to run their own TensorFlow models on Cloud TPU. For a more general guide to getting started with Cloud TPU, see the quickstart or the MNIST tutorial.
Overview
The recommended strategy for running TensorFlow models on the Cloud TPU is to
use the
TPUEstimator
API. If you are currently using TensorFlow's
Estimator
API, switching to
TPUEstimator
typically only requires changing a few lines of code. The
recommended way to load data into TPUEstimator
is with the
tf.data API.
See the MNIST example in Using TPUEstimator API with Cloud TPU
for a real-world example implementing TPUEstimator
and tf.data
. See also,
MNIST Estimator to TPUEstimator.
After you convert your model to TPUEstimator
, make sure the model works with
the flag use_tpu=False
. Setting the flag to false causes TensorFlow to fall
back to the Estimator
API and not use any code related to the TPU.
Any issues you encounter while running your model with use_tpu=False
are not related to the TPU and are out-of-scope of this guide. Instead, see the
TensorFlow programmers guide.
Ideally, once a model can be run successfully using TPUEstimator
and
use_tpu=False
, running it on the TPU is simply a matter of setting
use_tpu=True
and pointing master
to a TPU server URL (typically through the
use of a cluster
resolver.
However, because TensorFlow models can be very complex and the TPU uses its own
execution engine, you can run into issues that are specific to the TPU. These
issues fall into the following broad categories:
The training script is not able to connect to the TPU server at all.
The TPU returns an error when attempting to execute the model.
The model can run on the TPU, but the training speed is not as fast as expected.
Additionally, this guide contains a FAQ about general functionality available on TPUs.
For more specialized help porting particular types of neural networks to the TPU, see the Cloud TPU Tutorials.
Trouble connecting to the TPU server
When running a model on the TPU, you must pass a remote TPU server URL to the
master
parameter in RunConfig
. Under the hood, TensorFlow creates a remote
tf.Session
with this server. This section provides troubleshooting for
situations where TensorFlow stops responding or prints an error when connecting to the TPU
server. Note that the TPU graph compilation step can take a long time for large
models, so let the script execute for at least 5 minutes before concluding that
it has stopped responding.
The first step is to verify whether the issue is with the server itself, or with your TensorFlow training pipeline. To do this, run the MNIST tutorial using your TPU server URL and verify that it works correctly. If there are still connection issues with the MNIST tutorial, this confirms that it is an issue with the TPU server. In this case:
Run the following command to list the available TPUs:
(vm)$ gcloud compute tpus list
You may need to also set your
zone
andproject
, as shown in the MNIST tutorial. This prints output such as:NAME ZONE ACCELERATOR_TYPE NETWORK_ENDPOINT NETWORK RANGE STATUS demo-tpu us-central1-b v2-8 10.240.1.2:8470 default 10.240.1.0 READY
Verify that you are passing the correct value to
--tpu
(demo-tpu
in the above example), and that this TPU is listed asREADY
. Also make sure that yourzone
andproject
have been set with:(vm)$ gcloud config set project your-project-name
(vm)$ gcloud config set compute/zone us-central1-b
If your TPU is not listed as
READY
or you are still having trouble connecting, manually restart the server withgcloud compute tpus stop $TPU_SERVER_NAME && gcloud compute tpus start $TPU_SERVER_NAME
. In the above example$TPU_NAME
isdemo-tpu
. This may take several minutes.Re-run the above
... tpus list
command and wait for the TPU to be in theREADY
state. This may take several minutes.Try to run the MNIST tutorial again.
If you are still having trouble running the MNIST tutorial, ask for help using one of the mechanisms described in Getting Support.
If the MNIST example runs correctly but your model still stops responding, then the issue
is likely with your training pipeline. First, make sure that your model is using
the
TPUEstimator
API, since this not only handles the complex processing pipeline, but also
allows effortless switching between TPU and non-TPU execution with the use_tpu
flag. Please see the TPU tutorials for several examples
of how to use TPUEstimator
. Once your model is using the TPUEstimator
API,
please verify that it runs correctly when use_tpu=False
is set. If your model
does not run correctly when use_tpu=False
is set, the issue is unrelated to
the TPU.
Debugging common errors
Cannot use local filesystem
Error Message
InvalidArgumentError: Unimplemented: File system scheme '[local]' not
implemented
Details
All input files and the model directory must use a cloud storage bucket path
(gs://bucket-name/...
), and this bucket must be accessible from the TPU
server. Note that all data processing and model checkpointing is performed on
the TPU server, not the local machine. For information on how to properly
configure cloud storage for use with the TPU, see the guide Connecting to Cloud
Storage Buckets.
tf.data.Dataset.cache() cannot cache to the local filesystem
Error Message
tensorflow.python.framework.errors_impl.UnimplementedError: File system
scheme '[local]' not implemented (file: '[filename].lockfile')
Details
A tf.data.Dataset can be cached. .cache()
call has two implementations:
in memory, if no argument is passed.
on a file system if a file path is passed as an argument.
On Cloud TPU, (1) works (as long as it fits in available memory), but (2) doesn't work when saving to the local file system, and results in the error above.
The following code snippets illustrate the two situations:
(1) import tensorflow as tf def main(): print('Hello world!') ds = tf.data.Dataset.range(10) ds = ds.cache() runs to completion. (2) import tensorflow as tf def main(): print('Hello world!') ds = tf.data.Dataset.range(10) ds = ds.cache('/tmp/foo') generates the error.
The API guide
contains more detailed information on tf.data.Dataset.cache()
.
Unsupported data type
Error Message
TypeError: DataType is not a supported TPU infeed type.
Details
Currently, only the tf.float32
, tf.int32
, tf.bfloat16
, and tf.bool
data
types are supported on the TPU. Other common data types, such as tf.uint8
,
tf.string
, and tf.int64
, must be converted to one of the supported data
types during data pre-processing (that is, in the input_fn
of TPUEstimator
).
See the MNIST tutorial for another example. As an
example, this code snippet from MNIST converts an image
tensor stored as
tf.uint8
byte sequence to a tf.float32
tensor:
image = tf.decode_raw(image, tf.uint8)
image = tf.cast(image, tf.float32)
image = tf.reshape(image, [784])
This snippet converts a label
tensor stored as tf.int64
to a tf.int32
tensor:
label = tf.cast(label, tf.int32)
Dynamic shapes not supported
Error Message
ValueError: shape [Shape] must have a fixed size for dimension
d that is known at graph construction time.
Details
To execute a model on the TPU, TensorFlow compiles the model using the XLA framework. While this compilation step significantly improves training speed and memory usage, the shapes (dimension sizes) of all tensors in the graph must be static, that is, their values must be known at graph compilation time. If any shapes cannot be determined at compile time, TPU compilation fails with an error like the one above.
One common op that returns a dynamic shape is dataset.batch(batch_size)
, since
the number of samples remaining in a stream might be less than the batch size.
Therefore, when training on the TPU, use
tf.contrib.data.batch_and_drop_remainder(batch_size)
. This potentially
drops the last few samples from a file to ensure that every batch has a static
shape of batch_size
. For example:
dataset = ...
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
Unavailable TensorFlow op
Error Message
NotFoundError: No registered 'OpName' OpKernel for XLA_TPU_JIT
devices compatible with node
Details
The model uses a TensorFlow op which is not currently available on the TPU.
For a list of ops available on the TPU, along with plans for future support and suggestions for workarounds, please see the guide to available TensorFlow Ops.
Out-of-memory error message
Error Message
ResourceExhaustedError: Ran out of memory in memory space hbm; used:
YYY; limit: 7.48G.
Details
Each Cloud TPU is made of eight TPU cores, which each have 8GB of RAM (or HBM, High-Bandwidth Memory). This memory is used to store the weight (variable) tensors, as well as intermediate result tensors needed for gradient computation. If the model is too large to fit into TPU RAM, the initialization fails and the above error message is printed. See the section on reducing memory usage for more help.
Not using CrossShardOptimizer
Error Message
ValueError: CrossShardOptimizer must be used for model training on TPUs.
Details
When defining a model using the TensorFlow Python API, the vast majority of code
written by the user does not need to be specialized for the TPU. The most
significant exception is the optimizer, which must be wrapped in
tf.contrib.tpu.CrossShardOptimizer()
as shown below:
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
if FLAGS.use_tpu:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
train_op=optimizer.minimize(loss, tf.train.get_global_step())
Each Cloud TPU is made of 8 TPU cores, which are independent processing units. For each training step (i.e., weight update), each TPU core runs the forward pass and gradient computation on an independent mini-batch of data, and then all of the cores exchange gradients with one another. In most cases, this is mathematically equivalent to computing the gradients on one large batch, although there are some caveats explained in Understanding Data Sharding.
CrossShardOptimizer
is the op responsible for this gradient exchange. By
default, CrossShardOptimizer
computes gradients of the mean loss across
the cores, but it can be configured to compute the sum loss by passing
reduction=losses.Reduction.SUM
.
Unable to connect to TPU server
Error Message
An error was raised while a session was being created. This may be due to
a preemption of a connected worker or parameter server. A new session is
created.
Details
This error is printed when TensorFlow cannot connect to the TPU server URL that
is passed to master
. For help, see the section on trouble connecting to the
TPU server.
Errors in the middle of training
If a model cannot be executed successfully on the TPU, any errors related to
this are designed to be caught during initialization. Therefore, it is
rare for a model to fail in the middle of training. If this does happen, the
most likely cause is an issue in the data pre-processing function. For example,
when using the Dataset
API, you typically need to call dataset =
dataset.repeat()
, otherwise the training fails after making one pass
through the data. Dynamic execution ops like tf.while_loop()
can also only
fail in a way that is dependent on the input data. There is also the rare
possibility of spurious hardware or network failures.
Problems stopping execution
If TensorFlow encounters an error during TPU execution, the script
sometimes seems to stop responding rather than exit to the shell. If this happens, press
CTRL+\
on the keyboard to trigger aSIGQUIT
, which causes
Python to exit immediately.
Similarly, pressing CTRL+C
during TPU execution does not shut down TensorFlow
immediately, but instead waits until the end of the current iteration loop to
exit cleanly. Pressing CTRL+\
causes Python to exit immediately.
If you encounter any new errors like DeadlineExceededError when re-connecting to
the TPU server after exiting in this manner, then manually reset the TPU server
with the command
gcloud compute tpus stop $TPU_SERVER_NAME && gcloud compute tpus start $TPU_SERVER_NAME
,
where $TPU_SERVER_NAME
is taken from the first column of the
gcloud compute tpus list
command.
Reducing memory usage
If you encounter an out-of-memory error when executing your model on the TPU, you must take steps to reduce the model's memory usage. This section describes several root causes of memory issues and provides guidelines for fixing them.
Large number of model weights
Possible Cause of Memory Issue
Each float32
model weight requires 4 bytes. These weights are replicated on
each TPU core. Therefore, a model with
hundreds of millions of weights is likely to be too large to fit on the TPU.
How to Reduce Memory Usage
- Certain optimizers require extra memory per weight to store update
statistics. Notably,
AdamOptimizer
andAdadeltaOptimizer
both require an extra 8 bytes per weight.AdagradOptimizer
andMomentumOptimizer
require an extra 4 bytes per weight. The standardGradientDescentOptimizer
requires no extra storage, although it may not perform as well as other optimizers in terms of final model accuracy. The experimentalAdafactorOptimizer
requires almost no extra memory and performs as well as the baseline Adam optimizer when training Transformer models. - If the majority of weights are word embeddings, techniques such as WordPiece have been shown to substantially reduce vocabulary size while increasing accuracy across a variety of tasks.
- An upcoming release of TensorFlow will have experimental support for 16-bit floating point weights and gradients, which will reduce the memory requirements by half.
Excessive tensor padding
Possible Cause of Memory Issue
Tensors in TPU memory are padded, that is, the TPU rounds up the sizes of tensors stored in memory to perform computations more efficiently. This padding happens transparently at the hardware level and does not affect results. However, in certain cases the padding can result in significantly increased memory use and execution time.
How to Reduce Memory Usage
The TPU software attempts to lay out tensors in memory to maximize computational efficiency and minimize padding. This memory layout process is complex, however, for the best results the model should obey the following rule of thumb. To minimize memory overhead and maximize computational efficiency, one of the following must be true:
The total batch size should be a multiple of 64 (8 per TPU core), and feature dimensions should be a multiple of 128,
or
The total batch size should be a multiple of 1024 (128 per TPU core), and feature dimensions should be a multiple of 8.
Using a batch size of 1024 and feature dimensions that are a multiple of 128 results in the best efficiency, although this may not be possible for all models. For clarity, "feature dimension" refers to the hidden size of a fully-connected layer or the number of output channels in a convolution. Not all layers can conform to this rule, especially the first and last layers of the network. This is fine, and it is expected that most models require some amount of padding.
Batch size too large
Possible Cause of Memory Issue
When training a neural network on a CPU, GPU, or TPU, the memory use comes from two places:
- Storing the weights, the weight gradients, and optimizer-specific statistics such as momentum. The memory use is directly proportional to the number of weights in the model, but not the batch size.
- Storing intermediate activations from the forward pass necessary to compute the backward pass. The memory use is directly proportional to the batch size, layer sizes, and number of layers.
Therefore, the memory required by a model is largely dependent on the batch size.
How to Reduce Memory Usage
Try to slowly reduce the batch size until it fits in memory, making sure that the total batch size is a multiple of 64 (the per-core batch size should be a multiple of 8). Keep in mind that larger batch sizes are more efficient on the TPU. A total batch size of 1024 (128 per core) is generally a good starting point.
Model too large
Possible Cause of Memory Issue
The memory required by a model is highly dependent on the number of operators in
the graph (that is, layers in the network). This storage requirement is separate
from the number of weights. For example, computing the gradient of an operator
like tf.nn.conv2d()
may increase memory use, in addition to any memory
used to store weights.
TPU engine attempts to strategically re-compute certain operators to fit the model in memory (called rematerialization, similar to gradient checkpointing), but it is not always able to do this.
How to Reduce Memory Usage
If the model cannot be run on the TPU even with a small batch size (for example, 64), try reducing the number of layers or the layer sizes. An upcoming release of TensorFlow will support "model parallelism" on the TPU, which will allow significantly larger models to be run on Cloud TPU by running different parts of the model on different TPU cores.
Improving training speed
If your model is able to run successfully on the TPU, but the training speed is less than expected, this section outlines several potential ways to improve the speed.
Too few iterations per loop
Description of Performance Issue
The iterations_per_loop
parameter to TPUConfig
controls how many batches
of data are sent to the TPU in a single "training loop." Each training loop
requires significant communication between the local machine and the TPU server,
so if iterations_per_loop
is too small, can substantially slow down training.
How to Know if Your Model is Affected
If the logging message Enqueue next (X) batch(es) of data to infeed
is printed
very frequently (for example, every 3 seconds), then your training might have
significant overhead from the training loop.
How to Mitigate
Set iterations_per_loop
to a larger value. In the MNIST tutorial, this is
controlled by the --iterations
flag. As long as the Enqueue next (X)
batch(es) of data to infeed
message is not printed more than a few times a
minute, then the current value should be sufficient. Note that
iterations_per_loop
can be set to a very large value, with the only downside
being that logging messages and checkpointing can only occur at the end of a
loop.
Input processing bottleneck
Description of Performance Issue
While the TPU is training on a particular chunk of data, the input processing function prepares the next chunk of data on the CPU. Thus, if the input function takes less time than the model function, the cost of input processing is effectively zero. However, an input function that takes longer than the model function creates a bottleneck.
How to Know if Your Model is Affected
Follow the instructions in the Cloud TPU Tools: Input Pipeline Analyzer for viewing the input pipeline analysis in TensorBoard:
The input pipeline analysis page displays a clear summary which shows if your model is bottlenecked by input processing. The same page also shows per-op execution time, which allows you to pinpoint problematic ops.
How to Mitigate
There are several possible mitigations when loading data with the Dataset
API:
- Store your data as a collection of
tf.train.Example
structures inTFRecord
files, and load them withTFRecordDataset
. See the Dataset API tutorial or the ResNet tutorial for examples. - Use
dataset.cache()
and/ordataset.prefetch()
to buffer the input data. This prevents sporadic slowdowns in file access from creating a bottleneck. - Specify the
num_parallel_calls
parameter of thedataset.map()
function to enable multi-threadedmap()
ops. - Perform expensive data pre-processing offline as a one time cost, rather than incurring the cost through every epoch of every training.
All input processing is performed on CPUs located on the TPU server, not on the local machine, so the speed of the local machine is not a factor.
Too many non-matrix multiplication ops
Description of Performance Issue
The Cloud TPU can perform matrix multiplications and convolutions at incredibly high speeds. The majority of other TensorFlow ops do have efficient implementations on the TPU, but these are not the TPU's primary strength relative to other hardware. Therefore, a model should be dominated by matrix multiplications or convolutions to fully take advantage of the TPU.
How to Know if Your Model is Affected
The guide Cloud TPU Tools: Op Profile describes how to generate a performance profile for your model broken down by op type. In general, the vast majority of modern neural network architectures are dominated by matrix multiplications and convolutions.
How to Mitigate
If the lack of the matrix multiplications in your model was primarily motivated by training speed issues on other hardware, you are encouraged to re-benchmark those models on the TPU for better speed performance. If the lack of matrix multiplications is a fundamental property of the model, then the TPU might not be the optimal hardware choice.
Excessive tensor padding
Description of Performance Issue
The TPU pads tensors in memory so that the TPU can use its computational units efficiently. The padding can increase usage of both memory and memory bandwidth. See the section on tensor padding for help understanding and fixing tensor padding issues.
Batch size too small
Description of Performance Issue
As a general rule, using larger batch sizes results in greater training speed on the TPU, in terms of samples/second.
How to Know if Your Model is Affected
The batch size of any model should always be at least 64 (8 per TPU core), since the TPU always pads the tensors to this size. The ideal batch size when training on the TPU is 1024 (128 per TPU core), since this eliminates inefficiencies related to memory transfer and padding.
How to Mitigate
It is recommended to use the largest batch size which fits in to memory and is a multiple of 64. The easiest way to achieve this is to start with 1024, and if this causes an out-of-memory error then try reducing the batch size until the model runs successfully. Changing the batch size of a model may require adjusting other hyperparameters to achieve the same model accuracy, such as the the learning rate, but this must be evaluated on a case-by-case basis.
Layer sizes too small
Description of Performance Issue
Even when a model is dominated by matrix multiplications or convolutions, the TPU may not run at full efficiency if the input tensors are small. When compared to other hardware, the TPU runs most efficiently when both the batch size and layer sizes are large (for example, dimension >= 512).
How to Know if Your Model is Affected
As a general rule, layer sizes smaller than 128 achieve poor efficiency on the TPU, since 128 is the native dimension of the TPU matrix multiplication unit. For fully-connected layers, a minimum hidden size of 512 is recommended in order to achieve high efficiency. Note that convolutional layers typically do not need to be as large as fully connected layers to achieve an equal efficiency level. For example, a 3 × 3 convolution of size 256 achieves similar (high) efficiency compared to a fully-connected layer of size 2048, since 3 × 3 × 256 = 2304.
How to Mitigate
If the primary motivation for small layer sizes in your model is training speed, you are encouraged to re-benchmark your models with larger layers on the TPU. For example, increasing the output size of a layer from 256 to 512 may only increase the training time by 20% even though the model is performing 2x the computation.
Op-level model profiling
It is often useful to measure op-level execution time and memory usage in order
to identify performance bottlenecks. For instructions on how to do this,
see the guide Cloud TPU Tools: Trace
Viewer.
Debugging drops in model accuracy
One of the goals of the Cloud TPU ecosystem is that any model that is currently being trained on a CPU or GPU achieves a very similar accuracy when it is trained on the TPU, with perhaps minor adjustments to hyperparameters like the batch size and learning rate. Occasionally, however, users can observe a degradation in accuracy when training models on the TPU. Debugging such issues can be extremely frustrating due to the random nature of neural network training. This section provides guidance on how to pinpoint the root cause of any drops in model accuracy when porting a model to the TPU.
Understanding data sharding (data parallelism)
One of TensorFlow's primary goals is that each op should produce nearly identical results whether it is executed on the CPU, GPU, or TPU. There are certain exceptions to this, such as random ops. In general, if you find any significant difference between the output of non-random ops on the TPU and CPU, report it as a bug.
However, for the training pipeline as a whole, there is a significant
difference between training on the CPU/GPU and TPU: When using TPUEstimator
and use_tpu=False
, TensorFlow falls back to its standard execution engine.
This engine trains with one batch per step. However, when training on the actual
TPU, TensorFlow performs data sharding, also known as "data parallelism with
synchronous SGD". The reason is
that each Cloud TPU is made of 8 TPU cores which operate as
independent processing units. So, for each step in the training, each TPU core
is passed a batch of data, computes the weight gradients, exchanges the
gradients with one another, and then computes the weight update. By default,
the loss is averaged across the cores, but it can instead be summed by changing
the parameter of CrossShardOptimizer
.
If the total loss of the model can be computed as the average (or sum) of independent per-sample losses, then this procedure is mathematically equivalent to training on a single large batch. The most common op which is not independent per-sample is batch normalization, which runs over each per-core batch separately. For example, if the total batch size is 128, then the per-core batch size is 16, and each of the 8 cores performs batch norm over its own 16 samples. In some cases, performing batch normalization over small batches (for example, less than 32) has been found to cause degredations in accuracy. In the ideal scenario, the total batch size when training on the TPU can be large (for example, 256 to 1024), so batches of that size are not a major issue. However, if such a batch size is too large to fit into memory, the effect of sharding must be evaluated on a case-by-case basis.
Because of the complexities introduced by sharding, the first step in debugging drops in model accuracy is to run a deterministic, single-core TPU training, and compare it to a model trained on the CPU/GPU. Generally, this can be done quickly as it does not require training a model to convergence.
Deterministic training
One reason why it is difficult to debug differences in model accuracy is that TensorFlow uses different weight initialization and data shuffling each time a model is trained. It is beneficial to modify the training procedure to be deterministic, so that multiple runs produce nearly identical models. This section demonstrates how to run the MNIST tutorial deterministically:
- Generate an initial checkpoint file by running for a single step on the CPU. The step is used to achieve deterministic weight initialization. This can also be achieved by seeding the variable initializers, but that is more difficult.
# Run training for 1 step to create an initial checkpoint. python mnist_tpu.py \ --use_tpu=False \ --data_dir=${STORAGE_BUCKET}/data/ \ --model_dir=${STORAGE_BUCKET}/init_output \ --random_seed=12345 \ --iterations=1 --train_steps=1
- Modify any data shuffling functions in your input function to use a random seed. This has already been done in the MNIST tutorial. This works for the input data processing ops because those always run on the CPU. Random ops in the model function may not be deterministic between the TPU and CPU. For example:
# In the flag definitions
tf.flags.DEFINE_integer("batch_size", None, "Random seed for training")
# In the input_fn
if FLAGS.random_seed is not None:
dataset = dataset.shuffle(seed=FLAGS.random_seed)
-
Run the same model twice on the CPU, to verify that the training is
deterministic. Note that the training must be run for a reasonable number of
steps (for example, 1000) but it does not need to be run to convergence, as this
can be very slow on the CPU.
Since the CPU training is compared to a single-core TPU training, use a batch size that can fit on a single TPU core (typically, the full batch size divided by 8). TensorFlow does not guarantee bit-for-bit determinism between runs, but the loss should be very close:
# Copy the initial weights gsutil mkdir ${STORAGE_BUCKET}/cpu_output_1 gsutil cp -f ${STORAGE_BUCKET}/init_output/* ${STORAGE_BUCKET}/cpu_output_1 gsutil mkdir ${STORAGE_BUCKET}/cpu_output_2 gsutil cp -f ${STORAGE_BUCKET}/init_output/* ${STORAGE_BUCKET}/cpu_output_2 # Run 1 python mnist_tpu.py \ --use_tpu=False \ --data_dir=${STORAGE_BUCKET}/data/ \ --model_dir=${STORAGE_BUCKET}/cpu_output_1 \ --batch_size=128 \ --random_seed=12345 \ --train_steps=2000 \ --eval_steps=10 # Output 1 accuracy = 0.9910644, global_step = 1000, loss = 0.025323588 # Run 2 python mnist_tpu.py \ --use_tpu=False \ --data_dir=${STORAGE_BUCKET}/data/ \ --model_dir=${STORAGE_BUCKET}/cpu_output_1 \ --batch_size=128 \ --random_seed=12345 \ --train_steps=2000 \ --eval_steps=10 # Output 2 accuracy = 0.9910644, global_step = 1000, loss = 0.025323414
Single-core TPU training
Once you can run the MNIST tutorial deterministically, the next step is to replicate the CPU-trained results on the TPU, using a single TPU core to pinpoint whether the issue is related to data sharding or to the TPU execution engine itself.
Here's how to execute single-core training and evaluation on the MNIST tutorial:
# Use the same weight initialization as the CPU gsutil cp -f ${STORAGE_BUCKET}/init_output/* ${STORAGE_BUCKET}/tpu_output # Run training for 1000 steps python mnist.py \ --use_tpu=True \ --master=$GRPC_SERVER \ --train_file=${STORAGE_BUCKET}/data/train.tfrecords \ --model_dir=${STORAGE_BUCKET}/tpu_output \ --random_seed=12345 \ --batch_size=128 \ --train_steps=1000 \ --eval_steps=10 accuracy = 0.9910644, global_step = 1000, loss = 0.02514153
The loss will not exactly match the CPU-trained model, but it should be close. If it isn't close for your model, this might indicate that you have found a bug in the TPU execution engine. Before submitting a bug report, double check the following:
You are passing
num_shards=1
toTPUConfig
.You do not have any random ops in your model function, and any random ops in your input function are being seeded correctly.
You are using the same initial checkpoint file for the CPU and TPU training.
Debugging multi-core TPU training
If your model does achieve the same loss on the CPU and single-core TPU, then the issue is likely one of the following:
(a) The degradation is due to the natural random variance when training neural models with different initializations.
(b) The degradation is due to an issue related to data sharding on the TPU.
To determine whether (a) is the issue, it might be useful to re-train the full model on the CPU/GPU and multi-core TPU using the same weight initialization, as above.
If you are confident that the drop in accuracy is statistically significant, then the most likely issues related to data sharding are:
- If your model computes the loss as the sum of per-sample errors, you
probably want to pass
reduction=losses.Reduction.SUM
toCrossShardOptimizer
. By default,CrossShardOptimizer
computes the mean of the losses, rather than the sum. - If your model uses batch normalization, a total batch size less than 256 (for example, less than 32 per core) might reduce accuracy.
- If your model has a batch-wise loss function, then this will be affected by sharding. Such loss functions are typically quite specialized. For example, Karras et al. 2017 uses a batch discriminator when training a generative adversarial network.