The ML.EXPLAIN_PREDICT function
The ML.EXPLAIN_PREDICT
function generates a predicted value and a set of feature
attributions per instance of the input data. Feature attributions indicate how
much each feature in your model contributed to the final prediction for each given
instance.
ML.EXPLAIN_PREDICT
can be viewed as an extended version of ML.PREDICT
. For
information about Explainable AI, see
Explainable AI overview.
For information about the supported model types of each SQL statement and function, and all of the supported SQL statements and functions for each model type, read the End-to-end user journey for each model.
ML.EXPLAIN_PREDICT
syntax
ML.EXPLAIN_PREDICT(MODEL model_name, {TABLE table_name | (query_statement)} [, STRUCT<top_k_features INT64, threshold FLOAT64, num_integral_steps INT64> settings])
model_name
model_name
is the name of the model used to generate explanations.
If a default project is not configured, then prepend the project ID to
the model name in following format:
`[project_id].[dataset].[model]`
(including the backticks). For example,
`myproject.mydataset.mymodel`
.
table_name
table_name
is the name of the input table that contains the data to be evaluated.
If a default project is not configured, then prepend the project ID to
the model name in following format:
`[project_id].[dataset].[model]`
(including the backticks). For example,
`myproject.mydataset.mymodel`
.
The input column names in the table must contain the column names in the model, and their types should be compatible according to BigQuery implicit coercion rules.
If there are unused columns from the table, they are passed through to the output columns.
query_statement
The query_statement
clause specifies the GoogleSQL query that is used to
generate the data to predict and explain the predicted results. See the
GoogleSQL Query Syntax
page for the supported SQL syntax of the query_statement
clause.
The input column names from the query must contain the column names in the model, and their types should be compatible according to BigQuery implicit coercion rules.
If there are unused columns from the query, they are passed through to output columns.
If a TRANSFORM
clause was present in the CREATE MODEL
statement that
created the model_name
, then only the input columns present in the
TRANSFORM
clause can be present in the query_statement
.
top_k_features
(Optional) top_k_features
specifies how many top feature attribution pairs
are generated per row of input data. The features are ranked by the
absolute values of their attributions. top_k_features
is of
type INT64
and is part of the settings STRUCT
.
By default, top_k_features
is set to 5. If its value is greater than the
number of features in the training data, the attributions of all features are
returned.
threshold
(Optional) threshold
can be set for binary classification models
and is used as the cutoff between the two labels. Predictions above the
threshold are treated as positive prediction. Predictions below the threshold
are negative predictions. Feature attributions are returned only for the
predicted label. threshold
is type
FLOAT64
and is part of the settings STRUCT
.
The default value for threshold
is 0.5 and can be set to any float between 0.0
and 1.0.
num_integral_steps
(Optional) num_integral_steps
specifies the number of steps to sample between
the example being explained and its baseline for approximating the integral in
integrated gradients
attribution methods. Increasing the value improves the precision of feature
attributions, but can be slower and more computationally expensive.
This option only applies to Deep Neural Network (DNN) models which use
integrated gradients attribution methods. By default, num_integral_steps
is
set to 15.
ML.EXPLAIN_PREDICT
output
The ML.EXPLAIN_PREDICT
function generates the following columns beside any
passthrough columns:
predicted_<label_column_name>
: Same as the prediction label inML.PREDICT
, this column is either the predicted value of the label for regression models or the predicted label class for classification models.probability
: The probability of the predicted label class. This field is only present for classification models.top_feature_attributions
: AnARRAY
ofSTRUCT
s containing the attributions of the top k features to the final prediction.top_feature_attributions.feature
: The feature name.top_feature_attributions.attribution
: Attribution of the feature to the final prediction.
baseline_prediction_value
presents the following:- For linear models, the
baseline_prediction_value
is the intercept of the model. - For Deep Neural Network (DNN) models, the
baseline_prediction_value
is the mean across all numerical features and NULL for other types of features. - For boosted tree and random forest models, the
baseline_prediction_value
is equal to the bias term, which is the expected output of the model over the training dataset. See Tree SHAP documentation for more information.
- For linear models, the
prediction_value
: The raw prediction value.approximation_error
:- The exact attribution methods like Tree SHAP have the property such that
\(\texttt{baseline_prediction_value} + \sum{\texttt{feature_attributions}} = \texttt{prediction_value}\)Therefore there is no approximation error and this field is always 0. This applies to linear & logistic regression, boosted tree and random forest models.
- In approximated attribution methods like Integrated Gradients, this field is
greater than 0. The approximation error is defined as:
\(\frac{|\texttt{prediction_value} - \texttt{baseline_prediction_value} - \sum{\texttt{feature_attributions}}|}{|\texttt{prediction_value} - \texttt{baseline_prediction_value}|}\)This applies to Deep Neural Network (DNN) models.
- The exact attribution methods like Tree SHAP have the property such that
ML.EXPLAIN_PREDICT
examples
The following examples assume that your model and input table are in your default project.
Explain a prediction generated by a linear regression model
The following query uses the ML.EXPLAIN_PREDICT
function to explain a prediction for a
linear regression model by generating the top 3 attributions.
Assume a linear regression model stored in mydataset.mymodel
was trained with the
table mydataset.table
with the following columns:
label
column1
column2
column3
column4
column5
SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `mydataset.mymodel`, ( SELECT label, column1, column2, column3, column4, column5 FROM `mydataset.mytable`), STRUCT(3 AS top_k_features))
Explain a prediction generated by a boosted tree or a random forest binary classification model
The following query uses the ML.EXPLAIN_PREDICT
function to explain a
prediction generated by a boosted tree or a random forest binary classification model. It generates
the top 3 attributions with a custom threshold.
Assume a boosted tree or a random forest binary classification model stored in mydataset.mymodel
is
trained with the table mydataset.table
with the following columns:
label
column1
column2
column3
column4
column5
SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `mydataset.mymodel`, ( SELECT label, column1, column2, column3, column4, column5 FROM `mydataset.mytable`), STRUCT(3 AS top_k_features, 0.7 AS threshold))
Explain a prediction generated by a Deep Neural Network (DNN) classifier model
The following query uses the ML.EXPLAIN_PREDICT
function to explain a prediction
generated by a DNN classifier model.
Assume a DNN classifier is stored in mydataset.mymodel
and trained with the
table mydataset.table
with the following columns:
label
column1
column2
column3
column4
column5
SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `mydataset.mymodel`, ( SELECT label, column1, column2, column3, column4, column5 FROM `mydataset.mytable`), STRUCT(3 AS top_k_features, 30 AS num_integral_steps))