Get TabNet batch predictions

Stay organized with collections Save and categorize content based on your preferences.

This page shows you how to make a batch prediction request to your trained classification or regression model using the Google Cloud console or the Vertex AI API.

A batch prediction request is an asynchronous request (as opposed to online prediction, which is a synchronous request). You request batch predictions directly from the model resource without needing to deploy the model to an endpoint. For tabular data, use batch predictions when you don't require an immediate response and want to process accumulated data by using a single request.

To make a batch prediction request, you specify an input source and an output format where Vertex AI stores predictions results.

Before you begin

Before you can make a batch prediction request, you must first train a model.

Input data

The input data for batch prediction requests is the data that your model uses to make predictions. For classification or regression models, you can provide input data in one of two formats:

  • BigQuery tables
  • CSV objects in Cloud Storage

We recommend that you use the same format for your input data as you used for training the model. For example, if you trained your model using data in BigQuery, it is best to use a BigQuery table as the input for your batch prediction. Because Vertex AI treats all CSV input fields as strings, mixing training and input data formats may cause errors.

Your data source must contain tabular data that includes all of the columns, in any order, that were used to train the model. You can include columns that were not in the training data, or that were in the training data but excluded from use for training. These extra columns are included in the output but don't affect the prediction results.

Input data requirements

BigQuery table

If you choose a BigQuery table as the input, you must ensure the following:

  • BigQuery data source tables must be no larger than 100 GB.
  • If the table is in a different project, you must grant the BigQuery Data Editor role to the Vertex AI service account in that project.

CSV file

If you choose a CSV object in Cloud Storage as the input, you must ensure the following:

  • The data source must begin with a header row with the column names.
  • Each data source object must not be larger than 10 GB. You can include multiple files, up to a maximum amount of 100 GB.
  • If the Cloud Storage bucket is in a different project, you must grant the Storage Object Creator role to the Vertex AI service account in that project.
  • You must enclose all strings in double quotation marks (").

Output format

The output format of your batch prediction request doesn't need to be the same as the format that you used for the input. For example, if you used BigQuery table as the input, you can output the results to a CSV object in Cloud Storage.

Make a batch prediction request to your model

To make batch prediction requests, you can use the Google Cloud console or the Vertex AI API. The input data source can be CSV objects stored in a Cloud Storage bucket or BigQuery tables. Depending on the amount of data that you submit as input, a batch prediction task can take some time to complete.

Google Cloud console

Use the Google Cloud console to request a batch prediction.

  1. In the Google Cloud console, in the Vertex AI section, go to the Batch predictions page.

    Go to the Batch predictions page

  2. Click Create to open the New batch prediction window.

  3. For Define your batch prediction, complete the following steps:

    1. Enter a name for the batch prediction.
    2. For Model name, select the name of the model to use for this batch prediction.
    3. For Version, select the model version to use for this batch prediction.
    4. For Select source, select whether your source input data is a CSV file on Cloud Storage or a table in BigQuery.
      • For CSV files, specify the Cloud Storage location where your CSV input file is located.
      • For BigQuery tables, specify the project ID where the table is located, the BigQuery dataset ID, and the BigQuery table or view ID.
    5. For the Output, select CSV or BigQuery.
      • For CSV, specify the Cloud Storage bucket where Vertex AI stores your output.
      • For BigQuery, you can specify a project ID or an existing dataset:
        • To specify the project ID, enter the project ID in the **Google Cloud project ID** field. Vertex AI creates a new output dataset for you.
        • To specify an existing dataset, enter its BigQuery path in the Google Cloud project ID field, such as bq://projectid.datasetid.
  4. Optional: Model Monitoring analysis is available in Preview. See the Prerequisites for adding skew detection configuration to your batch prediction job.

    1. Click to toggle on Enable model monitoring for this batch prediction.

    2. Select a Training data source.

    3. Enter the Training data path for the training data source that you selected.

  5. Click Create.

API: BigQuery

REST

You use the batchPredictionJobs.create method to request a batch prediction.

Before using any of the request data, make the following replacements:

  • LOCATION_ID: Region where Model is stored and batch prediction job is executed. For example, us-central1.
  • PROJECT_ID: Your project ID
  • BATCH_JOB_NAME: Display name for the batch job
  • MODEL_ID: The ID for the model to use for making predictions
  • INPUT_URI: Reference to the BigQuery data source. In the form:
    bq://bqprojectId.bqDatasetId.bqTableId
    
  • OUTPUT_URI: Reference to the BigQuery destination (where the predictions will be written). Specify the project ID and, optionally, an existing dataset ID. If you specify just the project ID, Vertex AI creates a new output dataset for you. Use the following form:
    bq://bqprojectId.bqDatasetId
    
  • MACHINE_TYPE: The machine resources to be used for this batch prediction job. Learn more.
  • STARTING_REPLICA_COUNT: The starting number of nodes for this batch prediction job. The node count can be increased or decreased as required by load, up to the maximum number of nodes, but will never fall below this number.
  • MAX_REPLICA_COUNT: The maximum number of nodes for this batch prediction job. The node count can be increased or decreased as required by load, but will never exceed the maximum. Optional, defaults to 10.

HTTP method and URL:

POST https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/batchPredictionJobs

Request JSON body:

{
  "displayName": "BATCH_JOB_NAME",
  "model": "MODEL_ID",
  "inputConfig": {
    "instancesFormat": "bigquery",
    "bigquerySource": {
      "inputUri": "INPUT_URI"
    }
  },
  "outputConfig": {
    "predictionsFormat": "bigquery",
    "bigqueryDestination": {
      "outputUri": "OUTPUT_URI"
    }
  },
  "dedicatedResources": {
    "machineSpec": {
      "machineType": "MACHINE_TYPE",
      "acceleratorCount": "0"
    },
    "startingReplicaCount": STARTING_REPLICA_COUNT,
    "maxReplicaCount": MAX_REPLICA_COUNT
  },

}

To send your request, choose one of these options:

curl

Save the request body in a file called request.json, and execute the following command:

curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/batchPredictionJobs"

PowerShell

Save the request body in a file called request.json, and execute the following command:

$cred = gcloud auth print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }

Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/batchPredictionJobs" | Select-Object -Expand Content

