Improve your model's performance with bfloat16
By default, TPUs perform
matrix multiplication operations with bfloat16
values and accumulations with IEEE float32
values. Using reduced-precision floating point numbers decreases time to
convergence without losing accuracy.
The dynamic range of bfloat16
and float32
are equivalent. However, bfloat16
uses half of the memory space. For more information about bfloat16
performance,
see A Study of BFLOAT16 for Deep Learning Training.
Use bfloat16 explicitly
While automatic format conversion in TPUs lets you avoid thinking about numerical
precision, you can achieve performance improvements by explicitly casting values
to bfloat16
. There are two reasons to explicitly cast values to bfloat16
:
Storing values in
bfloat16
format saves on-chip memory, enabling Cloud TPUs to train larger models or use larger batch sizes.Some operations are memory-bandwidth-bound, which means the amount of time it takes to load data from memory can slow down the overall time spent performing the computation. Storing operands and outputs of those operations in
bfloat16
format reduces the amount of data that must be transferred, improving overall speed.
To get started, we recommend getting some experience with one of the Cloud TPU reference models. After that, the profiling tools guide, and troubleshooting guide provide in-depth technical information to help you create and optimize machine learning models on your own.
Format conversion details
The format conversion from float32
to bfloat16
is automatically inserted by the
XLA compiler. On TPU, the rounding scheme in the conversion is
round to nearest even
and overflow to inf
. Also, the bfloat16
on Cloud TPU does not support
subnormals, so all subnormals are flushed to zero during the conversion.
Special values, such as NaN
and inf
, are preserved in the conversion.
The format conversion from bfloat16
to float32
is also automatically inserted
by the XLA compiler. Since float32
can represent all exact values in bfloat16
,
the conversion pads 16 zeros in the mantissa bits. Special values are
preserved in the conversion.
Checkpoints obtained from a model trained on Cloud TPUs can be deployed on other hardware platforms (for example, inference or fine-tuning on CPUs or GPUs) without extensive manual conversions.