Profile TensorFlow workloads
To profile a TensorFlow model on Cloud TPUs, you use TensorBoard and the TPU TensorBoard plug-in. TensorBoard is preinstalled on TPU VMs. For information on how to install the TPU TensorBoard plug in and capture a performance profile, see Profile your model with Cloud TPU tools. For general Cloud TPU performance information, see Cloud TPU performance guide.
For more information, see TensorBoard callbacks.
TensorFlow function performance notes
See the full list of TensorFlow operations available on Cloud TPU.
tf.matmul
- Transposing the result of either of the operands is effectively free.
- Note that
tf.matmul
supports fusing into its input and output. This means that activation functions or biases applied directly to the output oftf.matmul
have low overhead.
tf.nn.conv_n_d
, tf.nn.depthwise_conv2d
, tf.nn.separable_conv2d
- For the activations, the batch and feature dimensions are padded to a
multiple of either 8 or 128.
- First XLA tracks the most common size of batch dimensions for convolutions in the module. This helps distinguish between forward convolutions, activation gradient convolutions, and kernel gradient convolutions.
- If the most common batch size is greater than or equal to 64:
- Batch is padded to a multiple of 128 and feature padded to a multiple of 8 for forwards and backwards convolutions.
- Batch is padded to a multiple of 8 and feature padded to a multiple of 128 for gradient update convolutions.
- If the most common batch size is less than 64:
- Batch is padded to a multiple of 8 and feature padded to a multiple of 128 for forwards and backwards convolutions.
- Batch is padded to a multiple of 128 and feature padded to a multiple of 8 for gradient update convolutions.
- Transposing the activations right before sending it to a convolution is free if the transpose only swaps the input feature and batch dimensions.
- For the kernel, the input feature and output feature dimensions are padded
to a multiple of either 8 or 128. The exact determination is influenced by
the producers and other consumers of the kernel.
- Transposing a kernel right before sending it to a convolution is free if the transpose only swaps the input and output feature dimensions.
- For the result, the batch and feature dimensions are padded to a multiple of
either 8 or 128.
- Transposing the result of a convolution is free if the transpose only swaps the batch and output feature dimensions.
- Note that
tf.nn.conv_n_d
supports fusing into its result, the activations and/or the kernel. This means that activation functions or biases applied directly to the output have low overhead.
tf.nn.avg_pool
, tf.nn.max_pool
- The padding rules apply: spatial dimensions are more major than batch and feature. Each of batch and feature may be padded to a multiple of either 8 or 128.
- Typically, the layout of a pool operation matches the convolutions that flow in or out of it.
- The gradient calculation for
tf.nn.max
_pool may be slower than theirtf.nn.avg_pool
equivalent. Consider switching from max-pooling to average-pooling when possible.
tf.concat
, tf.slice
, tf.strided_slice
- Avoid unnecessary slices and concatenations. Slices and concatenations in a
dimension that has been padded is considerably more expensive.
- Data movement is minimized if the slice dimension has no padding overhead.
tf.transpose
- Transposing any of the operands of a
tf.matmul
or its result are free. - Transposing the activations of a
tf.conv_n_d
is free if it swaps the batch and input feature dimensions. - Transposing the kernel of a
tf.conv_n_d
is free if it swaps the input and output feature dimensions. - Transposing the result of a
tf.conv_n_d
is free if it swaps the batch and output feature dimensions.
tf.batch_to_space
, tf.space_to_batch
, tf.space_to_depth
, tf.depth_to_space
- These are costly because they involve moving data from padded to unpadded dimensions and vice-versa.
tf.reshape
- Reshaping may be costly on Cloud TPU when moving around data in a padded dimension.
- It can be beneficial to reshape data to R1 on the host and reshape it back
to some higher dimension shape on the device if there is substantial
padding. This can make transfers between host and device more efficient.
- It can also help with peak memory utilization because the packed parameter can be unpacked on-demand.
tf.random_uniform
, tf.distributions.Bernoulli
, tf.random_normal
, tf.multinomial
- Pseudo random-number generation for uniform or Bernoulli distributions is very fast.
- Normal distributions are slightly more expensive than uniform or Bernoulli distributions.
- Pseudo random-number generation for Categorical and Multinomial distributions is considerably more expensive.
tf.reduce_all
, tf.reduce_any
, tf.reduce_logsumexp
, tf.reduce_max
, tf.reduce_min
, tf.reduce_prod
, tf.reduce_sum
- Multiple reductions with the same input and output shape can be performed in
parallel using operation fusion.
- Try to rewrite sequential chains of reductions into parallel ones, if possible.
Reductions support fusing elementwise operations into their input but not their output. When possible, rewrite expressions to promote fusion. For example:
tf.multiply(tf.reduce_sum(x), y)
Into:
tf.reduce_sum(tf.multiply(x, y))
tf.nn.batch_normalization
, tf.nn.fused_batch_norm
, tf.layers.batch_normalization
The XLA compiler can efficiently lower TensorFlow's fused variants of batch normalization. Using them can be considerably more efficient than the alternative.
- Prefer
tf.nn.fused_batch_norm
overtf.nn.batch_normalization
. - For
tf.layers.batch_normalization
, set the "fused" argument to true.
- Prefer