Run in Google Colab
|
View source on GitHub
|
This example demonstrates the use of the RunInference transform with the pre-trained ST5 text encoder model from TensorFlow Hub. The transform runs locally using the Interactive Runner.
Download and install the dependencies
pip install apache_beam[gcp,interactive]==2.41.0pip install tensorflow==2.10.0pip install tensorflow_text==2.10.0pip install keras==2.10.0pip install tfx_bsl==1.10.0pip install pillow==8.4.0
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text
from tensorflow import keras
import apache_beam as beam
import apache_beam.runners.interactive.interactive_beam as ib
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.runners.interactive.interactive_runner import InteractiveRunner
from tfx_bsl.public.beam.run_inference import CreateModelHandler
from tfx_bsl.public.proto import model_spec_pb2
Authenticate with Google Cloud
This notebook relies on saving the model to Google Cloud. To use your Google Cloud account, authenticate this notebook.
from google.colab import auth
auth.authenticate_user()
Create a Keras Model from TensorFlow Hub image
Replace GCS_BUCKET with the name of your bucket. Your model will be saved in MODEL_EXPORT_DIR.
GCS_BUCKET = '<GCS Bucket>'
MODEL_EXPORT_DIR = f'gs://{GCS_BUCKET}/st5-base/1'
inp = tf.keras.layers.Input(shape=[], dtype=tf.string, name='input')
hub_url = "https://tfhub.dev/google/sentence-t5/st5-base/1"
imported = hub.KerasLayer(hub_url)
outp = imported(inp)
model = tf.keras.Model(inp, outp)
# The ST5 model returns a 768-dimensional vector for an English text input.
model.summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input (InputLayer) [(None,)] 0
keras_layer (KerasLayer) [(None, 768)] 0
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________
Save the model
Save the model with a TF function definition for RunInference.
RAW_DATA_PREDICT_SPEC = {
'input': tf.io.FixedLenFeature([], tf.string),
}
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
def call(serialized_examples):
features = tf.io.parse_example(serialized_examples, RAW_DATA_PREDICT_SPEC)
return model(features)
tf.saved_model.save(model, MODEL_EXPORT_DIR, signatures={'serving_default': call})
Create and test the RunInference pipeline locally
Use TFX_BSL's CreateModelHandler function for RunInference with TensorFlow models.
# Creates a TensorFlow example to feed to the model handler.
class ExampleProcessor:
def create_example(self, feature: tf.string):
return tf.train.Example(
features=tf.train.Features(
feature={'input' : self.create_feature(feature)})
)
def create_feature(self, element: tf.string):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[element.encode()], ))
saved_model_spec = model_spec_pb2.SavedModelSpec(model_path=MODEL_EXPORT_DIR)
inference_spec_type = model_spec_pb2.InferenceSpecType(saved_model_spec=saved_model_spec)
model_handler = CreateModelHandler(inference_spec_type)
questions = [
'what is the official slogan for the 2018 winter olympics?',
]
pipeline = beam.Pipeline(InteractiveRunner())
inference = (pipeline | 'CreateSentences' >> beam.Create(questions)
| 'Convert input to Tensor' >> beam.Map(lambda x: ExampleProcessor().create_example(x))
| 'RunInference with T5' >> RunInference(model_handler))
ib.show(inference)
WARNING:tensorflow:From /usr/local/lib/python3.8/dist-packages/tfx_bsl/beam/run_inference.py:615: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0. 2022-12-06 09:30:47.084208: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled 2022-12-06 09:30:54.471173: I tensorflow/compiler/xla/service/service.cc:173] XLA service 0x4a3e0a00 initialized for platform Host (this does not guarantee that XLA will be used). Devices: 2022-12-06 09:30:54.471244: I tensorflow/compiler/xla/service/service.cc:181] StreamExecutor device (0): Host, Default Version 2022-12-06 09:30:54.537285: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable. 2022-12-06 09:31:00.441479: I tensorflow/compiler/jit/xla_compilation_cache.cc:476] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
Run in Google Colab
View source on GitHub