The ML.PREDICT function

ML.PREDICT function

The ML.PREDICT function can be used to predict outcomes using the model. Prediction can be done during model creation, after model creation, or after a failure (as long as at least 1 iteration is finished). ML.PREDICT always uses the model weights from the last successful iteration.

The output of the ML.PREDICT function has as many rows as the input table, and it includes all columns from the input table and all output columns from the model. The output column names for the model are predicted_<label_column_name> and (for logistic regression models) predicted_<label_column_name>_probs. In both columns, label_column_name is the name of the input label column used during training.

  • For logistic regression models:

    • The predicted_<label_column_name>_probs output column is an array of STRUCTs of type [<label, prob>] that contains the predicted probability of each label.
    • The predicted_<label_column_name> output column is one of the two input labels, depending on which label has the higher predicted probability.
  • For multiclass logistic regression models:

    • The predicted_<label_column_name>_probs output column is the probability for each class label calculated using a softmax function.
    • The predicted_<label_column_name> output column is the label with the highest predicted probability score.
  • For linear regression models:

    • The predicted_<label_column_name> output column is the predicted value of the label.
  • For k-means models:

    • Returns columns labeled centroid_id and nearest_centroids_distance. nearest_centroids_distance contains an ARRAY of STRUCTs named nearest_centroids_distance, which contains the distances to the nearest k clusters, where k is equal to the lesser of num_clusters or 5. If this model was created with the option standardize_features set to TRUE, then the model computes these distances using standardized features; otherwise, it computes these distances on non-standardized features.
  • For TensorFlow models:

    • The input must be convertible to the type expected by the model.
    • The output is the output of the TensorFlow model's predict method.

ML.PREDICT syntax

ML.PREDICT(MODEL model_name,
          {TABLE table_name | (query_statement)}
          [, STRUCT<threshold FLOAT64> settings)])

model_name

model_name is the name of the model you're evaluating. If you do not have a default project configured, prepend the project ID to the model name in following format: `[PROJECT_ID].[DATASET].[MODEL]` (including the backticks); for example, `myproject.mydataset.mymodel`.

table_name

table_name is the name of the input table that contains the evaluation data. If you do not have a default project configured, prepend the project ID to the table name in following format: `[PROJECT_ID].[DATASET].[TABLE]` (including the backticks); for example, `myproject.mydataset.mytable`.

The input column names in the table must contain the column names in the model, and their types should be compatible according to BigQuery implicit coercion rules.

If there are unused columns from the table, they will be passed through to output columns.

query_statement

The query_statement clause specifies the standard SQL query that is used to generate the evaluation data. See the Standard SQL Query Syntax page for the supported SQL syntax of the query_statement clause.

The input column names from the query must contain the column names in the model, and their types should be compatible according to BigQuery implicit coercion rules.

If there are unused columns from the query, they will be passed through to output columns.

If the TRANSFORM clause was present in the CREATE MODEL statement that created model_name, then only the input columns present in the TRANSFORM clause must be present in query_statement.

threshold

(Optional) Threshold is a custom threshold for your binary logistic regression model and is used as the cutoff between the two labels. Predictions above the threshold are treated as positive prediction. Predictions below the threshold are negative predictions. The threshold value is type FLOAT64 and is part of the settings STRUCT. The default value is 0.5.

Imputation

In statistics, imputation is used to replace missing data with substituted values. When you train a model in BigQuery ML, NULL values are treated as missing data. When you predict outcomes in BigQuery ML, missing values can occur when BigQuery ML encounters a NULL value or a previously unseen value. BigQuery ML handles missing data based on whether the column is numeric, one-hot encoded, or a timestamp.

Numerical columns

In both training and prediction, NULL values in numeric columns are replaced with the mean value as calculated by the feature column in the original input data.

One-hot/Multi-hot encoded columns

In both training and prediction, NULL values in the encoded columns are mapped to an additional category that is added to the data. Previously unseen data is assigned a weight of 0 during prediction.

Timestamp columns

TIMESTAMP columns use a mixture of imputation methods from both standardized and one-hot encoded columns. For the generated unix time column, BigQuery ML replaces values with the mean unix time across the original columns. For other generated values, BigQuery ML assigns them to the respective NULL category for each extracted feature.

STRUCT columns

In both training and prediction, each field of the STRUCT is imputed according to its type.

ML.PREDICT examples

The following examples assume your model and input table are in your default project.

Predicting an outcome

The following query uses the ML.PREDICT function to predict an outcome. The query returns these columns:

  • predicted_label
  • label
  • column1
  • column2
SELECT
  *
FROM
  ML.PREDICT(MODEL `mydataset.mymodel`,
    (
    SELECT
      label,
      column1,
      column2
    FROM
      `mydataset.mytable`))

Comparing predictions from two different models

In this example, the following query is used to create the first model.

CREATE MODEL
  `mydataset.mymodel1`
OPTIONS
  (model_type='linear_reg',
    input_label_cols=['label'],
  ) AS
SELECT
  label,
  input_column1
FROM
  `mydataset.mytable`

The following query is used to create the second model.

CREATE MODEL
  `mydataset.mymodel2`
OPTIONS
  (model_type='linear_reg',
    input_label_cols=['label'],
  ) AS
SELECT
  label,
  input_column2
FROM
  `mydataset.mytable`

The following query uses the ML.PREDICT function to compare the output of the two models.

SELECT
  label,
  predicted_label1,
  predicted_label AS predicted_label2
FROM
  ML.PREDICT(MODEL `mydataset.mymodel2`,
    (
    SELECT
      * EXCEPT (predicted_label),
          predicted_label AS predicted_label1
    FROM
      ML.PREDICT(MODEL `mydataset.mymodel1`,
        TABLE `mydataset.mytable`)))

Specifying a custom threshold

The following query uses the ML.PREDICT function by specifying input data and a custom threshold of 0.55.

SELECT
  *
FROM
  ML.PREDICT(MODEL `mydataset.mymodel`,
    (
    SELECT
      custom_label,
      column1,
      column2
    FROM
      `mydataset.mytable`),
    STRUCT(0.55 AS threshold))

Predicting an outcome with an imported TensorFlow model

The following query uses the ML.PREDICT function predict outcomes using an imported TensorFlow model. The input_data table contains inputs in the schema expected by my_model. See the CREATE MODEL statement for TensorFlow models for more information.

SELECT *
FROM ML.PREDICT(MODEL `my_project`.my_dataset.my_model,
               (SELECT * FROM input_data))

Predicting an outcome with a model trained with the TRANSFORM clause

The following query trains a model using the TRANSFORM clause:

CREATE MODEL m
  TRANSFORM(f1 + f2 as c, label)
  OPTIONS(...)
AS SELECT f1, f2, f3, label FROM t;

Because the f3 column doesn't appear in the TRANSFORM clause, the following prediction query omits that column in the query_statement:

SELECT * FROM ML.PREDICT(
  MODEL m, (SELECT f1, f2 FROM t1));

If f3 is provided in the SELECT statement, it isn't used for calculating predictions but is instead passed through for use in the rest of the SQL statement.

Was this page helpful? Let us know how we did:

Send feedback about...

BigQuery ML