Cloud TPU v5e Inference Converter introduction

Introduction

Cloud TPU Inference Converter prepares and optimizes a TensorFlow 2 (TF2) model for TPU inference. The converter runs in a local or TPU VM shell. The TPU VM shell is recommended because it comes preinstalled with the command line tools needed for the converter. It takes an exported SavedModel and performs the following steps:

  1. TPU Conversion: It adds TPUPartitionedCall and other TPU ops to the model to make it servable on the TPU. By default, a model exported for inference doesn't have such ops and cannot be served on the TPU, even if it was trained on the TPU.
  2. Batching: It adds batching ops to the model to enable in-graph batching for better throughput.
  3. BFloat16 Conversion: It converts the data format of the model from float32 to bfloat16 for better computational performance and lower High Bandwidth Memory (HBM) usage on the TPU.
  4. IO Shape Optimization: It optimizes the tensor shapes for data transferred between the CPU and TPU to improve bandwidth utilization.

When exporting a model, users create function aliases for any functions they would like to run on the TPU. They pass these functions to the Converter and the Converter places them on the TPU and optimizes them.

The Cloud TPU Inference Converter is available as a Docker image which can be executed in any environment with Docker installed.

Estimated time to complete the steps shown above: ~20 min - 30 min

Prerequisites

  1. The model must be a TF2 model and exported in the SavedModel format.
  2. The model must have a function alias for the TPU function. See the code example for how to do this. The following examples uses tpu_func as the TPU function alias.
  3. Make sure your machine's CPU supports Advanced Vector eXtensions (AVX) instructions, as the Tensorflow library (the dependency of the Cloud TPU Inference Converter) is compiled to use AVX instructions. Most CPUs have the AVX support.
    1. You can run lscpu | grep avx to check whether the AVX instruction set is supported.

Before you begin

Before you begin setup, do the following:

  • Create a new project: In the Google Cloud console, on the project selector page, select or create a Cloud project.

  • Set up a TPU VM: Create a new TPU VM using Google Cloud console or gcloud, or use an existing TPU VM to run inference with the converted model on the TPU VM.

    • Make sure the TPU VM image is TensorFlow based. For example, --version=tpu-vm-tf-2.11.0.
    • The converted model will be loaded and served on this TPU VM.
  • Ensure you have the command line tools you need to use Cloud TPU Inference Converter. You can install the Google Cloud SDK and Docker locally or use a TPU VM which has this software installed by default. You use these tools to interact with the Converter image.

    Connect to the instance with SSH using the following command:

    gcloud compute tpus tpu-vm ssh ${tpu-name} --zone ${zone} --project ${project-id}
    

Environment Setup

Set up your environment from your TPU VM shell or from your local shell.

TPU VM Shell

  • In your TPU VM shell, run the following commands to allow non-root docker usage:

    sudo usermod -a -G docker ${USER}
    newgrp docker
    
  • Initialize your Docker Credential helpers:

    gcloud auth configure-docker \
      us-docker.pkg.dev
    

Local Shell

In your local shell, set up the environment using the following steps:

  • Install the Cloud SDK, which includes the gcloud command-line tool.

  • Install Docker:

  • Allow non-root Docker usage:

    sudo usermod -a -G docker ${USER}
    newgrp docker
    
  • Login in to your environment:

    gcloud auth login
    
  • Initialize your Docker Credential helpers:

    gcloud auth configure-docker \
        us-docker.pkg.dev
    
  • Pull the Inference Converter Docker image:

      CONVERTER_IMAGE=us-docker.pkg.dev/cloud-tpu-images/inference/tpu-inference-converter-cli:2.13.0
      docker pull ${CONVERTER_IMAGE}
      

Converter Image

The Image is for doing one-time model conversions. Set the model paths and adjust the converter options to fit your needs. The Usage Examples section provides several common use cases.

docker run \
--mount type=bind,source=${MODEL_PATH},target=/tmp/input,readonly \
--mount type=bind,source=${CONVERTED_MODEL_PATH},target=/tmp/output \
${CONVERTER_IMAGE} \
--input_model_dir=/tmp/input \
--output_model_dir=/tmp/output \
--converter_options_string='
    tpu_functions {
      function_alias: "tpu_func"
    }
    batch_options {
      num_batch_threads: 2
      max_batch_size: 8
      batch_timeout_micros: 5000
      allowed_batch_sizes: 2
      allowed_batch_sizes: 4
      allowed_batch_sizes: 8
      max_enqueued_batches: 10
    }
'

Inference with the converted model in TPU VM

# Initialize the TPU
resolver = tf.distribute.cluster_resolver.TPUClusterResolver("local")
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)

# Load the model
model = tf.saved_model.load(${CONVERTED_MODEL_PATH})

# Find the signature function for serving
serving_signature = 'serving_default' # Change the serving signature if needed
serving_fn = model.signatures[serving_signature]
# Run the inference using requests.
results = serving_fn(**inputs)
logging.info("Serving results: %s", str(results))

