Using TFX inference with Dataflow for large scale ML inference patterns
Reza Rokni
Senior Developer Advocate, Dataflow
In part I of this blog series we discussed best practices and patterns for efficiently deploying a machine learning model for inference with Google Cloud Dataflow. Amongst other techniques, it showed efficient batching of the inputs and the use of shared.py to make efficient use of a model.
In this post, we walk through the use of the RunInference API from tfx-bsl, a utility transform from TensorFlow Extended (TFX), which abstracts us away from manually implementing the patterns described in part I. You can use RunInference to simplify your pipelines and reduce technical debt when building production inference pipelines in batch or stream mode.
The following four patterns are covered:
- Using RunInference to make ML prediction calls.
- Post-processing RunInference results. Making predictions is often the first part of a multistep flow, in the business process. Here we will process the results into a form that can be used downstream.
- Attaching a key. Along with the data that is passed to the model, there is often a need for an identifier — for example, an IOT device ID or a customer identifier — that is used later in the process even if it’s not used by the model itself. We show how this can be accomplished.
- Inference with multiple models in the same pipeline. Often you may need to run multiple models within the same pipeline, be it in parallel or as a sequence of predict - process - predict calls. We walk through a simple example.
Creating a simple model
In order to illustrate these patterns, we’ll use a simple toy model that will let us concentrate on the data engineering needed for the input and output of the pipeline. This model will be trained to approximate multiplication by the number 5.
Please note the following code snippets can be run as cells within a notebook environment.
Step 1 - Set up libraries and imports
%pip install tfx_bsl==0.29.0 --quiet
Step 2 - Create the example data
In this step we create a small dataset that includes a range of values from 0 to 99 and labels that correspond to each value multiplied by 5.
Step 3 - Create a simple model, compile, and fit it
Let’s teach the model about multiplication by 5.
|
From the results below it looks like this simple model has learned its 5 times table close enough for our needs!
Step 4 - Convert the input to tf.example
In the model we just built, we made use of a simple list to generate the data and pass it to the model. In this next step we make the model more robust by using tf.example objects in the model training.
tf.example is a serializable dictionary (or mapping) from names to tensors, which ensures the model can still function even when new features are added to the base examples. Making use of tf.example also brings with it the benefit of having the data be portable across models in an efficient, serialized format.
To use tf.example for this example, we first need to create a helper class, ExampleProcessor
, that is used to serialize the data points.
Using the ExampleProcess
class, the in-memory list can now be moved to disk.
With the new examples stored in TFRecord files on disk, we can use the Dataset API to prepare the data so it is ready for consumption by the model.
With the feature spec in place, we can train the model as before.
Note that these steps would be done automatically for us if we had built the model using a TFX pipeline, rather than hand-crafting the model as we did here.
Step 5 - Save the model
Now that we have a model, we need to save it for use with the RunInference transform. RunInference accepts TensorFlow saved model pb files as part of its configuration. The saved model file must be stored in a location that can be accessed by the RunInference transform. In a notebook this can be the local file system; however, to run the pipeline on Dataflow, the file will need to be accessible by all the workers, so here we use a GCP bucket.
Note that the gs:// schema is directly supported by the tf.keras.models.save_model api.
|
During development it's useful to be able to inspect the contents of the saved model file. For this, we use the saved_model_cli
that comes with TensorFlow. You can run this command from a cell:
|
Abbreviated output from the saved model file is shown below. Note the signature def 'serving_default', which accepts a tensor of float type. We will change this to accept another type in the next section.
RunInference will pass a serialized tf.example to the model rather than a tensor of float type as seen in the current signature. To accomplish this we have one more step to prepare the model: creation of a specific signature.
Signatures are a powerful feature as they enable us to control how calling programs interact with the model. From the TensorFlow documentation:
"The optional signatures argument controls which methods in obj will be available to programs which consume SavedModels, for example, serving APIs. Python functions may be decorated with @tf.function(input_signature=...) and passed as signatures directly, or lazily with a call to get_concrete_function on the method decorated with @tf.function."
In our case, the following code will create a signature that accepts a tf.string
data type with a name of examples
. This signature is then saved with the model, which replaces the previous saved model.
If you run the saved_model_cli
command again, you will see that the input signature has changed to DT_STRING
.
Pattern 1: RunInference for predictions
Step 1 - Use RunInference within the pipeline
Now that the model is ready, the RunInference transform can be plugged into an Apache Beam pipeline. The pipeline below uses TFXIO TFExampleRecord, which it converts to a transform via RawRecordBeamSource(). The saved model location and signature are passed to the RunInference API as aSavedModelSpec
configuration object.Note: You can perform two types of inference using RunInference:
- In-process inference from a SavedModel instance. Used when the
saved_model_spec
field is set ininference_spec_type
. - Remote inference by using a service endpoint. Used when the
ai_platform_prediction_model_spec
field is set ininference_spec_type
.
Below is a snippet of the output. The values here are a little difficult to interpret as they are in their raw unprocessed format. In the next section the raw results are post-processed.
Pattern 2: Post-processing RunInference results
The RunInference API returns a PredictionLog object, which contains the serialized input and the output from the call to the model. Having access to both the input and output enables you to create a simple tuple during post-processing for use downstream in the pipeline. Also worthy of note is that RunInference will consider the amenable-to-batching capability of the model (and does batch inference for performance purposes) transparently for you.
The PredictionProcessor beam.DoF
n takes the output of RunInference and produces formatted text with the questions and answers as output. Of course in a production system, the output would more normally be a Tuple[input, output], or simply the output depending on the use case.
Now the output contains both the original input and the model's output values.
Pattern 3: Attaching a key
One useful pattern is the ability to pass information, often a unique identifier, with the input to the model and have access to this identifier from the output. For example, in an IOT use case you could associate a device id with the input data being passed into the model. Often this type of key is not useful for the model itself and thus should not be passed into the first layer.
RunInference takes care of this for us, by accepting a Tuple[key, value] and outputting Tuple[key, PredictLog]
Step 1 - Create a source with attached key
Since we need a key with the data that we are sending in for prediction, in this step we create a table in BigQuery, which has two columns: One holds the key and the second holds the test value.
Step 2 - Modify post processor and pipeline
In this step we:
- Modify the pipeline to read from the new BigQuery source table
- Add a map transform, which converts a table row into a Tuple[ bytes, Example]
- Modify the post inference processor to output results along with the key
Pattern 4: Inference with multiple models in the same pipeline
In part I of the series, the "join results from multiple models" pattern covered the various branching techniques in Apache Beam that make it possible to run data through multiple models.
Those techniques are applicable to RunInference API, which can easily be used by multiple branches within a pipeline, with the same or different models. This is similar in function to cascade ensembling, although here the data flows through multiple models in a single Apache Beam DAG.
Inference with multiple models in parallel
In this example, the same data is run through two different models: the one that we’ve been using to multiply by 5 and a new model, which will learn to multiply by 10.
Now that we have two models, we apply them to our source data.
Inference with multiple models in sequence
In a sequential pattern, data is sent to one or more models in sequence, with the output from each model chaining to the next model.
Here are the steps:
- Read the data from BigQuery
- Map the data
- RunInference with multiply by 5 model
- Process the results
- RunInference with multiply by 10 model
- Process the results
Running the pipeline on Dataflow
Until now the pipeline has been run locally, using the direct runner, which is implicitly used when running a pipeline with the default configuration. The same examples can be run using the production Dataflow runner by passing in configuration parameters including --runner
. Details and an example can be found here.
Here is an example of the multimodel pipeline graph running on the Dataflow service:
With the Dataflow runner you also get access to pipeline monitoring as well as metrics that have been output from the RunInference transform. The following table shows some of these metrics from a much larger list available from the library.
Conclusion
In this blog, part II of our series, we explored the use of the tfx-bsl RunInference within some common scenarios, from standard inference, to post processing and the use of RunInference API in multiple locations in the pipeline.
To learn more, review the Dataflow and TFX documentation, you can also try out TFX with Google Cloud AI platform pipelines..
Acknowledgements
None of this would be possible without the hard work of many folks across both the Dataflow TFX and TF teams. From the TFX and TF team we would especially like to thank Konstantinos Katsiapis, Zohar Yahav, Vilobh Meshram, Jiayi Zhao, Zhitao Li, and Robert Crowe. From the Dataflow team I would like to thank Ahmet Altay for his support and input throughout.