Jump to Content
Developers & Practitioners

Using TFX inference with Dataflow for large scale ML inference patterns

April 30, 2021
Reza Rokni

Senior Staff Developer Advocate Dataflow

In part I of this blog series we discussed best practices and patterns for efficiently deploying a machine learning model for inference with Google Cloud Dataflow. Amongst other techniques, it showed efficient batching of the inputs and the use of shared.py to make efficient use of a model.

In this post, we walk through the use of the RunInference API from tfx-bsl, a utility transform from TensorFlow Extended (TFX), which abstracts us away from manually implementing the patterns described in part I. You can use RunInference to simplify your pipelines and reduce technical debt when building production inference pipelines in batch or stream mode. 

The following four patterns are covered:

  • Using RunInference to make ML prediction calls.
  • Post-processing  RunInference results. Making predictions is often the first part of a multistep flow, in the business process. Here we will process the results into a form that can be used downstream.
  • Attaching a key. Along with the data that is passed to the model, there is often a need for an identifier — for example, an IOT device ID or a customer identifier — that is used later in the process even if it’s not used by the model itself. We show how this can be accomplished.
  • Inference with multiple models in the same pipeline. Often you may need to run multiple models within the same pipeline, be it in parallel or as a sequence of predict - process - predict calls. We walk through a simple example.

Creating a simple model

In order to illustrate these patterns, we’ll use a simple toy model that will let us concentrate on the data engineering needed for the input and output of the pipeline. This model will be trained to approximate multiplication by the number 5.

Please note the following code snippets can be run as cells within a notebook environment. 

Step 1 - Set up libraries and imports

%pip install tfx_bsl==0.29.0 --quiet 


Step 2 - Create the example data 

In this step we create a small dataset that includes a range of values from 0 to 99 and labels that correspond to each value multiplied by 5.


Step 3 - Create a simple model, compile, and fit it 


Let’s teach the model about multiplication by 5.

model.fit(x, y, epochs=2000)

Next, check how well the model performs using some test data.


From the results below it looks like this simple model has learned its 5 times table close enough for our needs!


Step 4 - Convert the input to tf.example

In the model we just built, we made use of a simple list to generate the data and pass it to the model. In this next step we make the model more robust by using tf.example objects in the model training. 

tf.example is a serializable dictionary (or mapping) from names to tensors, which ensures the model can still function even when new features are added to the base examples. Making use of tf.example also brings with it the benefit of having the data be portable across models in an efficient, serialized format. 

To use tf.example for this example, we first need to create a helper class, ExampleProcessor, that is used to serialize the data points.


Using the ExampleProcess class, the in-memory list can now be moved to disk.


With the new examples stored in TFRecord files on disk, we can use the Dataset API to prepare the data so it is ready for consumption by the model.


With the feature spec in place, we can train the model as before. 


Note that these steps would be done automatically for us if we had built the model using a TFX pipeline, rather than hand-crafting the model as we did here. 

Step 5 - Save the model

Now that we have a model, we need to save it for use with the RunInference transform. RunInference accepts TensorFlow saved model pb files as part of its configuration. The saved model file must be stored in a location that can be accessed by the RunInference transform. In a notebook this can be the local file system; however, to run the pipeline on Dataflow, the file will need to be accessible by all the workers, so here we use a GCP bucket. 

Note that the gs:// schema is directly supported by the tf.keras.models.save_model api.

tf.keras.models.save_model(model, save_model_dir_multiply)

During development it's useful to be able to inspect the contents of the saved model file. For this, we use the saved_model_cli that comes with TensorFlow. You can run this command from a cell:

 !saved_model_cli show --dir {save_model_dir_multiply} --all

Abbreviated output from the saved model file is shown below. Note the signature def 'serving_default', which accepts a tensor of float type. We will change this to accept another type in the next section.


RunInference will pass a serialized tf.example to the model rather than a tensor of float type as seen in the current signature. To accomplish this we have one more step to prepare the model: creation of a specific signature. 

Signatures are a powerful feature as they enable us to control how calling programs interact with the model. From the TensorFlow documentation:

"The optional signatures argument controls which methods in obj will be available to programs which consume SavedModels, for example, serving APIs. Python functions may be decorated with @tf.function(input_signature=...) and passed as signatures directly, or lazily with a call to get_concrete_function on the method decorated with @tf.function."

In our case, the following code will create a signature that accepts a tf.string data type with a name of examples. This signature is then saved with the model, which replaces the previous saved model.