Usage Examples

Add a function alias for the TPU function

  1. Find or create a function in your model that wraps everything you want to run on the TPU. If @tf.function doesn't exist, add it.
  2. When saving the model, provide SaveOptions like below to give model.tpu_func an alias func_on_tpu.
  3. You can pass this function alias to the converter.
class ToyModel(tf.keras.Model):
  @tf.function(
      input_signature=[tf.TensorSpec(shape=[None, 10], dtype=tf.float32)])
  def tpu_func(self, x):
    return x * 1.0

model = ToyModel()
save_options = tf.saved_model.SaveOptions(function_aliases={
    'func_on_tpu': model.tpu_func,
})
tf.saved_model.save(model, model_dir, options=save_options)

Convert a model with multiple TPU functions

You can put multiple functions on the TPU. Simply create multiple function aliases and pass them in converter_options_string to the converter.

tpu_functions {
  function_alias: "tpu_func_1"
}
tpu_functions {
  function_alias: "tpu_func_2"
}

Quantization

Quantization is a technique that reduces the precision of the numbers used to represent a model's parameters. This results in a smaller model size and faster computation. A quantized model provides gains in inference throughput as well as smaller memory usage and storage size, at the cost of small accuracy drops.

The new Post-Training Quantization feature in TensorFlow that targets TPU, is developed from the similar existing feature in TensorFlow Lite that is used to target mobile and edge devices. To learn more about quantization in general, you can take a look at TensorFlow Lite's document.

Quantization concepts

This section defines concepts specifically related to quantization with the Inference Converter.

Concepts related to other TPU configurations (for example, slices, hosts, chips, and TensorCores) are described in the TPU System Architecture page.

  • Post-training quantization (PTQ): PTQ is a technique that reduces the size and computational complexity of a neural network model without significantly affecting its accuracy. PTQ works by converting the floating-point weights and activations of a trained model to lower-precision integers, such as 8-bit or 16-bit integers. This can cause a significant reduction in model size and inference latency, while only incurring a small loss in accuracy.

  • Calibration: The calibration step for quantization is the process of collecting statistics on the range of values that the weights and activations of a neural network model take. This information is used to determine the quantization parameters for the model, which are the values that will be used to convert the floating-point weights and activations to integers.

  • Representative Dataset: A representative dataset for quantization is a small dataset that represents the actual input data for the model. It is used during the calibration step of quantization to collect statistics on the range of values that the weights and activations of the model will take. The representative dataset should satisfy the following properties:

    • It should properly represent the actual inputs to the model during inference. This means that it should cover the range of values that the model is likely to see in the real world.
    • It should collectively flow through each branch of conditionals (such as tf.cond), if there are any. This is important because the quantization process needs to be able to handle all possible inputs to the model, even if they are not explicitly represented in the representative dataset.
    • It should be large enough to collect enough statistics and reduce error. As a rule of thumb, it is recommended to use more than 200 representative samples.

    The representative dataset can be a subset of the training dataset, or it can be a separate dataset that is specifically designed to be representative of the real-world inputs to the model. The choice of which dataset to use depends on the specific application.

  • Static Range Quantization (SRQ): SRQ determines the range of values for the weights and activations of a neural network model once, during the calibration step. This means that the same range of values is used for all inputs to the model. This can be less accurate than dynamic range quantization, especially for models with a wide range of input values. However, static range quantization requires less computation at run time than dynamic range quantization.

  • Dynamic Range Quantization (DRQ): DRQ determines the range of values for the weights and activations of a neural network model for each input. This allows the model to adapt to the range of values of the input data, which can improve accuracy. However, dynamic range quantization requires more computation at run time than static range quantization.

    Feature Static range quantization Dynamic range quantization
    Range of values Determined once, during calibration Determined for each input
    Accuracy Can be less accurate, especially for models with a wide range of input values Can be more accurate, especially for models with a wide range of input values
    Complexity Simpler More complex
    Computation at run time Less computation More computation
  • Weight-only Quantization: Weight-only quantization is a type of quantization that only quantizes the weights of a neural network model, while leaving the activations in floating point. This can be a good option for models that are sensitive to accuracy, as it can help to preserve the accuracy of the model.

How to use quantization

Quantization can be applied by configuring and setting QuantizationOptions to the converter options. Notable options are:

  • tags: Collection of tags identifying the MetaGraphDef within the SavedModel to quantize. No need to specify if you have only one MetaGraphDef.
  • signature_keys: Sequence of keys identifying SignatureDef containing inputs and outputs. If not specified, ["serving_default"] is used.
  • quantization_method: Quantization method to apply. If not specified, STATIC_RANGE quantization will be applied.
  • op_set: Should be kept as XLA. It is currently the default option, no need to specify.
  • representative_datasets: Specify the dataset used for calibrating the quantization parameters.

