ML.WEIGHTS
function
The ML.WEIGHTS
function allows you to see the underlying weights used by a
model during prediction. This function applies to
linear & logistic regression models and matrix factorization models.
For information about model weights support in BigQuery ML, see Model weights overview.
For information about supported model types of each SQL statement and function, and all supported SQL statements and functions for each model type, read End-to-end user journey for each model.
ML.WEIGHTS
syntax
In the following example syntax, standardize
is an optional parameter that
determines whether the model weights should be standardized to assume that all
features have a mean of zero and a standard deviation of one. Standardizing
the weights allows the absolute magnitude of the weights to be compared to
each other. The default value is false. The value that is supplied must be the
only field in a
STRUCT.
ML.WEIGHTS(MODEL `project-id.dataset.model` [, STRUCT(<T> as standardize)])
Replace the following:
project-id
: Your project ID.dataset
: The BigQuery dataset that contains the model.model
: The name of the model.
ML.WEIGHTS
output
ML.WEIGHTS
has different output columns for different models.
Linear & logistic regression models
For linear & logistic regression models, ML.WEIGHTS
returns the following columns:
processed_input
— The name of the model feature input. The value of this column matches the name of the column in theSELECT
statement used during training.weight
— The weight of each feature. For numerical columns, weight contains a value and thecategory_weights
column is NULL. For non-numeric columns that are converted to one-hot encoding, the weight column is NULL and the category_weights column is an ARRAY of category names and weights for each category.category_weights.category
— The category name if the input column is non-numeric.category_weights.weight
— The category's weight if the input column is non-numeric.class_label
— For multiclass models,class_label
is the label for a given weight. The output includes one row per<class_label, processed_input>
combination.
If the TRANSFORM
clause was present in the CREATE MODEL
statement that
created model
, ML.WEIGHTS
outputs the weights of
TRANSFORM
output
features. The weights are denormalized by default, with the option to get
normalized weights, exactly like models that are created without
TRANSFORM
.
Matrix factorization models
For matrix factorization models, ML.WEIGHTS
returns the following columns:
processed_input
: The name of the user or item column. The value of this column matches the name of the column in theSELECT
statement used during training.feature
: The names of the specific users or items used during training.factor_weights
: AnARRAY
of factors and weights for each factor.factor_weights.factor
: A latent factor from training. AnINT64
from 1 toNUM_FACTORS
.factor_weights.weight
: The weight of the respective factor and feature.intercept
: The intercept or bias term for a feature.
Finally, there is an additional row in the table that contains the
global__intercept__
calculated from the input data. This row has a NULL
value for processed_input
and factor_weights
. For
implicit feedback
models, global__intercept__
is always 0.
ML.WEIGHTS
examples
ML.WEIGHTS
without standardization
The following example retrieves weight information from mymodel
in
mydataset
. The dataset is in your default project.
The query returns the weights associated with each one-hot encoded category for
the input column input_col
.
SELECT category, weight FROM UNNEST(( SELECT category_weights FROM ML.WEIGHTS(MODEL `mydataset.mymodel`) WHERE processed_input = 'input_col'))
This command uses the UNNEST
function because the category_weights
column is a nested repeated column.
ML.WEIGHTS
with standardization
The following example retrieves weight information from mymodel
in
mydataset
. The dataset is in your default project.
The query retrieves standardized weights, which assume all features have a mean of zero and a standard deviation of one.
SELECT * FROM ML.WEIGHTS(MODEL `mydataset.mymodel`, STRUCT(true AS standardize))