Get batch predictions with Cloud Storage input

Creates a batch prediction job with a Cloud Storage file as input.

Explore further

For detailed documentation that includes this code sample, see the following:

Code sample

Java

To authenticate to AutoML Tables, set up Application Default Credentials. For more information, see Set up authentication for a local development environment.

import com.google.api.gax.longrunning.OperationFuture;
import com.google.cloud.automl.v1beta1.BatchPredictInputConfig;
import com.google.cloud.automl.v1beta1.BatchPredictOutputConfig;
import com.google.cloud.automl.v1beta1.BatchPredictRequest;
import com.google.cloud.automl.v1beta1.BatchPredictResult;
import com.google.cloud.automl.v1beta1.GcsDestination;
import com.google.cloud.automl.v1beta1.GcsSource;
import com.google.cloud.automl.v1beta1.ModelName;
import com.google.cloud.automl.v1beta1.OperationMetadata;
import com.google.cloud.automl.v1beta1.PredictionServiceClient;
import java.io.IOException;
import java.util.concurrent.ExecutionException;

abstract class BatchPredict {

  static void batchPredict() throws IOException, ExecutionException, InterruptedException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String modelId = "YOUR_MODEL_ID";
    String inputUri = "gs://YOUR_BUCKET_ID/path_to_your_input_csv_or_jsonl";
    String outputUri = "gs://YOUR_BUCKET_ID/path_to_save_results/";
    batchPredict(projectId, modelId, inputUri, outputUri);
  }

  static void batchPredict(String projectId, String modelId, String inputUri, String outputUri)
      throws IOException, ExecutionException, InterruptedException {
    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PredictionServiceClient client = PredictionServiceClient.create()) {
      // Get the full path of the model.
      ModelName name = ModelName.of(projectId, "us-central1", modelId);

      // Configure the source of the file from a GCS bucket
      GcsSource gcsSource = GcsSource.newBuilder().addInputUris(inputUri).build();
      BatchPredictInputConfig inputConfig =
          BatchPredictInputConfig.newBuilder().setGcsSource(gcsSource).build();

      // Configure where to store the output in a GCS bucket
      GcsDestination gcsDestination =
          GcsDestination.newBuilder().setOutputUriPrefix(outputUri).build();
      BatchPredictOutputConfig outputConfig =
          BatchPredictOutputConfig.newBuilder().setGcsDestination(gcsDestination).build();

      // Build the request that will be sent to the API
      BatchPredictRequest request =
          BatchPredictRequest.newBuilder()
              .setName(name.toString())
              .setInputConfig(inputConfig)
              .setOutputConfig(outputConfig)
              .build();

      // Start an asynchronous request
      OperationFuture<BatchPredictResult, OperationMetadata> future =
          client.batchPredictAsync(request);

      System.out.println("Waiting for operation to complete...");
      future.get();
      System.out.println("Batch Prediction results saved to specified Cloud Storage bucket.");
    }
  }
}

Node.js

To authenticate to AutoML Tables, set up Application Default Credentials. For more information, see Set up authentication for a local development environment.


/**
 * Demonstrates using the AutoML client to request prediction from
 * automl tables using GCS.
 * TODO(developer): Uncomment the following lines before running the sample.
 */
// const projectId = '[PROJECT_ID]' e.g., "my-gcloud-project";
// const computeRegion = '[REGION_NAME]' e.g., "us-central1";
// const modelId = '[MODEL_ID]' e.g., "TBL4704590352927948800";
// const inputUri = '[GCS_PATH]' e.g., "gs://<bucket-name>/<csv file>",
// `The Google Cloud Storage URI containing the inputs`;
// const outputUriPrefix = '[GCS_PATH]'
// e.g., "gs://<bucket-name>/<folder-name>",
// `The destination Google Cloud Storage URI for storing outputs`;

const automl = require('@google-cloud/automl');

// Create client for prediction service.
const automlClient = new automl.v1beta1.PredictionServiceClient();

// Get the full path of the model.
const modelFullId = automlClient.modelPath(projectId, computeRegion, modelId);

async function batchPredict() {
  // Construct request
  const inputConfig = {
    gcsSource: {
      inputUris: [inputUri],
    },
  };

  // Get the Google Cloud Storage output URI.
  const outputConfig = {
    gcsDestination: {
      outputUriPrefix: outputUriPrefix,
    },
  };

  const [, operation] = await automlClient.batchPredict({
    name: modelFullId,
    inputConfig: inputConfig,
    outputConfig: outputConfig,
  });

  // Get the latest state of long-running operation.
  console.log(`Operation name: ${operation.name}`);
  return operation;
}

batchPredict();

Python

To authenticate to AutoML Tables, set up Application Default Credentials. For more information, see Set up authentication for a local development environment.

# TODO(developer): Uncomment and set the following variables
# project_id = 'PROJECT_ID_HERE'
# compute_region = 'COMPUTE_REGION_HERE'
# model_display_name = 'MODEL_DISPLAY_NAME_HERE'
# gcs_input_uri = 'gs://YOUR_BUCKET_ID/path_to_your_input_csv'
# gcs_output_uri = 'gs://YOUR_BUCKET_ID/path_to_save_results/'
# params = {}

from google.cloud import automl_v1beta1 as automl

client = automl.TablesClient(project=project_id, region=compute_region)

# Query model
response = client.batch_predict(
    gcs_input_uris=gcs_input_uri,
    gcs_output_uri_prefix=gcs_output_uri,
    model_display_name=model_display_name,
    params=params,
)
print("Making batch prediction... ")
# `response` is a async operation descriptor,
# you can register a callback for the operation to complete via `add_done_callback`:
# def callback(operation_future):
#   result = operation_future.result()
# response.add_done_callback(callback)
#
# or block the thread polling for the operation's results:
response.result()

print(f"Batch prediction complete.\n{response.metadata}")

What's next

To search and filter code samples for other Google Cloud products, see the Google Cloud sample browser.