Building the representative dataset

A representative dataset is essentially an iterable of samples. Where a sample is a map of: {input_key: input_value}. For example:

representative_dataset = [{"x": tf.random.uniform(shape=(3, 3))}
                          for _ in range(256)]

The representative datasets should be saved as TFRecord files using the TfRecordRepresentativeDatasetSaver class currently available in the tf-nightly pip package. For example:

# Assumed tf-nightly installed.
import tensorflow as tf
representative_dataset = [{"x": tf.random.uniform(shape=(3, 3))}
                          for _ in range(256)]
tf.quantization.experimental.TfRecordRepresentativeDatasetSaver(
       path_map={'serving_default': '/tmp/representative_dataset_path'}
    ).save({'serving_default': representative_dataset})

Examples

The following example quantizes the model with the signature key of serving_default and function alias of tpu_func:

docker run \
  --mount type=bind,source=${MODEL_PATH},target=/tmp/input,readonly \
  --mount type=bind,source=${CONVERTED_MODEL_PATH},target=/tmp/output \
  ${CONVERTER_IMAGE} \
  --input_model_dir=/tmp/input \
  --output_model_dir=/tmp/output \
  --converter_options_string=' \
    tpu_functions { \
      function_alias: "tpu_func" \
    } \
    external_feature_configs { \
      quantization_options { \
        signature_keys: "serving_default" \
        representative_datasets: { \
          key: "serving_default" \
          value: { \
            tfrecord_file_path: "${TF_RECORD_FILE}" \
          } \
        } \
      } \
    } '

Add batching

The Converter can be used to add batching to a model. For a description of the batching options that can be tuned, see Definition of batching options.

By default, the Converter will batch any TPU functions in the model. It can also batch user-provided signatures and functions which can further improve performance. Any TPU function, user-provided function or signature that is batched, must meet the batching op's strict shape requirements.

The Converter can also update existing batching options. The following is an example of how to add batching to a model. For more information on batching, see Batching deep dive.

batch_options {
  num_batch_threads: 2
  max_batch_size: 8
  batch_timeout_micros: 5000
  allowed_batch_sizes: 2
  allowed_batch_sizes: 4
  allowed_batch_sizes: 8
  max_enqueued_batches: 10
}

Disable bfloat16 and IO shape optimizations

BFloat16 and IO Shape Optimizations are enabled by default. If they don't work well with your model, they can be disabled.

# Disable both optimizations
disable_default_optimizations: true

# Or disable them individually
io_shape_optimization: DISABLED
bfloat16_optimization: DISABLED

Conversion Report

You can find this conversion report from the log after running the Inference Converter. Below is an example.

-------- Conversion Report --------
TPU cost of the model: 96.67% (2034/2104)
CPU cost of the model:  3.33% (70/2104)

Cost breakdown
================================
%         Cost    Name
--------------------------------
3.33      70      [CPU cost]
48.34     1017    tpu_func_1
48.34     1017    tpu_func_2
--------------------------------

This report estimates the computational cost of the output model on CPU and TPU, and further breaks down the TPU cost to each function, which should reflect your selection of the TPU functions in the converter options.

If you want to better utilize the TPU, you may want to experiment with the model structure and adjust the converter options.

FAQs

Which function(s) should I place on the TPU?

It is best to put as much of your model on the TPU as possible, because the vast majority of ops execute faster on the TPU.

If your model does not contain any TPU-incompatible op, strings or sparse tensors, putting the entire model on the TPU is usually the best strategy. And you can do it by finding or creating a function that wraps the entire model, creating a function alias for it, and passing that to the Converter.

If your model contains parts that cannot work on the TPU (e.g.,TPU-incompatible ops, strings or sparse tensors), the choice of TPU functions depends on where the incompatible part is.

  • If it's at the beginning or the end of the model, you can refactor the model to keep it on the CPU. Examples are string pre- and post-processing stages. For more information about moving code to the CPU, see, "How do I move a part of the model to CPU?" It shows a typical way to refactor the model.
  • If it's in the middle of the model it's better to split the model into three parts and contain all the TPU-incompatible ops in the middle part, and make it run on the CPU.
  • If it is a sparse tensor, consider calling tf.sparse.to_dense on the CPU and passing the resulting dense tensor to the TPU portion of the model.

Another factor to consider is the HBM usage. Embedding tables can use a lot of HBM. If they grow beyond the hardware limitation of the TPU, they have to be put on the CPU, along with the lookup ops.

Whenever possible, only one TPU function should exist under one signature. If the structure of your model requires calling multiple TPU functions per incoming inference request, you should be aware of the added latency of sending tensors between CPU and TPU.

A good way to evaluate the selection of TPU functions is to check the Conversion Report. It shows the percentage of computation that was placed on the TPU, and a breakdown of the cost of each TPU function.