You should receive a JSON response similar to the following:

{
  "name": "projects/PROJECT_ID/locations/LOCATION_ID/batchPredictionJobs/67890",
  "displayName": "batch_job_1 202005291958",
  "model": "projects/12345/locations/us-central1/models/5678",
  "state": "JOB_STATE_PENDING",
  "inputConfig": {
    "instancesFormat": "bigquery",
    "bigquerySource": {
      "inputUri": "INPUT_URI"
    }
  },
  "outputConfig": {
    "predictionsFormat": "bigquery",
    "bigqueryDestination": {
        "outputUri": bq://12345
    }
  },
  "dedicatedResources": {
    "machineSpec": {
      "machineType": "n1-standard-32",
      "acceleratorCount": "0"
    },
    "startingReplicaCount": 2,
    "maxReplicaCount": 6
  },
  "manualBatchTuningParameters": {
    "batchSize": 4
  },
  "generateExplanation": false,
  "outputInfo": {
    "bigqueryOutputDataset": "bq://12345.reg_model_2020_10_02_06_04
  }
  "state": "JOB_STATE_PENDING",
  "createTime": "2020-09-30T02:58:44.341643Z",
  "updateTime": "2020-09-30T02:58:44.341643Z",
}

Java

To learn how to install and use the client library for Vertex AI, see Vertex AI client libraries. For more information, see the Vertex AI Java API reference documentation.

In the following sample, replace INSTANCES_FORMAT and PREDICTIONS_FORMAT with `bigquery`. To learn how to replace the other placeholders, see the `REST & CMD LINE` tab of this section.
import com.google.cloud.aiplatform.v1.BatchPredictionJob;
import com.google.cloud.aiplatform.v1.BigQueryDestination;
import com.google.cloud.aiplatform.v1.BigQuerySource;
import com.google.cloud.aiplatform.v1.JobServiceClient;
import com.google.cloud.aiplatform.v1.JobServiceSettings;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.ModelName;
import com.google.gson.JsonObject;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;