If you run the saved_model_cli command again, you will see that the input signature has changed to DT_STRING

Pattern 1: RunInference for predictions

Step 1 - Use RunInference within the pipeline

Now that the model is ready, the RunInference transform can be plugged into an Apache Beam pipeline. The pipeline below uses TFXIO  TFExampleRecord, which it converts to a transform via RawRecordBeamSource(). The saved model location and signature are passed to the RunInference API as a SavedModelSpec configuration object.

Note: You can perform two types of inference using RunInference:

  •  In-process inference from a SavedModel instance. Used when the saved_model_spec field is set in inference_spec_type.
  •   Remote inference by using a service endpoint. Used when the ai_platform_prediction_model_spec field is set in inference_spec_type.

Below is a snippet of the output. The values here are a little difficult to interpret as they are in their raw unprocessed format. In the next section the raw results are post-processed.


Pattern 2: Post-processing RunInference results

The RunInference API returns a PredictionLog object, which contains the serialized input and the output from the call to the model. Having access to both the input and output enables you to create a simple tuple during post-processing for use downstream in the pipeline. Also worthy of note is that RunInference will consider the amenable-to-batching capability of the model (and does batch inference for performance purposes) transparently for you. 

The PredictionProcessor beam.DoFn takes the output of RunInference and produces formatted text with the questions and answers as output. Of course in a production system, the output would more normally be a Tuple[input, output], or simply the output depending on the use case.


Now the output contains both the original input and the model's output values. 


Pattern 3: Attaching a key

One useful pattern is the ability to pass information, often a unique identifier, with the input to the model and have access to this identifier from the output. For example, in an IOT use case you could associate a device id with the input data being passed into the model. Often this type of key is not useful for the model itself and thus should not be passed into the first layer.

RunInference takes care of this for us, by accepting a Tuple[key, value] and outputting Tuple[key, PredictLog]

Step 1 - Create a source with attached key

Since we need a key with the data that we are sending in for prediction, in this step we create a table in BigQuery, which has two columns: One holds the key and the second holds the test value.


Step 2 - Modify post processor and pipeline

In this step we:

  • Modify the pipeline to read from the new BigQuery source table
  • Add a map transform, which converts a table row into a Tuple[ bytes, Example]
  • Modify the post inference processor to output results along with the key


Pattern 4: Inference with multiple models in the same pipeline

In part I of the series, the "join results from multiple models" pattern covered the various branching techniques in Apache Beam that make it possible to run data through multiple models. 

Those techniques are applicable to RunInference API, which can easily be used by multiple branches within a pipeline, with the same or different models. This is similar in function to cascade ensembling, although here the data flows through multiple models in a single Apache Beam DAG.

Inference with multiple models in parallel

In this example, the same data is run through two different models: the one that we’ve been using to multiply by 5 and a new model, which will learn to multiply by 10.


Now that we have two models, we apply them to our source data. 



Inference with multiple models in sequence 

In a sequential pattern, data is sent to one or more models in sequence, with the output from each model chaining to the next model.


Here are the steps: 

  1. Read the data from BigQuery
  2. Map the data
  3. RunInference with multiply by 5 model
  4. Process the results
  5. RunInference with multiply by 10 model
  6. Process the results


Running the pipeline on Dataflow

Until now the pipeline has been run locally, using the direct runner, which is implicitly used when running a pipeline with the default configuration. The same examples can be run using the production Dataflow runner by passing in configuration parameters including --runner. Details and an example can be found here.

Here is an example of the multimodel pipeline graph running on the Dataflow service:


With the Dataflow runner you also get access to pipeline monitoring as well as metrics that have been output from the RunInference transform. The following table shows some of these metrics from a much larger list available from the library.



In this blog, part II of our series, we explored the use of the tfx-bsl RunInference within some common scenarios, from standard inference, to post processing and the use of RunInference API in multiple locations in the pipeline.

To learn more, review the Dataflow and TFX documentation, you can also try out TFX with Google Cloud AI platform pipelines..  


None of this would be possible without the hard work of many folks across both the Dataflow  TFX and TF teams. From the TFX and TF team we would especially like to thank Konstantinos Katsiapis, Zohar Yahav, Vilobh Meshram, Jiayi Zhao, Zhitao Li, and Robert Crowe. From the Dataflow team I would like to thank Ahmet Altay for his support and input throughout.

Posted in