How do I move a part of the model to CPU?

If your model contains parts that cannot be served on the TPU, you need to refactor the model to move them to the CPU. Here is a toy example. The model is a language model with a preprocessing stage. The code for layer definitions and functions are omitted for simplicity.

class LanguageModel(tf.keras.Model):
  @tf.function
  def model_func(self, input_string):
    word_ids = self.preprocess(input_string)
    return self.bert_layer(word_ids)

This model cannot be directly served on the TPU for two reasons. First, the parameter is a string. Second, the preprocess function may contain many string ops. Both are not TPU-compatible.

To refactor this model, you can create another function called tpu_func to host the computational-intensive bert_layer. Then create a function alias for tpu_func and pass it to the Converter. In this way, everything inside tpu_func will run on the TPU, and everything left in model_func will run on the CPU.

class LanguageModel(tf.keras.Model):
  @tf.function
  def tpu_func(self, word_ids):
    return self.bert_layer(word_ids)

  @tf.function
  def model_func(self, input_string):
    word_ids = self.preprocess(input_string)
    return self.tpu_func(word_ids)

What should I do if the model has TPU-incompatible ops, strings or sparse tensors?

Most of the standard TensorFlow ops are supported on the TPU, but a few including sparse tensors and strings are not supported. The Converter doesn't check for TPU-incompatible ops. So a model containing such ops can pass the conversion. But when running it for inference, errors like below will occur.

'tf.StringToNumber' op isn't compilable for TPU device.

If your model has TPU-incompatible ops, they should be put outside the TPU function. Moreover, string is an unsupported data format on the TPU. So string-typed variables shouldn't be placed in the TPU function. And the parameters and return values of the TPU function shouldn't be string-typed as well. Similarly, avoid placing sparse tensors in the TPU function including in its parameters and return values.

It's usually not hard to refactor out the incompatible part of the model and move it to the CPU. Here is an example.

How to support custom ops in the model?

If custom ops are used in your model, the Converter may not recognize them and fail to convert the model. This is because the op library of the custom op, which contains the complete definition of the op, isn't linked to the Converter.

As currently the converter code is not open-sourced yet, the converter cannot be built with custom op.

What should I do if I have a TensorFlow 1 model?

The Converter does not support TensorFlow 1 models. TensorFlow 1 models should be migrated to TensorFlow 2.

Do I need to enable the MLIR bridge when running my model?

Most converted models can be run with either the newer TF2XLA MLIR bridge or the original TF2XLA bridge.

How do I convert a model that has already been exported without a function alias?

If a model was exported without a function alias, the easiest way is to export it again and create a function alias. If reexporting is not an option, it is still possible to convert the model by providing a concrete_function_name. However, identifying the correct concrete_function_name does require some detective work.

Function aliases are a mapping from a user defined string to a concrete function name. They make it easier to refer to a specific function in the model. The Converter accepts both function aliases and raw concrete function names.

Concrete function names can be found by examining the saved_model.pb.

The following example shows how to put a concrete function called __inference_serve_24 on the TPU.

sudo docker run \
--mount type=bind,source=${MODEL_PATH},target=/tmp/input,readonly \
--mount type=bind,source=${CONVERTED_MODEL_PATH},target=/tmp/output \
${CONVERTER_IMAGE} \
--input_model_dir=/tmp/input \
--output_model_dir=/tmp/output \
--converter_options_string='
    tpu_functions {
      concrete_function_name: "__inference_serve_24"
    }'

How do I resolve a compile time constant constraint error?

For both training and inference, XLA requires the inputs to certain ops have a known shape at TPU compile time. This means that when XLA compiles the TPU portion of the program, the inputs to these ops must have a statically known shape.

There are two ways to resolve this issue.

  • The best option is to update the op's inputs to have a statically known shape by the time XLA compiles the TPU program. This compilation happens right before the TPU portion of the model is run. This means that the shape should be statically known by the time the TpuFunction is about to run.
  • Another option is to modify the TpuFunction to no longer include the problematic op.

Why am I getting a batching shape error?

Batching has strict shape requirements that allow incoming requests to be batched along their 0th dimension (aka the batching dimension). These shape requirements come from the TensorFlow batching op and cannot be relaxed.

Failure to meet these requirements will result in errors like:

  1. Batching input tensors must have at least one dimension.
  2. Dimensions of inputs should match.
  3. Batching input tensors supplied in a given op invocation must have equal 0th-dimension size.
  4. Batched output tensor's 0th dimension does not equal the sum of the 0th dimension sizes of the input tensors.

To meet these requirements, consider providing a different function or signature to batch. It may also be necessary to modify existing functions to meet these requirements.

If a function is being batched, make sure its @tf.function's input_signature's shapes all have None in the 0th dimension. If a signature is being batched, make sure that all its inputs have -1 in the 0th dimension.

For a complete explanation on why these errors are happening and how to resolve them, see Batching Deep Dive.