public class CreateBatchPredictionJobBigquerySample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String project = "PROJECT";
    String displayName = "DISPLAY_NAME";
    String modelName = "MODEL_NAME";
    String instancesFormat = "INSTANCES_FORMAT";
    String bigquerySourceInputUri = "BIGQUERY_SOURCE_INPUT_URI";
    String predictionsFormat = "PREDICTIONS_FORMAT";
    String bigqueryDestinationOutputUri = "BIGQUERY_DESTINATION_OUTPUT_URI";
    createBatchPredictionJobBigquerySample(
        project,
        displayName,
        modelName,
        instancesFormat,
        bigquerySourceInputUri,
        predictionsFormat,
        bigqueryDestinationOutputUri);
  }

  static void createBatchPredictionJobBigquerySample(
      String project,
      String displayName,
      String model,
      String instancesFormat,
      String bigquerySourceInputUri,
      String predictionsFormat,
      String bigqueryDestinationOutputUri)
      throws IOException {
    JobServiceSettings settings =
        JobServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();
    String location = "us-central1";

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (JobServiceClient client = JobServiceClient.create(settings)) {
      JsonObject jsonModelParameters = new JsonObject();
      Value.Builder modelParametersBuilder = Value.newBuilder();
      JsonFormat.parser().merge(jsonModelParameters.toString(), modelParametersBuilder);
      Value modelParameters = modelParametersBuilder.build();
      BigQuerySource bigquerySource =
          BigQuerySource.newBuilder().setInputUri(bigquerySourceInputUri).build();
      BatchPredictionJob.InputConfig inputConfig =
          BatchPredictionJob.InputConfig.newBuilder()
              .setInstancesFormat(instancesFormat)
              .setBigquerySource(bigquerySource)
              .build();
      BigQueryDestination bigqueryDestination =
          BigQueryDestination.newBuilder().setOutputUri(bigqueryDestinationOutputUri).build();
      BatchPredictionJob.OutputConfig outputConfig =
          BatchPredictionJob.OutputConfig.newBuilder()
              .setPredictionsFormat(predictionsFormat)
              .setBigqueryDestination(bigqueryDestination)
              .build();
      String modelName = ModelName.of(project, location, model).toString();
      BatchPredictionJob batchPredictionJob =
          BatchPredictionJob.newBuilder()
              .setDisplayName(displayName)
              .setModel(modelName)
              .setModelParameters(modelParameters)
              .setInputConfig(inputConfig)
              .setOutputConfig(outputConfig)
              .build();
      LocationName parent = LocationName.of(project, location);
      BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob);
      System.out.format("response: %s\n", response);
      System.out.format("\tName: %s\n", response.getName());
    }
  }
}

Python

To learn how to install and use the client library for Vertex AI, see Vertex AI client libraries. For more information, see the Vertex AI Python API reference documentation.

In the following sample, set the `instances_format` and `predictions_format` parameters to `"bigquery"`. To learn how to set the other parameters, see the `REST & CMD LINE` tab of this section.
from google.cloud import aiplatform_v1beta1
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value


