Make predictions with PyTorch models in ONNX format

Open Neural Network Exchange (ONNX) provides a uniform format designed to represent any machine learning frameworks. BigQuery ML support for ONNX allows you to:

  • Train a model using your favorite framework.
  • Convert the model into ONNX model format. For more information, see Converting to ONNX format.
  • Import the ONNX model into BigQuery and make predictions using BigQuery ML.

This tutorial shows you how to import ONNX models trained with PyTorch into a BigQuery dataset and use them to make predictions from a SQL query. You can import ONNX models using these interfaces:

For more information about importing ONNX models into BigQuery, including format and storage requirements, see The CREATE MODEL statement for importing ONNX models.


In this tutorial, you will:

  • Create and train models with PyTorch.
  • Convert the models to ONNX format using torch.onnx.
  • Import the ONNX models into BigQuery and make predictions.

Create a PyTorch vision model for image classification

Import a PyTorch pretrained resnet18 that accepts decoded image data returned by the BigQuery ML ML.DECODE_IMAGE and ML.RESIZE_IMAGE functions.

import torch
import torch.nn as nn

# Define model input format to match the output format of
# ML.DECODE_IMAGE function: [height, width, channels]
dummy_input = torch.randn(1, 224, 224, 3, device="cpu")

# Load a pretrained pytorch model for image classification
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)

# Reshape input format from [batch_size, height, width, channels]
# to [batch_size, channels, height, width]
class ReshapeLayer(nn.Module):
    def __init__(self):

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)  # reorder dimensions
        return x

class ArgMaxLayer(nn.Module):
    def __init__(self):

    def forward(self, x):
       return torch.argmax(x, dim=1)

final_model = nn.Sequential(

Convert the model into ONNX format and save

Use torch.onnx to export the PyTorch vision model to an ONNX file named resnet18.onnx.

torch.onnx.export(final_model,            # model being run
                  dummy_input,            # model input
                  "resnet18.onnx",        # where to save the model
                  opset_version=10,       # the ONNX version to export the model to
                  input_names = ['input'],         # the model's input names
                  output_names = ['class_label'])  # the model's output names

Upload the ONNX model to Cloud Storage

Create a Cloud Storage bucket to store the ONNX model file, and then upload the saved ONNX model file to your Cloud Storage bucket. For more information, see Upload objects from a filesystem.

Import the ONNX model into BigQuery

This step assumes you have uploaded the ONNX model to your Cloud Storage bucket. An example model is stored at gs://cloud-samples-data/bigquery/ml/onnx/resnet18.onnx.


  1. In the Google Cloud console, go to the BigQuery page.

    Go to the BigQuery page

  2. In the query editor, enter a CREATE MODEL statement like the following.

     CREATE OR REPLACE MODEL `mydataset.mymodel`

    For example:

     CREATE OR REPLACE MODEL `example_dataset.imported_onnx_model`

    The preceding query imports the ONNX model located at gs://cloud-samples-data/bigquery/ml/onnx/resnet18.onnx as a BigQuery model named imported_onnx_model.

  Your new model should now appear in the Resources panel. As you expand each of the datasets in a project, models are listed along with the other BigQuery resources in the datasets.

  4. If you select the new model in the Resources panel, information about the model appears below the Query editor.

    onnx model info


To import an ONNX model from Cloud Storage, run a batch query by entering a command like the following:

bq query \
--use_legacy_sql=false \

For example:

bq query --use_legacy_sql=false \

After importing the model, it should appear in the output of bq ls [dataset_name]:

$ bq ls example_dataset

       tableId          Type    Labels   Time Partitioning
 --------------------- ------- -------- -------------------
  imported_onnx_model   MODEL


Insert a new job and populate the jobs#configuration.query property as in the following request body:

  "query": "CREATE MODEL `project_id:mydataset.mymodel` OPTIONS(MODEL_TYPE='ONNX' MODEL_PATH='gs://bucket/path/to/onnx_model/*')"

Create an object table in BigQuery to access image data

To access unstructured data in BigQuery, you need to create an object table. See Create object tables for detailed instructions.

Create an object table named goldfish_image_table on a goldfish image stored at gs://mybucket/goldfish.jpg.

CREATE EXTERNAL TABLE `example_dataset.goldfish_image_table`
  object_metadata = 'SIMPLE',
  uris = ['gs://mybucket/goldfish.jpg'],
  max_staleness = INTERVAL 1 DAY,
  metadata_cache_mode = 'AUTOMATIC');

Make predictions with the imported ONNX model


  1. In the Google Cloud console, go to the BigQuery page.

    Go to the BigQuery page

  2. In the query editor, enter a query using ML.PREDICT like the following.

       ML.PREDICT(MODEL example_dataset.imported_onnx_model,
             FALSE) AS input
           example_dataset.goldfish_image_table) )

    The preceding query uses the model named imported_onnx_model in the dataset example_dataset in the current project to make predictions from image data in the input object table goldfish_image_table. ML.DECODE_IMAGE function is required to decode the image data so that it can be interpreted by ML.PREDICT. Additionally, ML.RESIZE_IMAGE function is called to resize the image to fit the size of the model's input (224*224). For more information about running inference on image object tables, see Run inference on image object tables.

    This query outputs the predicted class label of the input image based on the ImageNet labels dictionary.

    Query results


To make predictions from input data in the table input_data, enter a command like the following, using the imported ONNX model my_model:

bq query \
--use_legacy_sql=false \
   MODEL `my_project.my_dataset.my_model`,
   (SELECT * FROM input_data))'


Insert a new job and populate the jobs#configuration.query property as in the following request body:

  "query": "SELECT * FROM ML.PREDICT(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM input_data))"