Known Issues

TPU function cannot indirectly call another TPU function

While the Converter can handle most function calling scenarios across the CPU-TPU boundary, there is one rare edge case it would fail. It is when a TPU function indirectly calls another TPU function.

This is because the Converter modifies the direct caller of a TPU function from calling the TPU function itself to calling a TPU call stub. The call stub contains ops that can only work on the CPU. When a TPU function calls any function that eventually calls the direct caller, those CPU ops could be brought on the TPU to execute, which will generate missing kernel errors. Note this case is different from a TPU function directly calling another TPU function. In this case, the Converter doesn't modify either function to call the call stub, so it can work.

In the Converter, we have implemented the detection of this scenario. If you see the following error, that means your model has hit this edge case:

Unable to place both "__inference_tpu_func_2_46" and "__inference_tpu_func_4_68"
on the TPU because "__inference_tpu_func_2_46" indirectly calls
"__inference_tpu_func_4_68". This behavior is unsupported because it can cause
invalid graphs to be generated.

The general solution is to refactor the model to avoid such a function calling scenario. If you find that difficult to do, contact the Google support team to discuss more.

Reference

Converter Options in Protobuf format

message ConverterOptions {
  // TPU conversion options.
  repeated TpuFunction tpu_functions = 1;

  // The state of an optimization.
  enum State {
    // When state is set to default, the optimization will perform its
    // default behavior. For some optimizations this is disabled and for others
    // it is enabled. To check a specific optimization, read the optimization's
    // description.
    DEFAULT = 0;
    // Enabled.
    ENABLED = 1;
    // Disabled.
    DISABLED = 2;
  }

  // Batch options to apply to the TPU Subgraph.
  //
  // At the moment, only one batch option is supported. This field will be
  // expanded to support batching on a per function and/or per signature basis.
  //
  //
  // If not specified, no batching will be done.
  repeated BatchOptions batch_options = 100;

  // Global flag to disable all optimizations that are enabled by default.
  // When enabled, all optimizations that run by default are disabled. If a
  // default optimization is explicitly enabled, this flag will have no affect
  // on that optimization.
  //
  // This flag defaults to false.
  bool disable_default_optimizations = 202;

  // If enabled, apply an optimization that reshapes the tensors going into
  // and out of the TPU. This reshape operation improves performance by reducing
  // the transfer time to and from the TPU.
  //
  // This optimization is incompatible with input_shape_opt which is disabled.
  // by default. If input_shape_opt is enabled, this option should be
  // disabled.
  //
  // This optimization defaults to enabled.
  State io_shape_optimization = 200;

  // If enabled, apply an optimization that updates float variables and float
  // ops on the TPU to bfloat16. This optimization improves performance and
  // throughtput by reducing HBM usage and taking advantage of TPU support for
  // bfloat16.
  //
  // This optimization may cause a loss of accuracy for some models. If an
  // unacceptable loss of accuracy is detected, disable this optimization.
  //
  // This optimization defaults to enabled.
  State bfloat16_optimization = 201;

  BFloat16OptimizationOptions bfloat16_optimization_options = 203;

  // The settings for XLA sharding. If set, XLA sharding is enabled.
  XlaShardingOptions xla_sharding_options = 204;
}

message TpuFunction {
  // The function(s) that should be placed on the TPU. Only provide a given
  // function once. Duplicates will result in errors. For example, if
  // you provide a specific function using function_alias don't also provide the
  // same function via concrete_function_name or jit_compile_functions.
  oneof name {
    // The name of the function alias associated with the function that
    // should be placed on the TPU. Function aliases are created during model
    // export using the tf.saved_model.SaveOptions.
    //
    // This is a recommended way to specify which function should be placed
    // on the TPU.
    string function_alias = 1;

    // The name of the concrete function that should be placed on the TPU. This
    // is the name of the function as it found in the GraphDef and the
    // FunctionDefLibrary.
    //
    // This is NOT the recommended way to specify which function should be
    // placed on the TPU because concrete function names change every time a
    // model is exported.
    string concrete_function_name = 3;

    // The name of the signature to be placed on the TPU. The user must make
    // sure there is no TPU-incompatible op under the entire signature.
    string signature_name = 5;

    // When jit_compile_functions is set to True, all jit compiled functions
    // are placed on the TPU.
    //
    // To use this option, decorate the relevant function(s) with
    // @tf.function(jit_compile=True), before exporting. Then set this flag to
    // True. The converter will find all functions that were tagged with
    // jit_compile=True and place them on the TPU.
    //
    // When using this option, all other settings for the TpuFunction
    // will apply to all functions tagged with
    // jit_compile=True.
    //
    // This option will place all jit_compile=True functions on the TPU.
    // If only some jit_compile=True functions should be placed on the TPU,
    // use function_alias or concrete_function_name.
    bool jit_compile_functions = 4;
  }

}