def create_batch_prediction_job_bigquery_sample(
    project: str,
    display_name: str,
    model_name: str,
    instances_format: str,
    bigquery_source_input_uri: str,
    predictions_format: str,
    bigquery_destination_output_uri: str,
    location: str = "us-central1",
    api_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
    # The AI Platform services require regional API endpoints.
    client_options = {"api_endpoint": api_endpoint}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    client = aiplatform_v1beta1.JobServiceClient(client_options=client_options)
    model_parameters_dict = {}
    model_parameters = json_format.ParseDict(model_parameters_dict, Value())

    batch_prediction_job = {
        "display_name": display_name,
        # Format: 'projects/{project}/locations/{location}/models/{model_id}'
        "model": model_name,
        "model_parameters": model_parameters,
        "input_config": {
            "instances_format": instances_format,
            "bigquery_source": {"input_uri": bigquery_source_input_uri},
        },
        "output_config": {
            "predictions_format": predictions_format,
            "bigquery_destination": {"output_uri": bigquery_destination_output_uri},
        },
        # optional
        "generate_explanation": True,
    }
    parent = f"projects/{project}/locations/{location}"
    response = client.create_batch_prediction_job(
        parent=parent, batch_prediction_job=batch_prediction_job
    )
    print("response:", response)

API: Cloud Storage

REST

You use the batchPredictionJobs.create method to request a batch prediction.

Before using any of the request data, make the following replacements:

  • LOCATION_ID: Region where Model is stored and batch prediction job is executed. For example, us-central1.
  • PROJECT_ID: Your project ID
  • BATCH_JOB_NAME: Display name for the batch job
  • MODEL_ID: The ID for the model to use for making predictions
  • URI: Paths (URIs) to the Cloud Storage buckets containing the training data. There can be more than one. Each URI has the form:
    gs://bucketName/pathToFileName
    
  • OUTPUT_URI_PREFIX: Path to a Cloud Storage destination where the predictions will be written. Vertex AI writes batch predictions to a timestamped subdirectory of this path. Set this value to a string with the following format:
    gs://bucketName/pathToOutputDirectory
    
  • MACHINE_TYPE: The machine resources to be used for this batch prediction job. Learn more.
  • STARTING_REPLICA_COUNT: The starting number of nodes for this batch prediction job. The node count can be increased or decreased as required by load, up to the maximum number of nodes, but will never fall below this number.
  • MAX_REPLICA_COUNT: The maximum number of nodes for this batch prediction job. The node count can be increased or decreased as required by load, but will never exceed the maximum. Optional, defaults to 10.

HTTP method and URL:

POST https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/batchPredictionJobs

Request JSON body:

{
  "displayName": "BATCH_JOB_NAME",
  "model": "MODEL_ID",
  "inputConfig": {
    "instancesFormat": "csv",
    "gcsSource": {
      "uris": [
        URI1,...
      ]
    },
  },
  "outputConfig": {
    "predictionsFormat": "csv",
    "gcsDestination": {
      "outputUriPrefix": "OUTPUT_URI_PREFIX"
    }
  },
  "dedicatedResources": {
    "machineSpec": {
      "machineType": "MACHINE_TYPE",
      "acceleratorCount": "0"
    },
    "startingReplicaCount": STARTING_REPLICA_COUNT,
    "maxReplicaCount": MAX_REPLICA_COUNT
  },

}

To send your request, choose one of these options:

curl

Save the request body in a file called request.json, and execute the following command:

curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/batchPredictionJobs"

PowerShell

Save the request body in a file called request.json, and execute the following command:

$cred = gcloud auth print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }

Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/batchPredictionJobs" | Select-Object -Expand Content

You should receive a JSON response similar to the following:

{
  "name": "projects/PROJECT__ID/locations/LOCATION_ID/batchPredictionJobs/67890",
  "displayName": "batch_job_1 202005291958",
  "model": "projects/12345/locations/us-central1/models/5678",
  "state": "JOB_STATE_PENDING",
  "inputConfig": {
    "instancesFormat": "csv",
    "gcsSource": {
      "uris": [
        "gs://bp_bucket/reg_mode_test"
      ]
    }
  },
  "outputConfig": {
    "predictionsFormat": "csv",
    "gcsDestination": {
      "outputUriPrefix": "OUTPUT_URI_PREFIX"
    }
  },
  "dedicatedResources": {
    "machineSpec": {
      "machineType": "n1-standard-32",
      "acceleratorCount": "0"
    },
    "startingReplicaCount": 2,
    "maxReplicaCount": 6
  },
  "manualBatchTuningParameters": {
    "batchSize": 4
  }
  "outputInfo": {
    "gcsOutputDataset": "OUTPUT_URI_PREFIX/prediction-batch_job_1 202005291958-2020-09-30T02:58:44.341643Z"
  }
  "state": "JOB_STATE_PENDING",
  "createTime": "2020-09-30T02:58:44.341643Z",
  "updateTime": "2020-09-30T02:58:44.341643Z",
}

Retrieve batch prediction results

Vertex AI sends the output of batch predictions to the destination that you specified, which can be either BigQuery or Cloud Storage.

BigQuery

Output dataset

If you are using BigQuery, the output of batch prediction is stored in an output dataset. If you had provided a dataset to Vertex AI, the name of the dataset (BQ_DATASET_NAME) is the name you had provided earlier. If you did not provide an output dataset, Vertex AI created one for you. You can find its name (BQ_DATASET_NAME) with the following steps:

  1. In the Google Cloud console, go to the Vertex AI Batch predictions page.
    Go to the Batch predictions page
  2. Select the prediction you created.
  3. The output dataset is given in Export location. The dataset name is formatted as follows: prediction_MODEL_NAME_TIMESTAMP
Output tables

