本页面介绍如何使用 Google Cloud 控制台或 Vertex AI API 从表格分类或回归模型获取在线(实时)预测结果和说明。
在线预测是同步请求,与之相对的是批量预测,即异步请求。如果您要发出请求以响应应用输入,或者在其他需要及时推断的情况下,可以使用在线预测。
您必须先将模型部署到端点,然后才能使用该模型执行在线预测。部署模型会将物理资源与模型相关联,以便以低延迟方式执行在线预测。
涵盖的主题如下:
准备工作
您必须先训练模型,然后才能获取在线预测结果。
将模型部署到端点
您可以将多个模型部署到一个端点,也可以将一个模型部署到多个端点。如需详细了解部署模型的方法和使用场景,请参阅部署模型简介。
请使用以下方法之一部署模型:
Google Cloud 控制台
在 Google Cloud 控制台的 Vertex AI 部分中,转到模型页面。
点击要部署的模型名称以打开其详情页面。
选择部署和测试标签页。
如果模型已部署到任何端点,部署模型 (Deploy your model) 部分中会列出这些端点。
点击部署到端点。
在定义端点页面中,按如下方式配置:
您可以选择将模型部署到新端点或现有端点。
- 如需将模型部署到新的端点,请选择 创建新端点并为新端点提供名称。
- 如需将模型部署到现有端点,请选择 添加到现有端点,然后从下拉列表中选择端点。
- 您可以将多个模型添加到一个端点,也可以将一个模型添加到多个端点。了解详情。
点击继续。
在模型设置页面中,按如下方式配置:
-
如果您要将模型部署到新端点,请接受 100 的流量拆分值。如果您要将模型部署到已部署有一个或多个模型的现有端点,则必须为要部署的模型和已部署模型更新流量拆分百分比,以使所有百分比的总和为 100%。
-
输入要为模型提供的计算节点数下限。
这是此模型始终可用的节点数。 您需要为使用的节点(无论是处理预测负载还是备用 [最少] 节点)付费,即使没有预测流量也是如此。请参阅价格页面。
-
选择机器类型。
较大的机器资源会提高预测性能并增加费用。
-
了解如何更改预测日志记录的默认设置。
-
点击继续
-
在模型监控页面中,点击继续。
在监控目标页面中,按如下方式配置:
- 输入训练数据的位置。
- 输入目标列的名称。
点击部署,将模型部署到端点。
API
使用 Vertex AI API 部署模型时,请完成以下步骤:
- 根据需要创建端点。
- 获取端点 ID。
- 将模型部署到端点。
创建端点
如果要将模型部署到现有端点,您可以跳过此步骤。
gcloud
以下示例使用 gcloud ai endpoints create
命令:
gcloud ai endpoints create \
--region=LOCATION \
--display-name=ENDPOINT_NAME
替换以下内容:
- LOCATION_ID:您在其中使用 Vertex AI 的区域。
ENDPOINT_NAME:端点的显示名称。
Google Cloud CLI 工具可能需要几秒钟才能创建端点。
REST
在使用任何请求数据之前,请先进行以下替换:
- LOCATION_ID:您的区域。
- PROJECT_ID:您的项目 ID。
- ENDPOINT_NAME:端点的显示名称。
HTTP 方法和网址:
POST https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/endpoints
请求 JSON 正文:
{ "display_name": "ENDPOINT_NAME" }
如需发送您的请求,请展开以下选项之一:
您应该收到类似以下内容的 JSON 响应:
{ "name": "projects/PROJECT_NUMBER/locations/LOCATION_ID/endpoints/ENDPOINT_ID/operations/OPERATION_ID", "metadata": { "@type": "type.googleapis.com/google.cloud.aiplatform.v1.CreateEndpointOperationMetadata", "genericMetadata": { "createTime": "2020-11-05T17:45:42.812656Z", "updateTime": "2020-11-05T17:45:42.812656Z" } } }
"done": true
。
Java
在尝试此示例之前,请按照《Vertex AI 快速入门:使用客户端库》中的 Java 设置说明执行操作。 如需了解详情,请参阅 Vertex AI Java API 参考文档。
如需向 Vertex AI 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证。
Node.js
在尝试此示例之前,请按照《Vertex AI 快速入门:使用客户端库》中的 Node.js 设置说明执行操作。 如需了解详情,请参阅 Vertex AI Node.js API 参考文档。
如需向 Vertex AI 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证。
Python
如需了解如何安装或更新 Vertex AI SDK for Python,请参阅安装 Vertex AI SDK for Python。 如需了解详情,请参阅 Python API 参考文档。
检索端点 ID
您需要端点 ID 才能部署模型。
gcloud
以下示例使用 gcloud ai endpoints list
命令:
gcloud ai endpoints list \
--region=LOCATION \
--filter=display_name=ENDPOINT_NAME
替换以下内容:
- LOCATION_ID:您在其中使用 Vertex AI 的区域。
ENDPOINT_NAME:端点的显示名称。
请注意
ENDPOINT_ID
列中显示的数字。请在以下步骤中使用此 ID。
REST
在使用任何请求数据之前,请先进行以下替换:
- LOCATION_ID:您在其中使用 Vertex AI 的区域。
- PROJECT_ID:您的项目 ID。
- ENDPOINT_NAME:端点的显示名称。
HTTP 方法和网址:
GET https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/endpoints?filter=display_name=ENDPOINT_NAME
如需发送您的请求,请展开以下选项之一:
您应该收到类似以下内容的 JSON 响应:
{ "endpoints": [ { "name": "projects/PROJECT_NUMBER/locations/LOCATION_ID/endpoints/ENDPOINT_ID", "displayName": "ENDPOINT_NAME", "etag": "AMEw9yPz5pf4PwBHbRWOGh0PcAxUdjbdX2Jm3QO_amguy3DbZGP5Oi_YUKRywIE-BtLx", "createTime": "2020-04-17T18:31:11.585169Z", "updateTime": "2020-04-17T18:35:08.568959Z" } ] }
部署模型
在下面选择您的语言或环境对应的标签页:
gcloud
以下示例使用 gcloud ai endpoints deploy-model
命令。
以下示例将 Model
部署到 Endpoint
,但不使用 GPU 来加快预测服务速度,而且未在多个 DeployedModel
资源之间拆分流量:
在使用下面的命令数据之前,请先进行以下替换:
- ENDPOINT_ID:端点的 ID。
- LOCATION_ID:您在其中使用 Vertex AI 的区域。
- MODEL_ID:要部署的模型的 ID。
-
DEPLOYED_MODEL_NAME:
DeployedModel
的名称。您还可以将Model
的显示名用于DeployedModel
。 -
MACHINE_TYPE:可选。用于此部署的每个节点的机器资源。其默认设置为
n1-standard-2
。详细了解机器类型。 -
MIN_REPLICA_COUNT:此部署的最小节点数。节点数可根据预测负载的需要而增加或减少,直至达到节点数上限并且绝不会少于此节点数。
此值必须大于或等于 1。如果省略
--min-replica-count
标志,则该值默认为 1。 -
MAX_REPLICA_COUNT:此部署的节点数上限。节点数可根据预测负载的需要而增加或减少,直至达到此节点数并且绝不会少于节点数下限。如果您省略
--max-replica-count
标志,则节点数上限将设置为--min-replica-count
的值。
执行 gcloud ai endpoints deploy-model 命令:
Linux、macOS 或 Cloud Shell
gcloud ai endpoints deploy-model ENDPOINT_ID\ --region=LOCATION_ID \ --model=MODEL_ID \ --display-name=DEPLOYED_MODEL_NAME \ --machine-type=MACHINE_TYPE \ --min-replica-count=MIN_REPLICA_COUNT \ --max-replica-count=MAX_REPLICA_COUNT \ --traffic-split=0=100
Windows (PowerShell)
gcloud ai endpoints deploy-model ENDPOINT_ID` --region=LOCATION_ID ` --model=MODEL_ID ` --display-name=DEPLOYED_MODEL_NAME ` --machine-type=MACHINE_TYPE ` --min-replica-count=MIN_REPLICA_COUNT ` --max-replica-count=MAX_REPLICA_COUNT ` --traffic-split=0=100
Windows (cmd.exe)
gcloud ai endpoints deploy-model ENDPOINT_ID^ --region=LOCATION_ID ^ --model=MODEL_ID ^ --display-name=DEPLOYED_MODEL_NAME ^ --machine-type=MACHINE_TYPE ^ --min-replica-count=MIN_REPLICA_COUNT ^ --max-replica-count=MAX_REPLICA_COUNT ^ --traffic-split=0=100
拆分流量
上述示例中的 --traffic-split=0=100
标志会将 Endpoint
接收的 100% 预测流量发送到新 DeployedModel
(使用临时 ID 0
表示)。如果您的 Endpoint
已有其他 DeployedModel
资源,那么您可以在新 DeployedModel
和旧资源之间拆分流量。例如,如需将 20% 的流量发送到新 DeployedModel
,将 80% 发送到旧版本,请运行以下命令。
在使用下面的命令数据之前,请先进行以下替换:
- OLD_DEPLOYED_MODEL_ID:现有
DeployedModel
的 ID。
执行 gcloud ai endpoints deploy-model 命令:
Linux、macOS 或 Cloud Shell
gcloud ai endpoints deploy-model ENDPOINT_ID\ --region=LOCATION_ID \ --model=MODEL_ID \ --display-name=DEPLOYED_MODEL_NAME \ --machine-type=MACHINE_TYPE \ --min-replica-count=MIN_REPLICA_COUNT \ --max-replica-count=MAX_REPLICA_COUNT \ --traffic-split=0=20,OLD_DEPLOYED_MODEL_ID=80
Windows (PowerShell)
gcloud ai endpoints deploy-model ENDPOINT_ID` --region=LOCATION_ID ` --model=MODEL_ID ` --display-name=DEPLOYED_MODEL_NAME \ --machine-type=MACHINE_TYPE ` --min-replica-count=MIN_REPLICA_COUNT ` --max-replica-count=MAX_REPLICA_COUNT ` --traffic-split=0=20,OLD_DEPLOYED_MODEL_ID=80
Windows (cmd.exe)
gcloud ai endpoints deploy-model ENDPOINT_ID^ --region=LOCATION_ID ^ --model=MODEL_ID ^ --display-name=DEPLOYED_MODEL_NAME \ --machine-type=MACHINE_TYPE ^ --min-replica-count=MIN_REPLICA_COUNT ^ --max-replica-count=MAX_REPLICA_COUNT ^ --traffic-split=0=20,OLD_DEPLOYED_MODEL_ID=80
REST
您可以使用 endpoints.predict 方法请求在线预测。
部署此模型。
在使用任何请求数据之前,请先进行以下替换:
- LOCATION_ID:您在其中使用 Vertex AI 的区域。
- PROJECT_ID:您的项目 ID。
- ENDPOINT_ID:端点的 ID。
- MODEL_ID:要部署的模型的 ID。
-
DEPLOYED_MODEL_NAME:
DeployedModel
的名称。您还可以将Model
的显示名用于DeployedModel
。 -
MACHINE_TYPE:可选。用于此部署的每个节点的机器资源。其默认设置为
n1-standard-2
。详细了解机器类型。 - ACCELERATOR_TYPE:要挂接到机器的加速器类型。如果未指定 ACCELERATOR_COUNT 或为零,则可选。建议不要用于使用非 GPU 映像的 AutoML 模型或自定义训练模型。了解详情。
- ACCELERATOR_COUNT:每个副本要使用的加速器数量。可选。对于使用非 GPU 映像的 AutoML 模型或自定义模型,应该为零或未指定。
- MIN_REPLICA_COUNT:此部署的最小节点数。 节点数可根据预测负载的需要而增加或减少,直至达到节点数上限并且绝不会少于此节点数。此值必须大于或等于 1。
- MAX_REPLICA_COUNT:此部署的节点数上限。 节点数可根据预测负载的需要而增加或减少,直至达到此节点数并且绝不会少于节点数下限。
- TRAFFIC_SPLIT_THIS_MODEL:流向此端点的要路由到使用此操作部署的模型的预测流量百分比。默认值为 100。所有流量百分比之和必须为 100。详细了解流量拆分。
- DEPLOYED_MODEL_ID_N:可选。如果将其他模型部署到此端点,您必须更新其流量拆分百分比,以便所有百分比之和等于 100。
- TRAFFIC_SPLIT_MODEL_N:已部署模型 ID 密钥的流量拆分百分比值。
- PROJECT_NUMBER:自动生成的项目编号
HTTP 方法和网址:
POST https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/endpoints/ENDPOINT_ID:deployModel
请求 JSON 正文:
{ "deployedModel": { "model": "projects/PROJECT/locations/us-central1/models/MODEL_ID", "displayName": "DEPLOYED_MODEL_NAME", "dedicatedResources": { "machineSpec": { "machineType": "MACHINE_TYPE", "acceleratorType": "ACCELERATOR_TYPE", "acceleratorCount": "ACCELERATOR_COUNT" }, "minReplicaCount": MIN_REPLICA_COUNT, "maxReplicaCount": MAX_REPLICA_COUNT }, }, "trafficSplit": { "0": TRAFFIC_SPLIT_THIS_MODEL, "DEPLOYED_MODEL_ID_1": TRAFFIC_SPLIT_MODEL_1, "DEPLOYED_MODEL_ID_2": TRAFFIC_SPLIT_MODEL_2 }, }
如需发送您的请求,请展开以下选项之一:
您应该收到类似以下内容的 JSON 响应:
{ "name": "projects/PROJECT_ID/locations/LOCATION/endpoints/ENDPOINT_ID/operations/OPERATION_ID", "metadata": { "@type": "type.googleapis.com/google.cloud.aiplatform.v1.DeployModelOperationMetadata", "genericMetadata": { "createTime": "2020-10-19T17:53:16.502088Z", "updateTime": "2020-10-19T17:53:16.502088Z" } } }
Java
在尝试此示例之前,请按照《Vertex AI 快速入门:使用客户端库》中的 Java 设置说明执行操作。 如需了解详情,请参阅 Vertex AI Java API 参考文档。
如需向 Vertex AI 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证。
Python
如需了解如何安装或更新 Vertex AI SDK for Python,请参阅安装 Vertex AI SDK for Python。 如需了解详情,请参阅 Python API 参考文档。
Node.js
在尝试此示例之前,请按照《Vertex AI 快速入门:使用客户端库》中的 Node.js 设置说明执行操作。 如需了解详情,请参阅 Vertex AI Node.js API 参考文档。
如需向 Vertex AI 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证。
了解如何更改预测日志记录的默认设置。
获取操作状态
某些请求会启动需要一些时间才能完成的长时间运行的操作。这些请求会返回操作名称,您可以使用该名称查看操作状态或取消操作。Vertex AI 提供辅助方法来调用长时间运行的操作。如需了解详情,请参阅使用长时间运行的操作。
使用已部署的模型获取在线预测
如需进行在线预测,请向模型提交一个或多个测试项进行分析,模型会返回基于模型目标的结果。使用 Google Cloud 控制台或 Vertex AI API 请求在线预测。
Google Cloud 控制台
在 Google Cloud 控制台的 Vertex AI 部分中,转到模型页面。
从模型列表中,点击要向其请求预测的模型的名称。
选择部署和测试标签页。
在测试模型部分下,添加测试项以请求预测。系统会为您填充基准预测数据,您也可以输入自己的预测数据并点击预测。
预测完成后,Vertex AI 会在控制台中返回结果。
API:分类
gcloud
-
创建名为
request.json
且包含以下内容的文件:{ "instances": [ { PREDICTION_DATA_ROW } ] }
替换以下内容:
-
PREDICTION_DATA_ROW:一个 JSON 对象,使用键作为特征名称,值作为相应的特征值。例如,对于包含数字、字符串数组和类别的数据集,数据行可能类似于以下示例请求:
"length":3.6, "material":"cotton", "tag_array": ["abc","def"]
必须为训练中包含的每个特征提供一个值。用于预测的数据格式必须与用于训练的格式匹配。如需了解详情,请参阅预测的数据格式。
-
-
运行以下命令:
gcloud ai endpoints predict ENDPOINT_ID \ --region=LOCATION_ID \ --json-request=request.json
替换以下内容:
- ENDPOINT_ID:端点的 ID。
- LOCATION_ID:您在其中使用 Vertex AI 的区域。
REST
您可以使用 endpoints.predict 方法请求在线预测。
在使用任何请求数据之前,请先进行以下替换:
-
LOCATION_ID:端点所在的区域。例如
us-central1
。 - PROJECT_ID:您的项目 ID。
- ENDPOINT_ID:端点的 ID。
-
PREDICTION_DATA_ROW:一个 JSON 对象,使用键作为特征名称,值作为相应的特征值。例如,对于包含数字、字符串数组和类别的数据集,数据行可能类似于以下示例请求:
"length":3.6, "material":"cotton", "tag_array": ["abc","def"]
必须为训练中包含的每个特征提供一个值。用于预测的数据格式必须与用于训练的格式匹配。如需了解详情,请参阅预测的数据格式。
- DEPLOYED_MODEL_ID:由
predict
方法输出。用于生成预测结果的模型的 ID。
HTTP 方法和网址:
POST https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/endpoints/ENDPOINT_ID:predict
请求 JSON 正文:
{ "instances": [ { PREDICTION_DATA_ROW } ] }
如需发送请求,请选择以下方式之一:
curl
将请求正文保存在名为 request.json
的文件中,然后执行以下命令:
curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/endpoints/ENDPOINT_ID:predict"
PowerShell
将请求正文保存在名为 request.json
的文件中,然后执行以下命令:
$cred = gcloud auth print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }
Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/endpoints/ENDPOINT_ID:predict" | Select-Object -Expand Content
您应该收到类似以下内容的 JSON 响应:
{ "predictions": [ { "scores": [ 0.96771615743637085, 0.032283786684274673 ], "classes": [ "0", "1" ] } ] "deployedModelId": "2429510197" }
Java
在尝试此示例之前,请按照《Vertex AI 快速入门:使用客户端库》中的 Java 设置说明执行操作。 如需了解详情,请参阅 Vertex AI Java API 参考文档。
如需向 Vertex AI 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证。
Node.js
在尝试此示例之前,请按照《Vertex AI 快速入门:使用客户端库》中的 Node.js 设置说明执行操作。 如需了解详情,请参阅 Vertex AI Node.js API 参考文档。
如需向 Vertex AI 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证。
Python
如需了解如何安装或更新 Vertex AI SDK for Python,请参阅安装 Vertex AI SDK for Python。 如需了解详情,请参阅 Python API 参考文档。
API:回归
gcloud
-
创建一个名为“request.json”且包含以下内容的文件:
{ "instances": [ { PREDICTION_DATA_ROW } ] }
替换以下内容:
-
PREDICTION_DATA_ROW:一个 JSON 对象,使用键作为特征名称,值作为相应的特征值。例如,对于包含数字、数字数组和类别的数据集,数据行可能类似于以下示例请求:
"age":3.6, "sq_ft":5392, "code": "90331"
必须为训练中包含的每个特征提供一个值。用于预测的数据格式必须与用于训练的格式匹配。如需了解详情,请参阅预测的数据格式。
-
-
运行以下命令:
gcloud ai endpoints predict ENDPOINT_ID \ --region=LOCATION_ID \ --json-request=request.json
替换以下内容:
- ENDPOINT_ID:端点的 ID。
- LOCATION_ID:您在其中使用 Vertex AI 的区域。
REST
您可以使用 endpoints.predict 方法请求在线预测。
在使用任何请求数据之前,请先进行以下替换:
-
LOCATION_ID:端点所在的区域。例如
us-central1
。 - PROJECT_ID:您的项目 ID。
- ENDPOINT_ID:端点的 ID。
-
PREDICTION_DATA_ROW:一个 JSON 对象,使用键作为特征名称,值作为相应的特征值。例如,对于包含数字、数字数组和类别的数据集,数据行可能类似于以下示例请求:
"age":3.6, "sq_ft":5392, "code": "90331"
必须为训练中包含的每个特征提供一个值。用于预测的数据格式必须与用于训练的格式匹配。如需了解详情,请参阅预测的数据格式。
- DEPLOYED_MODEL_ID:由
predict
方法输出。用于生成预测结果的模型的 ID。
HTTP 方法和网址:
POST https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/endpoints/ENDPOINT_ID:predict
请求 JSON 正文:
{ "instances": [ { PREDICTION_DATA_ROW } ] }
如需发送请求,请选择以下方式之一:
curl
将请求正文保存在名为 request.json
的文件中,然后执行以下命令:
curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/endpoints/ENDPOINT_ID:predict"
PowerShell
将请求正文保存在名为 request.json
的文件中,然后执行以下命令:
$cred = gcloud auth print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }
Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/endpoints/ENDPOINT_ID:predict" | Select-Object -Expand Content
您应该收到类似以下内容的 JSON 响应:
{ "predictions": [ [ { "value": 65.14233 } ] ], "deployedModelId": "DEPLOYED_MODEL_ID" }
Java
在尝试此示例之前,请按照《Vertex AI 快速入门:使用客户端库》中的 Java 设置说明执行操作。 如需了解详情,请参阅 Vertex AI Java API 参考文档。
如需向 Vertex AI 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证。
Node.js
在尝试此示例之前,请按照《Vertex AI 快速入门:使用客户端库》中的 Node.js 设置说明执行操作。 如需了解详情,请参阅 Vertex AI Node.js API 参考文档。
如需向 Vertex AI 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证。
Python
如需了解如何安装或更新 Vertex AI SDK for Python,请参阅安装 Vertex AI SDK for Python。 如需了解详情,请参阅 Python API 参考文档。
解读预测结果
分类
分类模型会返回置信度分数。
置信度分数传达了模型将每个类列或标签与测试项相关联的强度。该数值越高,模型应用于该项的置信度就越高。您可以决定接受模型的置信度分数为多高。
回归
回归模型会返回预测值。
如果模型使用概率推理,则 value
字段包含优化目标的最小化器。例如,如果优化目标为 minimize-rmse
,则 value
字段包含平均值。如果优化目标为 minimize-mae
,则 value
字段包含中位数值。
如果模型将概率推理与分位数结合使用,则除了优化目标的最小化器之外,Vertex AI 还提供分位数值和预测。分位数值是在模型训练期间设置的。分位数预测是与分位数值关联的预测值。
TabNet 可让用户深入了解哪些特征有助于做出决策,从而提供固有的模型可解释性。该算法利用注意力,通过学习选择性地增强某些特征的影响,同时通过加权平均值降低其他特征的影响。对于特定的决策,TabNet 会按阶梯式决定每个特征的重要性。然后,它会组合每个步骤来创建最终预测。注意力是乘法,其中较大的值表示特征在预测中发挥较大的作用,值为零表示特征在决策中不起作用。由于 TabNet 使用多个决策步骤,因此在经过适当的扩缩后,对所有步骤中的特征的注意力会进行线性组合。所有 TabNet 决策步骤中的这种线性组合是 TabNet 为您提供的总特征重要性。
预测的输出示例
回归模型中具有特征重要性的在线预测返回的载荷类似于以下示例。
{
"predictions":[
{
"value":0.3723912537097931,
"feature_importance":{
"MSSubClass":0.12,
"MSZoning":0.33,
"LotFrontage":0.27,
"LotArea":0.06,
...
}
}
]
}