message BatchOptions {
  // Number of scheduling threads for processing batches of work. Determines
  // the number of batches processed in parallel. This should be roughly in line
  // with the number of TPU cores available.
  int32 num_batch_threads = 1;

  // The maximum allowed batch size.
  int32 max_batch_size = 2;

  // Maximum number of microseconds to wait before outputting an incomplete
  // batch.
  int32 batch_timeout_micros = 3;

  // Optional list of allowed batch sizes. If left empty,
  // does nothing. Otherwise, supplies a list of batch sizes, causing the op
  // to pad batches up to one of those sizes. The entries must increase
  // monotonically, and the final entry must equal max_batch_size.
  repeated int32 allowed_batch_sizes = 4;

  // Maximum number of batches enqueued for processing before requests are
  // failed fast.
  int32 max_enqueued_batches = 5;

  // If set, disables large batch splitting which is an efficiency improvement
  // on batching to reduce padding inefficiency.
  bool disable_large_batch_splitting = 6;

  // Experimental features of batching. Everything inside is subject to change.
  message Experimental {
    // The component to be batched.
    // 1. Unset if it's for all TPU subgraphs.
    // 2. Set function_alias or concrete_function_name if it's for a function.
    // 3. Set signature_name if it's for a signature.
    oneof batch_component {
      // The function alias associated with the function. Function alias is
      // created during model export using the tf.saved_model.SaveOptions, and is
      // the recommended way to specify functions.
      string function_alias = 1;

      // The concreate name of the function. This is the name of the function as
      // it found in the GraphDef and the FunctionDefLibrary. This is NOT the
      // recommended way to specify functions, because concrete function names
      // change every time a model is exported.
      string concrete_function_name = 2;

      // The name of the signature.
      string signature_name = 3;
    }
  }

  Experimental experimental = 7;
}

message BFloat16OptimizationOptions {
  // Indicates where the BFloat16 optimization should be applied.
  enum Scope {
    // The scope currently defaults to TPU.
    DEFAULT = 0;
    // Apply the bfloat16 optimization to TPU computation.
    TPU = 1;
    // Apply the bfloat16 optimization to the entire model including CPU
    // computations.
    ALL = 2;
  }

  // This field indicates where the bfloat16 optimization should be applied.
  //
  // The scope defaults to TPU.
  Scope scope = 1;

  // If set, the normal safety checks are skipped. For example, if the model
  // already contains bfloat16 ops, the bfloat16 optimization will error because
  // pre-existing bfloat16 ops can cause issues with the optimization. By
  // setting this flag, the bfloat16 optimization will skip the check.
  //
  // This is an advanced feature and not recommended for almost all models.
  //
  // This flag is off by default.
  bool skip_safety_checks = 2;

  // Ops that should not be converted to bfloat16.
  // Inputs into these ops will be cast to float32, and outputs from these ops
  // will be cast back to bfloat16.
  repeated string filterlist = 3;
}

message XlaShardingOptions {
  // num_cores_per_replica for TPUReplicateMetadata.
  //
  // This is the number of cores you wish to split your model into using XLA
  // SPMD.
  int32 num_cores_per_replica = 1;

  // (optional) device_assignment for TPUReplicateMetadata.
  //
  // This is in a flattened [x, y, z, core] format (for
  // example, core 1 of the chip
  // located in 2,3,0 will be stored as [2,3,0,1]).
  //
  // If this is not specified, then the device assignments will utilize the same
  // topology as specified in the topology attribute.
  repeated int32 device_assignment = 2;

  // A serialized string of tensorflow.tpu.TopologyProto objects, used for
  // the topology attribute in TPUReplicateMetadata.
  //
  // You must specify the mesh_shape and device_coordinates attributes in
  // the topology object.
  //
  // This option is required for num_cores_per_replica > 1 cases due to
  // ambiguity of num_cores_per_replica, for example,
  // pf_1x2x1 with megacore and df_1x1
  // both have num_cores_per_replica = 2, but topology is (1,2,1,1) for pf and
  // (1,1,1,2) for df.
  // - For pf_1x2x1, mesh shape and device_coordinates looks like:
  //   mesh_shape = [1,2,1,1]
  //   device_coordinates=flatten([0,0,0,0], [0,1,0,0])
  // - For df_1x1, mesh shape and device_coordinates looks like:
  //   mesh_shape = [1,1,1,2]
  //   device_coordinates=flatten([0,0,0,0], [0,0,0,1])
  // - For df_2x2, mesh shape and device_coordinates looks like:
  //   mesh_shape = [2,2,1,2]
  //   device_coordinates=flatten(
  //    [0,0,0,0],[0,0,0,1],[0,1,0,0],[0,1,0,1]
  //    [1,0,0,0],[1,0,0,1],[1,1,0,0],[1,1,0,1])
  bytes topology = 3;
}

Batching Deep Dive