The output dataset contains one or more of the following three output tables:

  • Predictions table

    This table contains a row for every row in your input data where a prediction was requested (i.e. where TARGET_COLUMN_NAME = null).

  • Errors table

    This table contains a row for each non-critical error encountered during batch prediction. Each non-critical error corresponds with a row in the input data that Vertex AI could not return a forecast for.

Predictions table

The name of the table (BQ_PREDICTIONS_TABLE_NAME) is formed by appending `predictions_` with the timestamp of when the batch prediction job started: predictions_TIMESTAMP

To retrieve predictions, go to the BigQuery page.
Go to BigQuery

The format of the query depends on your model type:

Classification:

SELECT predicted_TARGET_COLUMN_NAME.classes AS classes,
predicted_TARGET_COLUMN_NAME.scores AS scores
FROM BQ_DATASET_NAME.BQ_PREDICTIONS_TABLE_NAME

"classes" is the list of potential classes, and "scores" are the corresponding confidence scores.

Regression:

SELECT predicted_TARGET_COLUMN_NAME.value
FROM BQ_DATASET_NAME.BQ_PREDICTIONS_TABLE_NAME

You can find feature importance in the predictions table as well. To access importance for a feature BQ_FEATURE_NAME, run the following query:

SELECT predicted_TARGET_COLUMN_NAME.feature_importance.BQ_FEATURE_NAME FROM BQ_DATASET_NAME.BQ_PREDICTIONS_TABLE_NAME
  

Errors table

The name of the table (BQ_ERRORS_TABLE_NAME) is formed by appending `errors_` with the timestamp of when the batch prediction job started: errors_TIMESTAMP To retrieve the errors validation table:
  1. In the console, go to the BigQuery page.
    Go to BigQuery
  2. Run the following query:
    SELECT * FROM BQ_DATASET_NAME.BQ_ERRORS_TABLE_NAME
          
The errors are stored in the following columns:
  • errors_TARGET_COLUMN_NAME.code
  • errors_TARGET_COLUMN_NAME.message

Cloud Storage

If you specified Cloud Storage as your output destination, the results of your batch prediction request are returned as CSV objects in a new folder in the bucket you specified. The name of the folder is the name of your model, prepended with "prediction-" and appended with the timestamp of when the batch prediction job started. You can find the Cloud Storage folder name in the Batch predictions tab for your model.

The Cloud Storage folder contains two kinds of objects:
  • Prediction objects

    The prediction objects are named `predictions_1.csv`, `predictions_2.csv`, and so on. They contain a header row with the column names, and a row for every prediction returned. In the prediction objects, Vertex AI returns your prediction data and creates one or more new columns for the prediction results based on your model type:

    • Classification: For each potential value of your target column, a column named TARGET_COLUMN_NAME_VALUE_score is added to the results. This column contains the score, or confidence estimate, for that value.
    • Regression: The predicted value for that row is returned in a column named predicted_TARGET_COLUMN_NAME. The prediction interval is not returned for CSV output.
  • Error objects

    The error objects are named `errors_1.csv`, `errors_2.csv`, and so on. They contain a header row, and a row for every row in your input data that Vertex AI could not return a prediction (for example, if a non-nullable feature was null) for.

Note: If the results are large, it is split into multiple objects.

Feature importance is not available for batch prediction results returned in Cloud Storage.

Interpret prediction results

Classification

Classification models return a confidence score.

The confidence score communicates how strongly your model associates each class or label with a test item. The higher the number, the higher the model's confidence that the label should be applied to that item. You decide how high the confidence score must be for you to accept the model's results.

Regression

Regression models return a prediction value.

For batch prediction results stored in BigQuery, TabNet provides inherent model interpretability by giving users insight into which features it used to help make its decision. The algorithm utilizes attention, which learns to selectively enhance the influence of some features while diminishing the influence of others through a weighted average. For a particular decision, TabNet decides in a stepwise fashion how much importance to place on each feature. It then combines each of the steps to create a final prediction. The attention is multiplicative, where larger values indicate that the feature played a larger role in the prediction and a value of zero means that the feature played no role in that decision. Because TabNet uses multiple decision steps, the attention placed on the features across all of the steps are linearly combined after appropriate scaling. This linear combination across all of TabNet's decision steps is the total feature importance that TabNet provides you.

What's next