本页面介绍了如何保存 TensorFlow 模型(不论您使用的是 TensorFlow 2.x 还是 TensorFlow 1.15),以便将其与 AI Explanations 配合使用。
TensorFlow 2
如果您使用的是 TensorFlow 2.x,请使用 tf.saved_model.save
保存您的模型。
优化已保存的 TensorFlow 模型的常见选择是为用户提供签名。您可以在保存模型时指定输入签名。
如果您只有一个输入签名,AI Explanations 会自动为您的说明请求使用的默认服务函数(遵循默认 tf.saved_model.save
行为)。详细了解如何在 TensorFlow 中指定服务签名。
多个输入签名
如果您的模型具有多个输入签名,AI Explanations 无法在确定从模型中检索预测时会自动使用哪个签名定义。因此,您必须指定希望 AI Explanations 使用的签名定义。保存模型时,请使用唯一的键 xai-model
指定服务默认函数的签名:
tf.saved_model.save(m, model_dir, signatures={
'serving_default': serving_fn,
'xai_model': my_signature_default_fn # Required for AI Explanations
})
在这种情况下,AI Explanations 将使用您在 xai_model
键中提供的模型函数签名与模型进行交互,并生成说明。请为该键使用确切的字符串 xai_model
。 如需更多背景信息,请参阅本“签名防御”概览。
预处理函数
如果使用预处理函数,则必须在保存模型时指定预处理函数和模型函数的签名。使用 xai_preprocess
键指定预处理函数:
tf.saved_model.save(m, model_dir, signatures={
'serving_default': serving_fn,
'xai_preprocess': preprocess_fn, # Required for AI Explanations
'xai_model': model_fn # Required for AI Explanations
})
在这种情况下,AI Explanations 会使用您的预处理函数和模型函数处理您的解释请求。确保预处理函数的输出与模型函数所需的输入相符。
试用完整的 TensorFlow 2 示例笔记本:
TensorFlow 1.15
如果您使用的是 TensorFlow 1.15,请勿使用 tf.saved_model.save
。使用 TensorFlow 1 时,AI Explanations 不支持此函数。请改为将 tf.estimator.export_savedmodel
与相应的 tf.estimator.export.ServingInputReceiver
结合使用
使用 Keras 构建的模型
如果您在 Keras 中构建和训练模型,则必须将模型转换为 TensorFlow Estimator,然后将其导出到 SavedModel。本部分重点介绍如何保存模型。如需了解完整的有效示例,请参阅示例笔记本。
在构建、编译、训练和评估 Keras 模型后,您必须执行以下操作:
- 使用
tf.keras.estimator.model_to_estimator
将 Keras 模型转换为 TensorFlow Estimator - 使用
tf.estimator.export.build_raw_serving_input_receiver_fn
提供传送输入函数 - 使用
tf.estimator.export_saved_model
将模型导出为 SavedModel。
# Build, compile, train, and evaluate your Keras model
model = tf.keras.Sequential(...)
model.compile(...)
model.fit(...)
model.predict(...)
## Convert your Keras model to an Estimator
keras_estimator = tf.keras.estimator.model_to_estimator(keras_model=model, model_dir='export')
## Define a serving input function appropriate for your model
def serving_input_receiver_fn():
...
return tf.estimator.export.ServingInputReceiver(...)
## Export the SavedModel to Cloud Storage using your serving input function
export_path = keras_estimator.export_saved_model(
'gs://' + 'YOUR_BUCKET_NAME', serving_input_receiver_fn).decode('utf-8')
print("Model exported to: ", export_path)
后续步骤
- 了解如何使用 Explainable AI SDK。
- 如需直观显示说明,您可以使用 What-If 工具。如需了解详情,请参阅示例笔记本。