Batching is used to improve the throughput and TPU utilization. It allows multiple requests to be processed at the same time. During training, batching can be done using tf.data. During inference, it is typically done by adding an op in the graph that batches incoming requests. The op waits until it has enough requests or a timeout is reached before it generates a large batch from the individual requests. See Definition of batching options for more information about the different batching options that can be tuned, including batch sizes and timeouts.

in-graph batching

By default, the Converter inserts the batching op directly before the TPU computation. It wraps the user-provided TPU function(s) and any preexisting TPU computation in the model with batching op(s). It is possible to override this default behavior by telling the Converter which functions and/or signatures should be batched.

The following example shows how to add the default batching.

batch_options {
  num_batch_threads: 2
  max_batch_size: 8
  batch_timeout_micros: 5000
  allowed_batch_sizes: 2
  allowed_batch_sizes: 4
  allowed_batch_sizes: 8
  max_enqueued_batches: 10
}

Signature batching

Signature batching batches the entire model starting at the signature's inputs and going to the signature's outputs. Unlike the Converter's default batching behavior, signature batching batches both the TPU computation and the CPU computation. This gives 10% to 20% performance gain during inference on some models.

Like all batching, Signature batching does have strict shape requirements. To help ensure these shape requirements are met, signature inputs should have shapes that have at least two dimensions. The first dimension is batch size and should have a size of -1. For example, (-1, 4), (-1) or (-1, 128, 4, 10) are all valid input shapes. If this is not possible, consider using the default batching behavior or function batching.

To use signature batching provide the signature name(s) as signature_name(s) using the BatchOptions.

batch_options {
  num_batch_threads: 2
  max_batch_size: 8
  batch_timeout_micros: 5000
  allowed_batch_sizes: 2
  allowed_batch_sizes: 4
  allowed_batch_sizes: 8
  max_enqueued_batches: 10
  experimental {
    signature_name: "serving_default"
  }
}

Function batching

Function batching can be used to tell the Converter which function(s) should be batched. By default the Converter will batch all TPU functions. Function batching overrides this default behavior.

Function batching can be used to batch CPU computation. Many models see a performance improvement when their CPU computation is batched. The best way to batch CPU computation is using signature batching however it may not work for some models. In those cases, function batching can be used to batch part of the CPU computation in addition to the TPU computation. Note that the batching op cannot run on the TPU so any batching function that is provided must be called on the CPU.

Function batching can also be used to satisfy the strict shape requirements imposed by the batching op. In cases when the TPU function(s) don't meet the batching op's shape requirements, function batching can be used to tell the Converter to batch different function(s).

To use this, generate a function_alias for the function that should be batched. You can do this by finding or creating a function in your model that wraps everything you want batched. Make sure this function meets the strict shape requirements imposed by the batching op. Add @tf.function if it doesn't have one already. It is important to provide the input_signature to the @tf.function. The 0th dimension should be None because it is the batch dimension so it cannot be a fixed size. For example, [None, 4], [None] or [None, 128, 4, 10] are all valid input shapes. When saving the model, provide SaveOptions like those shown below to give model.batch_func an alias "batch_func". Then you can pass this function alias to the converter.

class ToyModel(tf.keras.Model):
  @tf.function(input_signature=[tf.TensorSpec(shape=[None, 10],
                                              dtype=tf.float32)])
  def batch_func(self, x):
    return x * 1.0

  ...

model = ToyModel()
save_options = tf.saved_model.SaveOptions(function_aliases={
    'batch_func': model.batch_func,
})
tf.saved_model.save(model, model_dir, options=save_options)

Next, pass the function_alias(s) using the BatchOptions.

batch_options {
  num_batch_threads: 2
  max_batch_size: 8
  batch_timeout_micros: 5000
  allowed_batch_sizes: 2
  allowed_batch_sizes: 4
  allowed_batch_sizes: 8
  max_enqueued_batches: 10
  experimental {
    function_alias: "batch_func"
  }
}

Definition of batching options

  • num_batch_threads: (integer) Number of scheduling threads for processing batches of work. Determines the number of batches processed in parallel. This should be roughly in line with the number of TPU cores available.
  • max_batch_size: (integer) Maximum allowed batch size. Can be larger than allowed_batch_sizes to utilize large batch splitting.
  • batch_timeout_micros: (integer) Maximum number of microseconds to wait before outputting an incomplete batch.
  • allowed_batch_sizes: (list of integers) If the list is not empty, it will pad batches up to the nearest size in the list. The list must be monotonically increasing and the final element must be lower than or equal to max_batch_size.
  • max_enqueued_batches: (integer) Maximum number of batches enqueued for processing before requests are failed fast.

Updating existing batching options

You can add or update batching options by running the Docker image specifying batch_options and setting disable_default_optimizations to true using the --converter_options_string flag. The batch options will be applied to every TPU function or pre-existing batching op.

