解释预测结果

本页面介绍如何使用特征重要性来深入了解模型如何进行预测。

如需详细了解 AI Explanations,请参阅 AI Platform 的 AI Explanations 简介

简介

使用机器学习模型做出业务决策时,请务必了解训练数据对最终模型的影响程度,以及该模型如何得出各个预测。这种理解有助于确保您的模型合理且准确。

AutoML Tables 提供特征重要性(有时称为特征属性),可让您了解哪些特征对模型训练(模型特征重要性)和个别预测(局部特征重要性)贡献最大。

AutoML Tables 使用 Sampled Shapley 方法计算特征重要性。如需详细了解模型可解释性,请参阅 AI 说明简介

模型特征重要性

模型特征重要性有助于您确保影响模型训练的特征对您的数据和业务问题有所帮助。所有特征重要性值较高的特征都应代表有效的预测信号,并且能够始终包含在您的预测请求中。

模型特征重要性是以每个特征的百分比形式提供的:百分比越高,特征对模型训练的影响就越大。

获取模型特征重要性

控制台

要使用 Google Cloud Console 查看模型的特征重要性值,请执行以下操作:

  1. 转到 Google Cloud Console 中的 AutoML Tables 页面。

    转到 AutoML Tables 页面

  2. 选择左侧导航窗格中的模型标签页,然后选择要获取评估指标的模型。

  3. 打开评估标签页。

  4. 向下滚动以查看特征重要性图。

AutoML Tables 评估页面

REST 和命令行

要获取模型的特征重要性值,请使用 model.get 方法。

在使用任何请求数据之前,请先进行以下替换:

  • endpoint:全球位置为 automl.googleapis.com,欧盟地区为 eu-automl.googleapis.com
  • project-id:您的 Google Cloud 项目 ID。
  • location:资源的位置:全球位置为 us-central1,欧盟位置为 eu
  • model-id:您要获取特征重要性信息的模型的 ID。例如 TBL543

HTTP 方法和网址:

GET https://endpoint/v1beta1/projects/project-id/locations/location/models/model-id

如需发送请求,请选择以下方式之一:

curl

执行以下命令:

curl -X GET \
-H "Authorization: Bearer "$(gcloud auth application-default print-access-token) \
"https://endpoint/v1beta1/projects/project-id/locations/location/models/model-id"

PowerShell

执行以下命令:

$cred = gcloud auth application-default print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }

Invoke-WebRequest `
-Method GET `
-Headers $headers `
-Uri "https://endpoint/v1beta1/projects/project-id/locations/location/models/model-id" | Select-Object -Expand Content
每列的特征重要性值在 TablesModelColumnInfo 对象中返回。
{
  "name": "projects/292381/locations/us-central1/models/TBL543",
  "displayName": "Quickstart_Model",
  ...
  "tablesModelMetadata": {
    "targetColumnSpec": {
    ...
    },
    "inputFeatureColumnSpecs": [
    ...
    ],
    "optimizationObjective": "MAXIMIZE_AU_ROC",
    "tablesModelColumnInfo": [
      {
        "columnSpecName": "projects/292381/locations/us-central1/datasets/TBL543/tableSpecs/246/columnSpecs/331",
        "columnDisplayName": "Contact",
        "featureImportance": 0.093201876
      },
      {
        "columnSpecName": "projects/292381/locations/us-central1/datasets/TBL543/tableSpecs/246/columnSpecs/638",
        "columnDisplayName": "Month",
        "featureImportance": 0.215029223
      },
      ...
    ],
    "trainBudgetMilliNodeHours": "1000",
    "trainCostMilliNodeHours": "1000",
    "classificationType": "BINARY",
    "predictionSampleRows": [
    ...
    ],
    "splitPercentageConfig": {
    ...
    }
  },
  "creationState": "CREATED",
  "deployedModelSizeBytes": "1160941568"
}

Java

如果资源位于欧盟区域,您必须明确设置端点。了解详情


import com.google.cloud.automl.v1beta1.AutoMlClient;
import com.google.cloud.automl.v1beta1.Model;
import com.google.cloud.automl.v1beta1.ModelName;
import com.google.cloud.automl.v1beta1.TablesModelColumnInfo;
import io.grpc.StatusRuntimeException;
import java.io.IOException;
import java.text.DateFormat;
import java.text.SimpleDateFormat;

public class TablesGetModel {

  public static void main(String[] args) throws IOException, StatusRuntimeException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String region = "YOUR_REGION";
    String modelId = "YOUR_MODEL_ID";
    getModel(projectId, region, modelId);
  }

  // Demonstrates using the AutoML client to get model details.
  public static void getModel(String projectId, String computeRegion, String modelId)
      throws IOException, StatusRuntimeException {
    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (AutoMlClient client = AutoMlClient.create()) {

      // Get the full path of the model.
      ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId);

      // Get complete detail of the model.
      Model model = client.getModel(modelFullId);

      // Display the model information.
      System.out.format("Model name: %s%n", model.getName());
      System.out.format(
          "Model Id: %s\n", model.getName().split("/")[model.getName().split("/").length - 1]);
      System.out.format("Model display name: %s%n", model.getDisplayName());
      System.out.format("Dataset Id: %s%n", model.getDatasetId());
      System.out.println("Tables Model Metadata: ");
      System.out.format(
          "\tTraining budget: %s%n", model.getTablesModelMetadata().getTrainBudgetMilliNodeHours());
      System.out.format(
          "\tTraining cost: %s%n", model.getTablesModelMetadata().getTrainBudgetMilliNodeHours());

      DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSZ");
      String createTime =
          dateFormat.format(new java.util.Date(model.getCreateTime().getSeconds() * 1000));
      System.out.format("Model create time: %s%n", createTime);

      System.out.format("Model deployment state: %s%n", model.getDeploymentState());

      // Get features of top importance
      for (TablesModelColumnInfo info :
          model.getTablesModelMetadata().getTablesModelColumnInfoList()) {
        System.out.format(
            "Column: %s - Importance: %.2f%n",
            info.getColumnDisplayName(), info.getFeatureImportance());
      }
    }
  }
}

Node.js

如果资源位于欧盟区域,您必须明确设置端点。了解详情

const automl = require('@google-cloud/automl');
const client = new automl.v1beta1.AutoMlClient();

/**
 * Demonstrates using the AutoML client to get model details.
 * TODO(developer): Uncomment the following lines before running the sample.
 */
// const projectId = '[PROJECT_ID]' e.g., "my-gcloud-project";
// const computeRegion = '[REGION_NAME]' e.g., "us-central1";
// const modelId = '[MODEL_ID]' e.g., "TBL4704590352927948800";

// Get the full path of the model.
const modelFullId = client.modelPath(projectId, computeRegion, modelId);

// Get complete detail of the model.
client
  .getModel({name: modelFullId})
  .then(responses => {
    const model = responses[0];

    // Display the model information.
    console.log(`Model name: ${model.name}`);
    console.log(`Model Id: ${model.name.split('/').pop(-1)}`);
    console.log(`Model display name: ${model.displayName}`);
    console.log(`Dataset Id: ${model.datasetId}`);
    console.log('Tables model metadata: ');
    console.log(
      `\tTraining budget: ${model.tablesModelMetadata.trainBudgetMilliNodeHours}`
    );
    console.log(
      `\tTraining cost: ${model.tablesModelMetadata.trainCostMilliNodeHours}`
    );
    console.log(`Model deployment state: ${model.deploymentState}`);
  })
  .catch(err => {
    console.error(err);
  });

Python

AutoML Tables 的客户端库包含其他 Python 方法,这些方法使用 AutoML Tables API 进行简化。这些方法按名称而不是 ID 来引用数据集和模型。您的数据集和模型的名称必须是唯一的。如需了解详情,请参阅客户端参考

如果资源位于欧盟区域,您必须明确设置端点。了解详情

# TODO(developer): Uncomment and set the following variables
# project_id = 'PROJECT_ID_HERE'
# compute_region = 'COMPUTE_REGION_HERE'
# model_display_name = 'MODEL_DISPLAY_NAME_HERE'

from google.cloud import automl_v1beta1 as automl

client = automl.TablesClient(project=project_id, region=compute_region)

# Get complete detail of the model.
model = client.get_model(model_display_name=model_display_name)

# Retrieve deployment state.
if model.deployment_state == automl.Model.DeploymentState.DEPLOYED:
    deployment_state = "deployed"
else:
    deployment_state = "undeployed"

# get features of top importance
feat_list = [
    (column.feature_importance, column.column_display_name)
    for column in model.tables_model_metadata.tables_model_column_info
]
feat_list.sort(reverse=True)
if len(feat_list) < 10:
    feat_to_show = len(feat_list)
else:
    feat_to_show = 10

# Display the model information.
print("Model name: {}".format(model.name))
print("Model id: {}".format(model.name.split("/")[-1]))
print("Model display name: {}".format(model.display_name))
print("Features of top importance:")
for feat in feat_list[:feat_to_show]:
    print(feat)
print("Model create time: {}".format(model.create_time))
print("Model deployment state: {}".format(deployment_state))

局部特征重要性

局部特征重要性可以让您了解特定预测请求中的个别特征是如何影响所产生的预测的。

为得到各个局部特征重要性值,需首先计算基准预测得分。基线值使用数值特征的中值和分类特征的模式基于训练数据进行计算。根据基准值生成的预测是基准预测得分。

对于分类模型,局部特征重要性可以表示与基准预测得分相比,每个特征在分配给得分最高的类的概率中加减多少。得分值介于 0.0 和 1.0 之间,因此分类模型的局部特征重要性始终介于 -1.0 和 1.0(含边界值)之间。

对于回归模型,预测的局部特征重要性可以表示与基准预测得分相比,每项特征在结果中加减多少。

局部特征重要性可用于在线预测和批量预测。

获取在线预测的局部特征重要性

控制台

要使用 Google Cloud Console 获取在线预测的局部特征重要性值,请按进行在线预测中的步骤操作,确保选中生成特征重要性复选框。

AutoML Tables 特征重要性复选框

REST 和命令行

要获取在线预测请求的局部特征重要性,请使用 model.predict 方法,并将 feature_importance 参数设置为 true。

在使用任何请求数据之前,请先进行以下替换:

  • endpoint:全球位置为 automl.googleapis.com,欧盟地区为 eu-automl.googleapis.com
  • project-id:您的 Google Cloud 项目 ID。
  • location:资源的位置:全球位置为 us-central1,欧盟位置为 eu
  • model-id:模型的 ID。例如 TBL543
  • valueN:每列的值(按正确的顺序显示)。

HTTP 方法和网址:

POST https://endpoint/v1beta1/projects/project-id/locations/location/models/model-id:predict

请求 JSON 正文:

{
  "payload": {
    "row": {
      "values": [
        value1, value2,...
      ]
    }
  }
  "params": {
    "feature_importance": "true"
  }
}

如需发送请求,请选择以下方式之一:

curl

将请求正文保存在名为 request.json 的文件中,然后执行以下命令:

curl -X POST \
-H "Authorization: Bearer "$(gcloud auth application-default print-access-token) \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://endpoint/v1beta1/projects/project-id/locations/location/models/model-id:predict"

PowerShell

将请求正文保存在名为 request.json 的文件中,然后执行以下命令:

$cred = gcloud auth application-default print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }

Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://endpoint/v1beta1/projects/project-id/locations/location/models/model-id:predict" | Select-Object -Expand Content
特征重要性结果在“tablesModelColumnInfo”对象中返回。
"tablesModelColumnInfo": [
  {
     "columnSpecName": "projects/2381/locations/us-central1/datasets/TBL8440/tableSpecs/766336/columnSpecs/4704",
     "columnDisplayName": "Promo",
     "featureImportance": 1626.5464
  },
  {
     "columnSpecName": "projects/2381/locations/us-central1/datasets/TBL8440/tableSpecs/766336/columnSpecs/6800",
     "columnDisplayName": "Open",
     "featureImportance": -7496.5405
  },
  {
     "columnSpecName": "projects/2381/locations/us-central1/datasets/TBL8440/tableSpecs/766336/columnSpecs/9824",
     "columnDisplayName": "StateHoliday"
  }
],

如果某列的特征重要性值为 0,则该列不显示特征重要性。

Java

如果资源位于欧盟区域,您必须明确设置端点。了解详情

import com.google.cloud.automl.v1beta1.AnnotationPayload;
import com.google.cloud.automl.v1beta1.ExamplePayload;
import com.google.cloud.automl.v1beta1.ModelName;
import com.google.cloud.automl.v1beta1.PredictRequest;
import com.google.cloud.automl.v1beta1.PredictResponse;
import com.google.cloud.automl.v1beta1.PredictionServiceClient;
import com.google.cloud.automl.v1beta1.Row;
import com.google.cloud.automl.v1beta1.TablesAnnotation;
import com.google.protobuf.Value;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

class TablesPredict {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String modelId = "YOUR_MODEL_ID";
    // Values should match the input expected by your model.
    List<Value> values = new ArrayList<>();
    // values.add(Value.newBuilder().setBoolValue(true).build());
    // values.add(Value.newBuilder().setNumberValue(10).build());
    // values.add(Value.newBuilder().setStringValue("YOUR_STRING").build());
    predict(projectId, modelId, values);
  }

  static void predict(String projectId, String modelId, List<Value> values) throws IOException {
    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PredictionServiceClient client = PredictionServiceClient.create()) {
      // Get the full path of the model.
      ModelName name = ModelName.of(projectId, "us-central1", modelId);
      Row row = Row.newBuilder().addAllValues(values).build();
      ExamplePayload payload = ExamplePayload.newBuilder().setRow(row).build();

      // Feature importance gives you visibility into how the features in a specific prediction
      // request informed the resulting prediction. For more info, see:
      // https://cloud.google.com/automl-tables/docs/features#local
      PredictRequest request =
          PredictRequest.newBuilder()
              .setName(name.toString())
              .setPayload(payload)
              .putParams("feature_importance", "true")
              .build();

      PredictResponse response = client.predict(request);

      System.out.println("Prediction results:");
      for (AnnotationPayload annotationPayload : response.getPayloadList()) {
        TablesAnnotation tablesAnnotation = annotationPayload.getTables();
        System.out.format(
            "Classification label: %s%n", tablesAnnotation.getValue().getStringValue());
        System.out.format("Classification score: %.3f%n", tablesAnnotation.getScore());
        // Get features of top importance
        tablesAnnotation
            .getTablesModelColumnInfoList()
            .forEach(
                info ->
                    System.out.format(
                        "\tColumn: %s - Importance: %.2f%n",
                        info.getColumnDisplayName(), info.getFeatureImportance()));
      }
    }
  }
}

Node.js

如果资源位于欧盟区域,您必须明确设置端点。了解详情


/**
 * Demonstrates using the AutoML client to request prediction from
 * automl tables using csv.
 * TODO(developer): Uncomment the following lines before running the sample.
 */
// const projectId = '[PROJECT_ID]' e.g., "my-gcloud-project";
// const computeRegion = '[REGION_NAME]' e.g., "us-central1";
// const modelId = '[MODEL_ID]' e.g., "TBL000000000000";
// const inputs = [{ numberValue: 1 }, { stringValue: 'value' }, { stringValue: 'value2' } ...]

const automl = require('@google-cloud/automl');

// Create client for prediction service.
const automlClient = new automl.v1beta1.PredictionServiceClient();

// Get the full path of the model.
const modelFullId = automlClient.modelPath(projectId, computeRegion, modelId);

inputs = JSON.parse(inputs);

async function predict() {
  // Set the payload by giving the row values.
  const payload = {
    row: {
      values: inputs,
    },
  };

  // Params is additional domain-specific parameters.
  // Currently there is no additional parameters supported.
  const [response] = await automlClient.predict({
    name: modelFullId,
    payload: payload,
    params: {feature_importance: true},
  });
  console.log('Prediction results:');

  for (const result of response.payload) {
    console.log(`Predicted class name: ${result.displayName}`);
    console.log(`Predicted class score: ${result.tables.score}`);

    // Get features of top importance
    const featureList = result.tables.tablesModelColumnInfo.map(
      columnInfo => {
        return {
          importance: columnInfo.featureImportance,
          displayName: columnInfo.columnDisplayName,
        };
      }
    );
    // Sort features by their importance, highest importance first
    featureList.sort((a, b) => {
      return b.importance - a.importance;
    });

    // Print top 10 important features
    console.log('Features of top importance');
    console.log(featureList.slice(0, 10));
  }
}
predict();

Python

AutoML Tables 的客户端库包含其他 Python 方法,这些方法使用 AutoML Tables API 进行简化。这些方法按名称而不是 ID 来引用数据集和模型。您的数据集和模型的名称必须是唯一的。如需了解详情,请参阅客户端参考

如果资源位于欧盟区域,您必须明确设置端点。了解详情

# TODO(developer): Uncomment and set the following variables
# project_id = 'PROJECT_ID_HERE'
# compute_region = 'COMPUTE_REGION_HERE'
# model_display_name = 'MODEL_DISPLAY_NAME_HERE'
# inputs = {'value': 3, ...}

from google.cloud import automl_v1beta1 as automl

client = automl.TablesClient(project=project_id, region=compute_region)

if feature_importance:
    response = client.predict(
        model_display_name=model_display_name,
        inputs=inputs,
        feature_importance=True,
    )
else:
    response = client.predict(
        model_display_name=model_display_name, inputs=inputs
    )

print("Prediction results:")
for result in response.payload:
    print(
        "Predicted class name: {}".format(result.tables.value)
    )
    print("Predicted class score: {}".format(result.tables.score))

    if feature_importance:
        # get features of top importance
        feat_list = [
            (column.feature_importance, column.column_display_name)
            for column in result.tables.tables_model_column_info
        ]
        feat_list.sort(reverse=True)
        if len(feat_list) < 10:
            feat_to_show = len(feat_list)
        else:
            feat_to_show = 10

        print("Features of top importance:")
        for feat in feat_list[:feat_to_show]:
            print(feat)

获取批量预测的局部特征重要性

控制台

要使用 Google Cloud Console 获取批量预测的局部特征重要性值,请按照请求批量预测中的步骤操作,确保选中生成特征重要性复选框。

AutoML Tables 特征重要性复选框

通过为每个特征添加一个名为 feature_importance.<feature_name> 的新列,可返回特征重要性。

REST 和命令行

要获取批量预测请求的局部特征重要性,请使用 model.batchPredict 方法,并将 feature_importance 参数设置为 true。

以下示例将 BigQuery 用于请求数据和结果;针对使用 Cloud Storage 的请求,请使用同一个参数。

在使用任何请求数据之前,请先进行以下替换:

  • endpoint:全球位置为 automl.googleapis.com,欧盟地区为 eu-automl.googleapis.com
  • project-id:您的 Google Cloud 项目 ID。
  • location:资源的位置:全球位置为 us-central1,欧盟位置为 eu
  • model-id:模型的 ID。例如 TBL543
  • dataset-id:预测数据所在的 BigQuery 数据集的 ID。
  • table-id:预测数据所在的 BigQuery 表的 ID。

    AutoML Tables 会为 project-id.dataset-id.table-id 中名为 prediction-<model_name>-<timestamp> 的预测结果创建一个子文件夹。

HTTP 方法和网址:

POST https://endpoint/v1beta1/projects/project-id/locations/location/models/model-id:batchPredict

请求 JSON 正文:

{
  "inputConfig": {
    "bigquerySource": {
      "inputUri": "bq://project-id.dataset-id.table-id"
    },
  },
  "outputConfig": {
    "bigqueryDestination": {
      "outputUri": "bq://project-id"
    },
  },
  "params": {"feature_importance": "true"}
}

如需发送请求,请选择以下方式之一:

curl

将请求正文保存在名为 request.json 的文件中,然后执行以下命令:

curl -X POST \
-H "Authorization: Bearer "$(gcloud auth application-default print-access-token) \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://endpoint/v1beta1/projects/project-id/locations/location/models/model-id:batchPredict"

PowerShell

将请求正文保存在名为 request.json 的文件中,然后执行以下命令:

$cred = gcloud auth application-default print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }

Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://endpoint/v1beta1/projects/project-id/locations/location/models/model-id:batchPredict" | Select-Object -Expand Content
批量预测是一项长时间运行的操作。您可以轮询操作状态或等待操作返回。了解详情

通过为每个特征添加一个名为 feature_importance.<feature_name> 的新列,可返回特征重要性。

使用局部特征重要性的注意事项:

  • 局部特征重要性结果仅在 2019 年 11 月 15 日或之后训练的模型中提供。

  • 不支持包含超过 100 万行或 300 列的批量预测请求启用局部特征重要性。

  • 每个局部特征重要性值仅说明该特征对相应行预测的影响程度。如需了解模型的整体行为,请使用模型特征重要性

  • 局部特征重要性值始终相对于基准值而言。在评估局部特征重要性结果时,请务必参考基准值。基准值只能从 Cloud Console 获取。

  • 局部特征重要性值完全取决于模型和用于训练模型的数据。它们只能说明模型从数据中发现的模式,并且无法探测数据中的任何基本关系。因此,某个特征的特征重要性高并不能说明该特征与目标之间的关系;它只说明了模型在预测中使用了该特征。

  • 如果预测包含完全不在训练数据范围内的数据,则局部特征重要性可能不会提供有意义的结果。

  • 生成特征重要性会增加预测所需的时间和计算资源。此外,您的请求使用的配额与没有特征重要性的预测请求不同。了解详情

  • 仅凭特征重要性值无法判断您的模型是否公正、无偏差或可靠。除特征重要性外,您还应仔细评估训练数据集、程序和评估指标。