The ML.EXPLAIN_PREDICT function

The ML.EXPLAIN_PREDICT function generates a predicted value and a sef of feature attributions per instance of the input data. Feature attributions indicate how much each feature in your model contributed to the final prediction for each given instance.

ML.EXPLAIN_PREDICT can be viewed as an extended version of ML.PREDICT. For information about Explainable AI, see Explainable AI Overview.

For information about the supported model types of each SQL statement and function, and all of the supported SQL statements and functions for each model type, read the End-to-end user journey for each model.

ML.EXPLAIN_PREDICT syntax

ML.EXPLAIN_PREDICT(MODEL model_name,
          {TABLE table_name | (query_statement)}
                 [, STRUCT<top_k_features INT64,
                           threshold FLOAT64,
                           num_integral_steps INT64> settings])

model_name

model_name is the name of the model used to generate explanations. If a default project is not configured, then prepend the project ID to the model name in following format: `[project_id].[dataset].[model]` (including the backticks). For example, `myproject.mydataset.mymodel`.

table_name

table_name is the name of the input table that contains the data to be evaluated. If a default project is not configured, then prepend the project ID to the model name in following format: `[project_id].[dataset].[model]` (including the backticks). For example, `myproject.mydataset.mymodel`.

The input column names in the table must contain the column names in the model, and their types should be compatible according to BigQuery implicit coercion rules.

If there are unused columns from the table, they will be passed through to the output columns.

query_statement

The query_statement clause specifies the standard SQL query that is used to generate the data to predict and explain the predicted results. See the Standard SQL Query Syntax page for the supported SQL syntax of the query_statement clause.

The input column names from the query must contain the column names in the model, and their types should be compatible according to BigQuery implicit coercion rules.

If there are unused columns from the query, they will be passed through to output columns.

If a TRANSFORM clause was present in the CREATE MODEL statement that created the model_name, then only the input columns present in the TRANSFORM clause can be present in the query_statement.

top_k_features

(Optional) top_k_features specifies how many top feature attribution pairs are generated per row of input data. The features are ranked by the absolute values of their attributions. top_k_features is of type INT64 and is part of the settings STRUCT.

By default, top_k_features is set to 5. If its value is greater than the number of features in the training data, the attributions of all features is returned.

threshold

(Optional) threshold can be set for binary classification models and is used as the cutoff between the two labels. Predictions above the threshold are treated as positive prediction. Predictions below the threshold are negative predictions. Feature attributions are returned only for the predicted label. threshold is type FLOAT64 and is part of the settings STRUCT.

The default value for threshold is 0.5 and can be set to any float between 0.0 and 1.0.

num_integral_steps

(Optional) num_integral_steps specifies the number of steps to sample between the example being explained and its baseline for approximating the integral in integrated gradients attribution methods. Increasing the value improves the precision of feature attributions, but can be slower and more computationally expensive.

This option only applies to Deep Neural Network (DNN) models which use integrated gradients attribution methods. By default, num_integral_steps is set to 50.

ML.EXPLAIN_PREDICT output

The ML.EXPLAIN_PREDICT function generates the following columns beside any passthrough columns:

  • predicted_<label_column_name>: Same as the prediction label in ML.PREDICT, this column is either the predicted value of the label for regression models or the predicted label class for classification models.
  • probability: The probability of the predicted label class. This field is only present for classification models.
  • top_feature_attributions: An ARRAY of STRUCTs containing the attributions of the top k features to the final prediction.
    • top_feature_attributions.feature: The feature name.
    • top_feature_attributions.attribution: Attribution of the feature to the final prediction.
  • baseline_prediction_value presents the following:
    • For linear models, the baseline_prediction_value is the intercept of the model.
    • For Deep Neural Network (DNN) models, the baseline_prediction_value is the mean across all numerical features and NULL for other types of features.
    • For boosted tree models, the baseline_prediction_value is equal to the bias term, which is the expected output of the model over the training dataset. See Tree SHAP documentation for more information.
  • prediction_value: The raw prediction value.
    • For regression, this is equal to the value of predicted_<label_column_name>.
    • For classification models, this is the logit value (i.e., log-odds) for the predicted class. By applying the softmax transformation to the logit values, the predicted class probabilities are obtained.
  • approximation_error:
    • The exact attribution methods like Tree SHAP has the property such that
      \(\texttt{baseline_prediction_value} + \sum{\texttt{feature_attributions}} = \texttt{prediction_value}\)
      Therefore there is no approximation error and this field is always 0. This applies to linear & logistic regression and boosted tree models.
    • In approximated attribution methods like Integrated Gradients, this field is greater than 0. The approximation error is defined as:
      \(\frac{|\texttt{prediction_value} - \texttt{baseline_prediction_value} - \sum{\texttt{feature_attributions}}|}{|\texttt{prediction_value} - \texttt{baseline_prediction_value}|}\)
      This applies to Deep Neural Network (DNN) models.

ML.EXPLAIN_PREDICT examples

The following examples assume that your model and input table are in your default project.

Explaining a prediction generated by a linear regression model

The following query uses the ML.EXPLAIN_PREDICT function to explain a prediction for a linear regression model by generating the top 3 attributions.

Assume a linear regression model stored in mydataset.mymodel was trained with the table mydataset.table with the following columns:

  • label
  • column1
  • column2
  • column3
  • column4
  • column5
SELECT
  *
FROM
  ML.EXPLAIN_PREDICT(MODEL `mydataset.mymodel`,
    (
    SELECT
      label,
      column1,
      column2,
      column3,
      column4,
      column5
    FROM
      `mydataset.mytable`), STRUCT(3 AS top_k_features))

Explaining a prediction generated by a boosted tree binary classification model

The following query uses the ML.EXPLAIN_PREDICT function to explain a prediction generated by a boosted tree binary classification model. It generates the top 3 attributions with a custom threshold.

Assume a boosted tree binary classification model stored in mydataset.mymodel is trained with the table mydataset.table with the following columns:

  • label
  • column1
  • column2
  • column3
  • column4
  • column5
SELECT
  *
FROM
  ML.EXPLAIN_PREDICT(MODEL `mydataset.mymodel`,
    (
    SELECT
      label,
      column1,
      column2,
      column3,
      column4,
      column5
    FROM
      `mydataset.mytable`), STRUCT(3 AS top_k_features, 0.7 AS threshold))

Explaining a prediction generated by a Deep Neural Network (DNN) classifier model

The following query uses the ML.EXPLAIN_PREDICT function to explain a prediction generated by a DNN classifier model.

Assume a DNN classifier is stored in mydataset.mymodel and trained with the table mydataset.table with the following columns:

  • label
  • column1
  • column2
  • column3
  • column4
  • column5
SELECT
  *
FROM
  ML.EXPLAIN_PREDICT(MODEL `mydataset.mymodel`,
    (
    SELECT
      label,
      column1,
      column2,
      column3,
      column4,
      column5
    FROM
      `mydataset.mytable`), STRUCT(3 AS top_k_features, 30 AS num_integral_steps))