batch_options {
  num_batch_threads: 2
  max_batch_size: 8
  batch_timeout_micros: 5000
  allowed_batch_sizes: 2
  allowed_batch_sizes: 4
  allowed_batch_sizes: 8
  max_enqueued_batches: 10
}
disable_default_optimizations=True

Batching shape requirements

Batches are created by concatenating input tensors across requests along their batch (0th) dimension. The output tensors are split along their 0th dimension. In order to perform these operations, the batching op has strict shape requirements for its inputs and outputs.

Walkthrough

To understand these requirements, it is helpful to first understand how batching is performed. In the example below, we are batching a simple tf.matmul op.

def my_func(A, B)
    return tf.matmul(A, B)

The first inference request produces the inputs A and B with the shapes (1, 3, 2) and (1, 2, 4) respectively. The second inference request produces the inputs A and B with the shapes (2, 3, 2) and (2, 2, 4).

inference request 1

The batching timeout is reached. The model supports a batch size of 3 so inference requests #1 and #2 are batched together without any padding. The batched tensors are formed by concatenating the requests #1 and #2 along the batch (0th) dimension. Since #1's A has a shape of (1, 3, 2) and #2's A has a shape of (2, 3, 2), when they are concatenated along the batch (0th) dimension, the resulting shape is (3, 3, 2).

batched request

The tf.matmul is executed and it produces an output with the shape (3, 3, 4).

batched matmul request

The output of the tf.matmul is batched so it needs to be split back into separate requests. The batching op does this by splitting along the batch (0th) dimension of each output tensor. It decides how to split the 0th dimension based on the shape of the original inputs. Since request #1's shapes have a 0th dimension of 1, its output has a 0th dimension of 1 for a shape of (1, 3, 4). Since request #2's shapes have a 0th dimension of 2, its output has a 0th dimension of 2 for a shape of (2, 3, 4).

inference request results

Shape Requirements

In order to perform the input concatenating and output splitting described above, the batching op has the following shape requirements:

  1. Inputs to batching cannot be scalars. In order to concatenate along the 0th dimension, the tensors have to have at least two dimensions.

    In the walkthrough above. Neither A nor B are scalars.

    Failure to meet this requirement will cause an error like: Batching input tensors must have at least one dimension. A simple fix for this error is to make the scalar a vector.

  2. Across different inference requests (for example, different Session run invocations), input tensors with the same name have the same size for each dimension except the 0th dimension. This allows inputs to be cleanly concatenated along their 0th dimension.

    In the walkthrough above, request #1's A has a shape of (1, 3, 2). This means that any future request must produce a shape with the pattern (X, 3, 2). Request #2 meets this requirement with (2, 3, 2). Similarly, request #1's B has a shape of (1, 2, 4) so all future requests must produce a shape with the pattern (X, 2, 4).

    Failure to meet this requirement will cause an error like: Dimensions of inputs should match.

  3. For a given inference request, all inputs must have the same 0th dimension size. If different input tensors to the batching op have different 0th dimensions, the batching op does not know how to split the output tensors.

    In the walkthrough above, request #1's tensors all have a 0th dimension size of 1. This lets the batching op know that its output should have a 0th dimension size of 1. Similarly request #2's tensors have a 0th dimension size of 2, so its output will have a 0th dimension size of 2. When the batching op splits the final shape of (3, 3, 4), it produces (1, 3, 4) for request #1 and (2, 3, 4) for request #2.

    Failure to meet this requirement will result in errors like: Batching input tensors supplied in a given op invocation must have equal 0th-dimension size.

  4. The 0th dimension size of each output tensor's shape must be the sum of all the input tensors' 0th dimension size (plus any padding introduced by the batching op to meet the next largest allowed_batch_size). This allows the batching op to split the output tensors along their 0th dimension based on the 0th dimension of the input tensors.

    In the walkthrough above, the input tensors have a 0th dimension of 1 from request #1 and 2 from request #2. Therefore, each output tensor must have a 0th dimension of 3 because 1+2=3. The output tensor (3, 3, 4) meets this requirement. If 3 had not been a valid batch size but 4 was, the batching op would have had to pad the 0th dimension of the inputs from 3 to 4. In this case, each output tensor would have to have a 0th dimension size of 4.

    Failure to meet this requirement will result in an error like: Batched output tensor's 0th dimension does not equal the sum of the 0th dimension sizes of the input tensors.

Resolving Shape Requirement Errors

To meet these requirements, consider providing a different function or signature to batch. It may also be necessary to modify existing functions to meet these requirements.

If a function is being batched, make sure its @tf.function's input_signature's shapes all have None in the 0th dimension (aka the batch dimension). If a signature is being batched, make sure that all its inputs have -1 in the 0th dimension.

The BatchFunction op does not support SparseTensors as inputs or outputs. Internally, each sparse tensor is represented as three separate tensors that can have different 0th dimension sizes.