Vertex AI 提供在线预测和批量预测这两个选项,用来通过经过训练的预测模型来预测未来的值。
在线预测是同步请求。如果您要发出请求以响应应用输入,或者在其他需要及时推理的情况下,可以使用在线预测。
批量预测请求是异步请求。如果您不需要立即响应并且希望使用单个请求处理累积的数据,就可使用批量预测。
本页面介绍如何使用在线预测来预测未来的值。如需了解如何使用批量预测来预测值,请参阅获取预测模型的批量预测结果。
您必须先将模型部署到端点,然后才能使用模型进行预测。端点是一组物理资源。
您可以请求说明而不是预测。说明的局部特征重要性值可以表示每个特征对预测结果的影响程度。如需查看概念性概览,请参阅用于预测的特征归因。
如需了解在线预测的价格,请参阅表格工作流的价格。
准备工作
在发出在线预测请求之前,您必须先训练模型。
创建或选择端点
使用函数 aiplatform.Endpoint.create()
创建端点。如果您已有端点,请使用 aiplatform.Endpoint()
函数进行选择。
以下代码提供了一个示例:
# Import required modules
from google.cloud import aiplatform
from google.cloud.aiplatform import models
PROJECT_ID = "PROJECT_ID"
REGION = "REGION"
# Initialize the Vertex SDK for Python for your project.
aiplatform.init(project=PROJECT_ID, location=REGION)
endpoint = aiplatform.Endpoint.create(display_name='ENDPOINT_NAME')
替换以下内容:
- PROJECT_ID:您的项目 ID。
- REGION:您在其中使用 Vertex AI 的区域。
- ENDPOINT_NAME:端点的显示名称。
选择经过训练的模型
使用 aiplatform.Model()
函数选择经过训练的模型:
# Create reference to the model trained ahead of time.
model_obj = models.Model("TRAINED_MODEL_PATH")
替换以下内容:
- TRAINED_MODEL_PATH:例如
projects/PROJECT_ID/locations/REGION/models/[TRAINED_MODEL_ID]
将模型部署到端点
使用 deploy()
函数将模型部署到端点。以下代码提供了一个示例:
deployed_model = endpoint.deploy(
model_obj,
machine_type='MACHINE_TYPE',
traffic_percentage=100,
min_replica_count='MIN_REPLICA_COUNT',
max_replica_count='MAX_REPLICA_COUNT',
sync=True,
deployed_model_display_name='DEPLOYED_MODEL_NAME',
)
替换以下内容:
- MACHINE_TYPE:例如
n1-standard-8
。详细了解机器类型。 - MIN_REPLICA_COUNT:此部署的最小节点数。节点数可根据预测负载的需要而增加或减少,直至达到节点数上限并且绝不会少于此节点数。此值必须大于或等于 1。如果未设置
min_replica_count
变量,则该值默认为1
。 - MAX_REPLICA_COUNT:此部署的节点数上限。节点数可根据预测负载的需要而增加或减少,直至达到此节点数并且绝不会少于节点数下限。如果您未设置
max_replica_count
变量,则节点数上限将设置为min_replica_count
的值。 - DEPLOYED_MODEL_NAME:
DeployedModel
的名称。您还可以将Model
的显示名称用于DeployedModel
。
模型部署可能需要大约十分钟时间。
获取在线预测结果
如需获取预测结果,请使用 predict()
函数并提供一个或多个输入实例。以下代码展示了一个示例:
predictions = endpoint.predict(instances=[{...}, {...}])
每个输入实例都是一个 Python 字典,它具有训练模型所用的架构。它必须包含与时间列对应的预测时可用键值对,以及包含目标预测列的历史值的预测时不可用键值对。Vertex AI 要求每个输入实例属于单个时序。实例中键值对的顺序并不重要。
输入实例存在以下限制条件:
- 预测时可用键值对必须都具有相同的数据点数量。
- 预测不可用键值对必须都具有相同的数据点数量。
- 预测时可用键值对的数据点数量必须至少与预测时不可用键值对的数据点数量相同。
如需详细了解预测中使用的列类型,请参阅特征类型和预测时的可用性。
以下代码演示了一组(两个)输入实例。Category
列包含属性数据。Timestamp
列包含预测时可用的数据。三个数据点是上下文数据,两个数据点是范围数据。Sales
列包含预测时不可用的数据。三个数据点都是上下文数据。如需了解如何在预测中使用上下文和范围,请参阅预测范围、上下文窗口和预测窗口。
instances=[
{
# Attribute
"Category": "Electronics",
# Available at forecast: three days of context, two days of horizon
"Timestamp": ['2023-08-03', '2023-08-04', '2023-08-05', '2023-08-06', '2023-08-07'],
# Unavailable at forecast: three days of context
"Sales": [490.50, 325.25, 647.00],
},
{
# Attribute
"Category": "Food",
# Available at forecast: three days of context, two days of horizon
"Timestamp": ['2023-08-03', '2023-08-04', '2023-08-05', '2023-08-06', '2023-08-07'],
# Unavailable at forecast: three days of context
"Sales": [190.50, 395.25, 47.00],
}
])
对于每个实例,Vertex AI 会返回 Sales
的两个预测结果作为响应,分别对应两个范围时间戳(“2023-08-06”和“2023-08-07”)。
为获得最佳性能,每个输入实例中上下文数据点的数量和范围数据点的数量必须与模型训练时使用的上下文和范围长度一致。如果不一致,Vertex AI 会填充或截断实例,使之与模型的大小相匹配。
如果输入实例中的上下文数据点数量小于或大于用于模型训练的上下文数据点数量,请确保此数据点数量在所有预测时可用键值对和所有预测时不可用键值对之间保持一致。
例如,假设一个模型用四天的上下文数据和两天的范围数据进行训练。只需使用三天的上下文数据即可发出预测请求。在这种情况下,预测时不可用键值对包含三个值。预测时可用键值对必须包含五个值。
在线预测的输出
Vertex AI 在 value
字段中提供在线预测输出:
{
'value': [...]
}
预测响应的长度取决于模型训练中使用的范围以及输入实例的范围。预测响应的长度是这两个值中的最小值。
请参考以下示例:
- 用于训练模型的
context
=15
,horizon
=50
。输入实例的context
=15
,horizon
=20
。预测响应的长度为20
。 - 用于训练模型的
context
=15
,horizon
=50
。输入实例的context
=15
,horizon
=100
。预测响应的长度为50
。
TFT 模型的在线预测输出
对于使用 Temporal Fusion Transformer (TFT) 训练的模型,除了 value
字段中的预测之外,Vertex AI 还提供 TFT 可解释性 tft_feature_importance
:
{
"tft_feature_importance": {
"attribute_weights": [...],
"attribute_columns": [...],
"context_columns": [...],
"context_weights": [...],
"horizon_weights": [...],
"horizon_columns": [...]
},
"value": [...]
}
attribute_columns
:预测特征(时间不变)。attribute_weights
:与每个attribute_columns
关联的权重。context_columns
:预测特征,其上下文窗口值用作 TFT 长/短期记忆(LSTM) 编码器的输入。context_weights
:与预测实例的每个context_columns
关联的特征重要性权重。horizon_columns
:预测特征,其预测范围值用作 TFT 长短期记忆 (LSTM) 解码器的输入。horizon_weights
:与预测实例的每个horizon_columns
关联的特征重要性权重。
针对分位数损失进行了优化的模型的在线预测输出
对于针对分位数损失进行了优化的模型,Vertex AI 提供以下在线预测输出:
{
"value": [...],
"quantile_values": [...],
"quantile_predictions": [...]
}
-
value
:如果分位数集包含中位数,则value
是中位数的预测值。否则,value
是集合中最小分位数的预测值。例如,如果分位数集为[0.1, 0.5, 0.9]
,则value
是分位数0.5
的预测值。如果分位数集为[0.1, 0.9]
,则value
是分位数0.1
的预测值。 quantile_values
:分位数的值,在模型训练期间设置此值。-
quantile_predictions
:与 quantile_values 关联的预测值。
例如,假设一个目标列为销售价值的模型。分位数值定义为 [0.1, 0.5, 0.9]
。Vertex AI 会返回以下分位数预测值:[4484, 5615, 6853]
。在这里,分位数集包括中位数,因此 value
是分位数 0.5
的预测值 (5615
)。可对分位数预测值进行如下解释:
- P(销售值 < 4484)= 10%
- P(销售值 < 5615)= 50%
- P(销售值 < 6853)= 90%
使用概率推理的模型的在线预测输出
如果模型使用概率推理,则 value
字段包含优化目标的最小化器。例如,如果优化目标为 minimize-rmse
,则 value
字段包含平均值。如果优化目标为 minimize-mae
,则 value
字段包含中位数值。
如果模型将概率推理与分位数结合使用,则除了优化目标的最小化器之外,Vertex AI 还提供分位数值和预测。分位数值是在模型训练期间设置的。分位数预测是与分位数值关联的预测值。
获取在线说明
如需获取说明,请使用 explain()
函数并提供一个或多个输入实例。以下代码展示了一个示例:
explanations = endpoint.explain(instances=[{...}, {...}])
对于在线预测和在线说明,输入实例的格式是相同的。如需了解详情,请参阅获取在线预测结果。
如需从概念上大致了解特征归因,请参阅用于预测的特征归因。
在线说明的输出
以下代码演示了如何输出说明结果:
# Import required modules
import json
from google.protobuf import json_format
def explanation_to_dict(explanation):
"""Converts the explanation proto to a human-friendly json."""
return json.loads(json_format.MessageToJson(explanation._pb))
for response in explanations.explanations:
print(explanation_to_dict(response))
说明结果的格式如下:
{
"attributions": [
{
"baselineOutputValue": 1.4194682836532593,
"instanceOutputValue": 2.152980089187622,
"featureAttributions": {
...
"store_id": [
0.007947325706481934
],
...
"dept_id": [
5.960464477539062e-08
],
"item_id": [
0.1100526452064514
],
"date": [
0.8525647521018982
],
...
"sales": [
0.0
]
},
"outputIndex": [
2
],
"approximationError": 0.01433318599207033,
"outputName": "value"
},
...
]
}
attributions
元素的数量取决于模型训练中使用的范围以及输入实例的范围。元素数量是这两个值中的最小值。
attributions
元素中的 featureAttributions
字段包含一个值,对应输入数据集中的每一列。Vertex AI 会为所有类型的特征生成说明:属性、预测时可用和预测时不可用。如需详细了解 attributions
元素的字段,请参阅归因。
删除端点
使用 undeploy_all()
和 delete()
函数删除端点。以下代码展示了一个示例:
endpoint.undeploy_all()
endpoint.delete()
后续步骤
- 了解在线预测的价格。