保存 TensorFlow 模型以用于 AI Explanations

本页面介绍了如何保存 TensorFlow 模型(不论您使用的是 TensorFlow 2.x 还是 TensorFlow 1.15),以便将其与 AI Explanations 配合使用。

TensorFlow 2

如果您使用的是 TensorFlow 2.x,请使用 tf.saved_model.save 保存您的模型。

优化已保存的 TensorFlow 模型的常见选择是为用户提供签名。您可以在保存模型时指定输入签名。 如果您只有一个输入签名,AI Explanations 会自动为您的说明请求使用的默认服务函数(遵循默认 tf.saved_model.save 行为)。详细了解如何在 TensorFlow 中指定服务签名

多个输入签名

如果您的模型具有多个输入签名,AI Explanations 无法在确定从模型中检索预测时会自动使用哪个签名定义。因此,您必须指定希望 AI Explanations 使用的签名定义。保存模型时,请使用唯一的键 xai-model 指定服务默认函数的签名:

tf.saved_model.save(m, model_dir, signatures={
    'serving_default': serving_fn,
    'xai_model': my_signature_default_fn # Required for AI Explanations
    })

在这种情况下,AI Explanations 将使用您在 xai_model 键中提供的模型函数签名与模型进行交互,并生成说明。请为该键使用确切的字符串 xai_model。 如需更多背景信息,请参阅本“签名防御”概览

预处理函数

如果使用预处理函数,则必须在保存模型时指定预处理函数和模型函数的签名。使用 xai_preprocess 键指定预处理函数:

tf.saved_model.save(m, model_dir, signatures={
    'serving_default': serving_fn,
    'xai_preprocess': preprocess_fn, # Required for AI Explanations
    'xai_model': model_fn # Required for AI Explanations
    })

在这种情况下,AI Explanations 会使用您的预处理函数和模型函数处理您的解释请求。确保预处理函数的输出与模型函数所需的输入相符。

试用完整的 TensorFlow 2 示例笔记本:

TensorFlow 1.15

如果您使用的是 TensorFlow 1.15,请勿使用 tf.saved_model.save。使用 TensorFlow 1 时,AI Explanations 不支持此函数。请改为将 tf.estimator.export_savedmodel 与相应的 tf.estimator.export.ServingInputReceiver 结合使用

使用 Keras 构建的模型

如果您在 Keras 中构建和训练模型,则必须将模型转换为 TensorFlow Estimator,然后将其导出到 SavedModel。本部分重点介绍如何保存模型。如需了解完整的有效示例,请参阅示例笔记本。

在构建、编译、训练和评估 Keras 模型后,您必须执行以下操作:

  • 使用 tf.keras.estimator.model_to_estimator 将 Keras 模型转换为 TensorFlow Estimator
  • 使用 tf.estimator.export.build_raw_serving_input_receiver_fn 提供传送输入函数
  • 使用 tf.estimator.export_saved_model 将模型导出为 SavedModel。
# Build, compile, train, and evaluate your Keras model
model = tf.keras.Sequential(...)
model.compile(...)
model.fit(...)
model.predict(...)

## Convert your Keras model to an Estimator
keras_estimator = tf.keras.estimator.model_to_estimator(keras_model=model, model_dir='export')

## Define a serving input function appropriate for your model
def serving_input_receiver_fn():
  ...
  return tf.estimator.export.ServingInputReceiver(...)

## Export the SavedModel to Cloud Storage using your serving input function
export_path = keras_estimator.export_saved_model(
    'gs://' + 'YOUR_BUCKET_NAME', serving_input_receiver_fn).decode('utf-8')

print("Model exported to: ", export_path)

后续步骤