Using Cloud Dataflow for Batch Predictions with TensorFlow

This article shows you how to use Cloud Dataflow to run batch processing for machine learning predictions. The article uses the machine learning model trained with TensorFlow. The trained model is exported into a Google Cloud Storage bucket before batch processing starts. The model is dynamically restored on the worker nodes of prediction jobs.

This approach enables you to make predictions against a large dataset, stored in a Cloud Storage bucket or Google BigQuery tables, in a scalable manner, because Cloud Dataflow automatically distributes the prediction tasks to multiple worker nodes.

Cloud Dataflow is a unified programming model and a managed service for developing and executing a wide range of data processing patterns including ETL, batch computation, and continuous computation. Cloud Dataflow frees you from operational tasks such as resource management and performance optimization.

In this solution, predictions are made for the MNIST dataset using a pre-trained convolutional neural network. The MNIST dataset enables handwritten digit recognition; it is widely used in machine learning as a training set for image recognition.

Architecture overview

This article shows two example use cases. The Cloud Dataflow pipelines are coded in Python for both cases. For more about the Python SDK, see Apache Beam - Python SDK.

The first use case, shown in figure 1, starts with the data source as a text file in a Cloud Storage bucket. The file contains handwritten digit images, in a text format, where each line contains a key-value pair that consists of a serial number (the key) and data from a single image consisting of 784 float values. Each value represents a pixel in a 28-by-28 pixels grayscale image. The beam.io.ReadFromText class reads data from a text file, and the beam.io.WriteToText class writes the prediction results in different text files.

data source as a text file in Cloud Storage
Figure 1. Using Cloud Storage as a data source.

The second use case, shown in figure 2, uses the data source as a BigQuery table that contains the same handwritten digit images, in a string format. The beam.io.BigQuerySource class reads data from the table, and the beam.io.BigQuerySink class stores the prediction results in another table.

data source as a BigQuery table
Figure 2. Using BigQuery as a data source.

The main logic of the pipeline code is the same in each case. This shows the flexibility of Cloud Dataflow: the same pipeline code can be used for different data sources. In both cases, the trained TensorFlow model is restored from the model binaries that are stored in a Cloud Storage bucket.

This enables you to change the model easily by replacing the model binaries in your bucket. It is also possible to modify the pipeline code to mix Cloud Storage and Google BigQuery as a data source and a data sink. For example, you can use text files in a Cloud Storage bucket as a data source and store the prediction results in BigQuery tables.

Objectives

  • Use Cloud Dataflow to distribute prediction tasks.
  • Run a prediction using Cloud Storage as a data source.
  • Run a prediction using BigQuery as a data source.

Costs

This tutorial uses billable components of Cloud Platform, including:

  • Cloud Dataflow
  • Cloud Storage
  • BigQuery

Use the Pricing Calculator to generate a cost estimate based on your projected usage.

New Cloud Platform users might be eligible for a free trial.

Before you begin

  1. Sign in to your Google account.

    If you don't already have one, sign up for a new account.

  2. Select or create a Cloud Platform project.

    Go to the Manage resources page

  3. Enable billing for your project.

    Enable billing

  4. Enable the Google Cloud Dataflow API.

    Enable the API

Cloning the sample code to Cloud Shell

  1. Launch the Cloud Shell.

  2. Set your default compute zone. In Cloud Shell, enter the following command:

     gcloud config set compute/zone us-east1-d
    
  3. Install the Cloud Dataflow Python SDK:

    pip install --upgrade google-cloud-dataflow --user
    
  4. Clone the lab repository:

     git clone https://github.com/GoogleCloudPlatform/dataflow-prediction-example
    
  5. Enter the new project directory:

     cd dataflow-prediction-example
    

The files and their functions are listed in the table below:

Directory File Content
data export.meta Trained model binary (network metagraph).
data export Trained model binary (variable values).
data images.txt Prediction data source file.
prediction setup.py Setup script to install TensorFlow on worker nodes.
prediction run.py Bootstrap script of worker nodes.
prediction/modules predict.py Pipeline execution script.

