当您不需要立即获取预测结果,或者当您有大量实例需要预测时,您可以使用批量预测服务。 本页面介绍如何开始 AI Platform Prediction 批量预测作业。 AI Platform Prediction 仅支持从 TensorFlow 模型进行批量预测。
准备工作
要请求预测,您首先必须完成以下操作:
创建模型资源和版本资源或将 TensorFlow SavedModel 放置在一个您的项目可以访问的 Cloud Storage 位置中。
- 如果您选择使用版本资源进行批量预测,则必须创建使用
mls1-c1-m2
机器类型的版本。
- 如果您选择使用版本资源进行批量预测,则必须创建使用
为以下文件设置项目有权访问的 Cloud Storage 位置:
输入数据文件。可以设置多个位置,但您的项目必须具有每个位置的读取权限。
输出文件。您只能指定一个输出路径,并且您的项目必须具有向该路径写入数据的权限。
验证输入文件是否具有进行批量预测的正确格式。
配置批量预测作业
要开始批量预测作业,您需要收集一些配置数据。这与您直接调用 API 时使用的 PredictionInput 对象中包含的数据相同:
- 数据格式
用于输入文件的输入格式类型。给定作业的所有输入文件必须使用相同的数据格式。请设置为以下值之一:
- JSON
您的输入文件为纯文本,每行都有一个实例。这是预测概念页面上介绍的格式。
- TF_RECORD
您的输入文件使用 TensorFlow TFRecord 格式。
- TF_RECORD_GZIP
您的输入文件为经过 GZIP 压缩的 TFRecord 文件。
- 输入路径
输入数据文件的 URI,必须位于 Cloud Storage 位置中。您可以将其指定为:
特定文件的路径:
'gs://path/to/my/input/file.json'
。指向某个目录的路径,带有一个星号通配符,用于指示该目录中的所有文件:
'gs://path/to/my/input/*'
.指向部分文件名的路径,末尾带有一个星号通配符,用于指示以提供的序列开头的所有文件:
'gs://path/to/my/input/file*'
。
您可以组合多个 URI。在 Python 中,您可以列出这些 URI。如果您使用 Google Cloud CLI 或直接调用 API,则可以列出多个 URI(以逗号分隔,但相互之间没有空格)。以下是
--input-paths
标志的正确格式:--input-paths gs://a/directory/of/files/*,gs://a/single/specific/file.json,gs://a/file/template/data*
- 输出路径
指向您希望预测服务将结果保存到的 Cloud Storage 位置的路径。您的项目必须具有向此位置写入数据的权限。
- 模型名称和版本名称
模型的名称,以及您希望从中获取预测结果的版本(可选)。如果未指定版本,则使用模型的默认版本。对于批量预测,版本必须使用
mls1-c1-m2
机器类型。如果您提供了模型 URI(参阅以下部分),请忽略这些字段。
- 模型 URI
您可以从尚未部署到 AI Platform Prediction 的模型获取预测结果,只需指定要使用的 SavedModel 的 URI 即可。SavedModel 必须存储在 Cloud Storage 中。
总之,您有三种方案可供选择,以指定用于批量预测的模型。这三种方案分别如下:
使用模型名称本身,并使用模型的默认版本。
使用模型和版本名称,以使用特定模型版本。
使用模型 URI,以使用存储在 Cloud Storage 中、但未部署到 AI Platform Prediction 的 SavedModel。
- 区域
要在其中运行作业的 Google Compute Engine 区域。 为获得最佳性能,您应该在同一区域运行预测作业并存储输入和输出数据,尤其是对于非常大的数据集而言。AI Platform Prediction 批量预测可在以下区域使用:
- us-central1 - us-east1 - europe-west1 - asia-east1
如需全面了解提供 AI Platform Prediction 服务(包括模型训练和在线预测)的区域,请参阅区域指南。
- 作业名称
作业的名称,必须符合以下条件:
- 仅包含混合大小写(区分大小写)字母、数字和下划线。
- 以字母开头。
- 最多包含 128 个字符。
- 在项目曾使用过的所有训练和批量预测作业名称中唯一。包括您在项目中创建的所有作业,无论其成功与否或状态如何。
- 批次大小(可选)
每个批次的记录数量。在调用您的模型之前,该服务会在内存中缓冲
batch_size
条记录。如果未指定,则默认值为 64。- 标签(可选)
您可以为作业添加标签,以便在查看或监控资源时整理作业,并对这些数据进行分类。例如,您可以按团队分类作业(通过添加
engineering
或research
标签),也可以按开发阶段(prod
或test
)对作业进行分类。如需向预测作业添加标签,请提供KEY=VALUE
对的列表。- 工作器数量上限(可选)
作业的处理集群中使用的预测节点的最大数量。这是为批量预测的自动扩缩功能设置上限的方法。如果未指定值,则默认为 10。无论指定的值为何,扩缩都受预测节点配额的限制。
- 运行时版本(可选)
作业使用的 AI Platform Prediction 版本。使用此选项时,您可以指定要用于未部署到 AI Platform Prediction 的模型的运行时版本。对于已部署的模型版本,您应始终省略此值,这表示服务使用的版本与部署模型版本时指定的版本相同。
- 签名名称(可选)
如果已保存的模型具有多个签名,请使用此选项指定自定义 TensorFlow 签名名称,这样您可以选择 TensorFlow SavedModel 中定义的备用输入/输出映射。 如需查看如何使用签名的指南以及如何指定自定义模型的输出的指南,请参阅 SavedModel 的 TensorFlow 文档。签名名称默认为 DEFAULT_SERVING_SIGNATURE_DEF_KEY,其值为
serving_default
。
以下示例定义用于保存配置数据的变量。
gcloud
使用 gcloud 命令行工具开始作业时,无需创建变量。但是,在本例中,创建变量可以使作业提交命令更容易输入和读取。
DATA_FORMAT="text" # JSON data format
INPUT_PATHS='gs://path/to/your/input/data/*'
OUTPUT_PATH='gs://your/desired/output/location'
MODEL_NAME='census'
VERSION_NAME='v1'
REGION='us-east1'
now=$(date +"%Y%m%d_%H%M%S")
JOB_NAME="census_batch_predict_$now"
MAX_WORKER_COUNT="20"
BATCH_SIZE="32"
LABELS="team=engineering,phase=test,owner=sara"
Python
当使用 Python 版 Google API 客户端库时,您可以使用 Python 字典表示作业和 PredictionInput 资源。
按照 AI Platform Prediction REST API 使用的语法为项目名称和模型或版本名称设置格式:
- project_name -> 'projects/project_name'
- model_name -> 'projects/project_name/models/model_name'
- version_name -> 'projects/project_name/models/model_name/versions/version_name'
为作业资源创建一个字典,并在其中填充以下两项内容:
一个名为
'jobId'
的键,其值为您要使用的作业名称。一个名为
'predictionInput'
的键,包含另一个存储 PredictionInput 所有所需成员以及您要使用的任何可选成员的字典对象。
以下示例展示了一个函数,该函数将配置信息作为输入变量并返回预测请求正文。 除了基本功能之外,该示例还会根据项目名称、模型名称和当前时间生成唯一的作业标识符。
import time import re def make_batch_job_body(project_name, input_paths, output_path, model_name, region, data_format='JSON', version_name=None, max_worker_count=None, runtime_version=None): project_id = 'projects/{}'.format(project_name) model_id = '{}/models/{}'.format(project_id, model_name) if version_name: version_id = '{}/versions/{}'.format(model_id, version_name) # Make a jobName of the format "model_name_batch_predict_YYYYMMDD_HHMMSS" timestamp = time.strftime('%Y%m%d_%H%M%S', time.gmtime()) # Make sure the project name is formatted correctly to work as the basis # of a valid job name. clean_project_name = re.sub(r'\W+', '_', project_name) job_id = '{}_{}_{}'.format(clean_project_name, model_name, timestamp) # Start building the request dictionary with required information. body = {'jobId': job_id, 'predictionInput': { 'dataFormat': data_format, 'inputPaths': input_paths, 'outputPath': output_path, 'region': region}} # Use the version if present, the model (its default version) if not. if version_name: body['predictionInput']['versionName'] = version_id else: body['predictionInput']['modelName'] = model_id # Only include a maximum number of workers or a runtime version if specified. # Otherwise let the service use its defaults. if max_worker_count: body['predictionInput']['maxWorkerCount'] = max_worker_count if runtime_version: body['predictionInput']['runtimeVersion'] = runtime_version return body
提交批量预测作业
提交作业是简单调用 projects.jobs.create 或其命令行工具对等项 gcloud ai-platform jobs submit prediction。
gcloud
以下示例使用上一节中定义的变量来开始批量预测。
gcloud ai-platform jobs submit prediction $JOB_NAME \
--model $MODEL_NAME \
--input-paths $INPUT_PATHS \
--output-path $OUTPUT_PATH \
--region $REGION \
--data-format $DATA_FORMAT
Python
使用 Python 版 Google API 客户端库开始批量预测作业遵循与其他客户端 SDK 过程类似的模式:
准备要用于调用的请求主体(如上一部分中所示)。
通过调用 ml.projects.jobs.create 来创建请求。
对请求调用 execute 以获得响应,确保检查 HTTP 是否存在错误。
将响应用作字典以从作业资源获取值。
您可以使用 Python 版 Google API 客户端库来调用 AI Platform Training 和 Prediction API,而无需手动构建 HTTP 请求。在运行以下代码示例之前,必须先设置身份验证。
import googleapiclient.discovery as discovery
project_id = 'projects/{}'.format(project_name)
ml = discovery.build('ml', 'v1')
request = ml.projects().jobs().create(parent=project_id,
body=batch_predict_body)
try:
response = request.execute()
print('Job requested.')
# The state returned will almost always be QUEUED.
print('state : {}'.format(response['state']))
except errors.HttpError as err:
# Something went wrong, print out some information.
print('There was an error getting the prediction results.' +
'Check the details:')
print(err._get_reason())
监控批量预测作业
批量预测作业可能需要很长时间才能完成。您可以监控 使用 Google Cloud 控制台监控作业进度:
转到 Google Cloud 控制台中的 AI Platform Prediction 作业页面:
在作业 ID 列表中点击作业名称。此时将打开作业详细信息页面。
作业名称和当前状态将显示在页面顶部。
若要了解更多详情,可点击查看日志查看 Cloud Logging 中的作业条目。
跟踪批量预测作业的进度还有其他方法。这些方法与监控训练作业遵循相同的模式。您可以在页面上找到有关如何监控训练作业的详细信息。 您可能需要对说明稍作调整以使用预测作业,但机制是相同的。
获取预测结果
该服务会将预测结果写入您指定的 Cloud Storage 位置。 输出文件有两种类型,其中可能包含所需的结果:
名为
prediction.errors_stats-NNNNN-of-NNNNN
的文件包含关于作业运行期间遇到的任何问题的信息。名为
prediction.results-NNNNN-of-NNNNN
的 JSON 行文件包含由模型输出定义的预测结果本身。
文件名包括索引号(上面显示的每个“N”代表一位数),用于指示总共应找到的文件数。例如,具有六个结果文件的作业包括从 prediction.results-00000-of-00006
到 prediction.results-00005-of-00006
的文件。
每个预测文件的每一行都是一个 JSON 对象,表示单个预测结果。您可以使用所选的文本编辑器打开预测文件。如需在命令行中快速查看,可以使用 gcloud storage cat
:
gcloud storage cat $OUTPUT_PATH/prediction.results-NNNNN-of-NNNNN|less
请记住,预测结果通常不会以与输入实例相同的顺序输出,即使您仅使用单个输入文件也是如此。您可以通过匹配实例键来查找实例的预测结果。
后续步骤
- 使用在线预测。
- 获取有关预测过程的详细信息。
- 排查您请求在线预测时出现的问题。
- 了解如何使用标签来整理作业。