Creating a bucket and uploading files

  1. Create a Cloud Storage bucket to store work files:

    PROJECT=$(gcloud config list project --format "value(core.project)")
    BUCKET=gs://$PROJECT-dataflow
    gsutil mkdir $BUCKET
    
  2. Upload the trained model binaries and the prediction data source file to the bucket.

    gsutil cp data/export* $BUCKET/model/
    gzip -kdf data/images.txt.gz
    gsutil cp data/images.txt $BUCKET/input/
    

Making predictions

Using Cloud Storage for input and output

  1. Use the following script to submit a prediction job. The flag --source cs indicates that the prediction data source and the prediction results are stored in the Cloud Storage bucket as in figure 1. You specify the filenames for data source and output files by using the --input and --output parameters.

    python prediction/run.py \
      --runner DataflowRunner \
      --project $PROJECT \
      --staging_location $BUCKET/staging \
      --temp_location $BUCKET/temp \
      --job_name $PROJECT-prediction-cs \
      --setup_file prediction/setup.py \
      --model $BUCKET/model \
      --source cs \
      --input $BUCKET/input/images.txt \
      --output $BUCKET/output/predict
    

Note: The number of worker nodes is automatically adjusted by the autoscaling feature. You can specify the number of nodes by using the --num_workers parameter, if you want to use a fixed number of nodes.

  1. Open the Cloud Dataflow page in the Google Cloud Platform Console to find the running job. Look for the data flow graph to show that the job has succeeded.

    OPEN CLOUD DATAFLOW

    image

    When the job finishes successfully, the prediction results are stored in the Cloud Storage bucket.

  2. Open the Storage Browser in the Cloud Platform Console:

    OPEN Storage Browser

  3. Go to the output directory in your storage bucket named [$PROJECT]-dataflow to find the output files, which should look like this:

image

Using BigQuery for input and output

You can run the same prediction job using BigQuery as the data source and the output source.

  1. Create a BigQuery table and upload the prediction data source.

    bq mk mnist
    bq load --source_format=CSV -F":" mnist.images data/images.txt.gz \
      "key:integer,image:string"
    
  2. Submit a prediction job again using the script below.

    python prediction/run.py \
    --runner DataflowRunner \
    --project $PROJECT \
    --staging_location $BUCKET/staging \
    --temp_location $BUCKET/temp \
    --job_name $PROJECT-prediction-bq \
    --setup_file prediction/setup.py \
    --model $BUCKET/model \
    --source bq \
    --input $PROJECT:mnist.images \
    --output $PROJECT:mnist.predict
    
  3. In the Cloud Platform Console, find the running job and look for the data flow graph as below to show the job has succeeded.

  4. Open the Cloud Dataflow page in the Google Cloud Platform Console.

    OPEN CLOUD DATAFLOW

    image

When the job finishes, the prediction results are stored in the BigQuery table.

  1. Click on the BigQuery menu in the Cloud Platform Console.
  2. Click Compose query and enter the following query:

    SELECT * FROM mnist.predict WHERE key < 10 ORDER BY key;
    
  3. Click on Run query to see your prediction results for the first 10 images, in a tabular format, as follows:

    image

  4. Look at rows pred0 through pred9, which indicate the probability of the corresponding labels (digit 0 through digit 9).

At this point you can download the data as a CSV file or save it to Google Sheets from the buttons above the table on the screen. You can also use it for additional analysis in BigQuery. For example, you can count the number of images which are highly possibly the digit 0 with following query:

  SELECT COUNT(*) FROM mnist.predict WHERE pred0 >= 0.9;

Cleaning up

The data results that you produce when you do this tutorial are stored in Cloud Storage buckets and BigQuery tables. You can keep them in your project as a reference if you want to, but be aware that you'll incur small charges as described on the price list.

To entirely stop ongoing charges, you can delete the project.

  1. In the Cloud Platform Console, go to the Projects page.

    Go to the Projects page

  2. In the project list, select the project you want to delete and click Delete project. After selecting the checkbox next to the project name, click
      Delete project
  3. In the dialog, type the project ID, and then click Shut down to delete the project.

What's next

  • Read What is BigQuery?
  • Read Cloud Machine Learning to learn how to use Cloud Machine Learning to train your model for making predictions on your own dataset.

  • Try out other Google Cloud Platform features for yourself. Have a look at our tutorials.

Send feedback about...