Using Distributed TensorFlow with Cloud ML Engine and Cloud Datalab

This tutorial shows you how to use a distributed configuration of TensorFlow code in Python on Google Cloud Machine Learning Engine to train a convolutional neural network model by using the MNIST dataset. You use TensorBoard to visualize the training process and Google Cloud Datalab to test the predictions.

TensorFlow is Google's open source library for machine learning, developed by researchers and engineers in Google's Machine Intelligence organization, which is part of Research at Google. TensorFlow is designed to run on multiple computers to distribute the training workloads, and Cloud Machine Learning Engine provides a managed service where you can run TensorFlow code in a distributed manner by using service APIs.

In this tutorial, the term node refers to an application container that runs parallel computations during training.

About the data

The MNIST dataset enables handwritten digit recognition, and is widely used in machine learning as a training set for image recognition.

The dataset contains a large number of images of hand-written digits in the range 0 to 9, as well as the labels identifying the digit in each image.

This tutorial trains a machine learning model to classify images based on the MNIST dataset. After training, the model classifies incoming images into 10 categories (0 to 9) based on what it's learned about handwritten images from the MNIST dataset. You can then send the model an image that it hasn't seen before, and the model identifies the digit in the image based on what the model has learned during training.

The MNIST dataset has been split into three parts:

  • 55,000 examples of training data
  • 10,000 examples of test data
  • 5,000 examples of validation data

You can find more information about the dataset at the MNIST database site.

Understanding neural networks

In computer programming, humans instruct a computer to solve a problem by specifying each step using many lines of code. With machine learning and neural networks, you instead get the computer to solve the problem through examples.

A neural network is a mathematical function that can learn the expected output for a given input from training datasets. The following figure illustrates a neural network that has been trained to output "cat" from a cat image.

A neural network model is a function that can be trained through examples.

You can see that a neural network model consists of multiple layers of calculation units, in which each layer has configurable parameters. The goal of training the model is to optimize the parameters to get results with the highest accuracy. The training algorithm makes adjustments as it processes batches of training datasets through the model. If you distribute the training process to multiple computational nodes, you need a way to keep track of the changing parameters to be shared by all nodes.

Architecture of the distributed training

There are three basic strategies to train a model with multiple nodes:

  • Data-parallel training with synchronous updates.
  • Data-parallel training with asynchronous updates.
  • Model-parallel training.

The example code in this tutorial uses data-parallel training with asynchronous updates on Cloud ML Engine. In this case, a training job is executed using the following types of nodes:

  • Parameter server node. Update parameters with gradient vectors from worker and chief work nodes.
  • Worker node. Calculate a gradient vector from the training dataset.
  • Chief worker node. Coordinate the operations of multiple workers, in addition to working as one of the worker nodes.

Because you can use the data-parallel strategy regardless of the model structure, it is a good starting point for applying the distributed training method to your custom model. In data-parallel training, the whole model is shared with all worker nodes. Each node calculates gradient vectors independently from some part of the training dataset in the same manner as the mini-batch processing. The calculated gradient vectors are collected into the parameter server node, and model parameters are updated with the total summation of the gradient vectors. If you distribute 10,000 batches among 10 worker nodes, each node works on roughly 1,000 batches.

Data-parallel training can be done with either synchronous or asynchronous updates. When using asynchronous updates, the parameter server applies each gradient vector independently, right after receiving it from one of the worker nodes, as shown in the following diagram.

Data-parallel training with asynchronous updates.

In a typical deployment, there are a few parameter server nodes, a single chief worker node, and several worker nodes. When you submit a training job through the service API, these nodes are automatically deployed in your project.

The following diagram describes the architecture for running a distributed training job on Cloud ML Engine and using Cloud Datalab to execute predictions with your trained model.

Architecture used by the tutorial.


  • Run the distributed TensorFlow sample code on Cloud ML Engine.
  • Deploy the trained model to Cloud ML Engine to create a custom API for predictions.
  • Visualize the training process with TensorBoard.
  • Use Cloud Datalab to test the predictions.


This tutorial uses billable components of Cloud Platform, including:

  • Cloud ML Engine
  • Google Cloud Storage
  • Google Compute Engine
  • Compute Engine Persistent Disk

The estimated price to run this tutorial, assuming you use every resource for an entire day, is approximately $1.20 based on this pricing calculator.

Before you begin

  1. Sign in to your Google Account.

    If you don't already have one, sign up for a new account.

  2. Select or create a Google Cloud Platform project.

    Go to the Manage resources page

  3. Make sure that billing is enabled for your Google Cloud Platform project.

    Learn how to enable billing

  4. Enable the AI Platform ("Cloud Machine Learning Engine") and Compute Engine API.

    Enable the API

Verifying the Google Cloud SDK components

  1. Go to Cloud Shell.

    Open Cloud Shell

  2. List the models to verify that the command returns an empty list:

    gcloud ai-platform models list

    Verify that the command returns an empty list:

    Listed 0 items.

If you've already worked with Cloud ML Engine, you'll get a list of all of the models associated with your account.

Downloading example files

Download the example files and set your current directory.

git clone
cd cloudml-dist-mnist-example

Creating a Cloud Storage bucket for MNIST files

  1. Create a regional Cloud Storage bucket to hold the MNIST data files that are used to train the model.

    PROJECT_ID=$(gcloud config list project --format "value(core.project)")
    gsutil mb -c regional -l us-central1 gs://${BUCKET}
  2. Use the following script to download the MNIST data files and copy them to the bucket.

    gsutil cp /tmp/data/train.tfrecords gs://${BUCKET}/data/
    gsutil cp /tmp/data/test.tfrecords gs://${BUCKET}/data/

Training the model on Cloud Machine Learning Engine

  1. Submit a training job to Cloud ML Engine.

    JOB_NAME="job_$(date +%Y%m%d_%H%M%S)"
    gcloud ai-platform jobs submit training ${JOB_NAME} \
        --package-path trainer \
        --module-name trainer.task \
        --staging-bucket gs://${BUCKET} \
        --job-dir gs://${BUCKET}/${JOB_NAME} \
        --runtime-version 1.2 \
        --region us-central1 \
        --config config/config.yaml \
        -- \
        --data_dir gs://${BUCKET}/data \
        --output_dir gs://${BUCKET}/${JOB_NAME} \
        --train_steps 10000

    The --train_steps option specifies the total number of training batches.

    You can control the amount of resources allocated for the training job by specifying a scale tier with the configuration file config/config.yaml. When the job starts running on multiple nodes, the same Python code in the trainer directory that is specified with --package-path parameter are deployed on all nodes. The files in the trainer directory and their functions are listed in the table below:

    File Description
    `` Setup script to install additional modules on nodes.
    `` TensorFlow code to define the convolutional neural network model.
    `` TensorFlow code to run the training task. In this example, Experiment API is used to run the training loop in a distributed manner.
  2. Open the ML Engine page in the Google Cloud Platform Console to find the running job.


  3. Click the job ID to find a link to the log viewer. The example code shows progress in logs during the training. For example, each worker node shows a training loss value, which represents the total loss value for the dataset in a single training batch, at some intervals. In addition, the chief worker node shows a loss and accuracy for the test set. At the end of the training, the final evaluation against the test set is shown. In this example, the training achieved 99.3% accuracy for the test set.

    Saving dict for global step 10008: accuracy = 0.9931, global_step = 10008, loss = 0.0315906
  4. After the training, the trained model is exported in the storage bucket. You can find the storage path for the directory that contains the model binary by using the following command.

    gsutil ls gs://${BUCKET}/${JOB_NAME}/export/Servo | tail -1

    The output should look like this:


Visualizing the training process with TensorBoard

After the training, the summary data is stored in gs://${BUCKET}/${JOB_NAME} and you can visualize them with TensorBoard.

  1. Run the following command in Cloud Shell to start TensorBoard.

    tensorboard --port 8080 --logdir gs://${BUCKET}/${JOB_NAME}
  2. To open a new browser window, select Preview on port 8080 from the Web preview menu in the top-right corner of the Cloud Shell toolbar.

  3. In the new window, you can use TensorBoard to see the training summary and the visualized network graph. Press Control+C to stop TensorBoard in the Cloud Shell.

    TensorBoard shows training summary and network graph.

Deploying the trained model for predictions

You deploy the trained model for predictions using the model binary.

  1. Deploy the model and set the default version.

    gcloud ai-platform models create --regions us-central1 ${MODEL_NAME}
    ORIGIN=$(gsutil ls gs://${BUCKET}/${JOB_NAME}/export/Servo | tail -1)
    gcloud ai-platform versions create \
        --origin ${ORIGIN} \
        --model ${MODEL_NAME} \
    gcloud ai-platform versions set-default --model ${MODEL_NAME} ${VERSION_NAME}

    MODEL_NAME and VERSION_NAME can be arbitrary, but you can't reuse the same name. The last command is not necessary for the first version because it automatically becomes the default. It's a good practice to set the default explicitly.

    It might take a few minutes for the deployed model to become ready. Until it becomes ready, it returns an HTTP 503 error against requests.

  2. Test the prediction API by using a sample request file.


    This script creates a JSON file, named request.json, containing 10 test images for predictions.

  3. Submit an online prediction request.

    gcloud ai-platform predict --model ${MODEL_NAME} --json-instances request.json

    You should get a response like this:

    7        [3.437006127094938e-21, 5.562060376991084e-16, 2.5538862785511466e-19, 7.567420805782991e-17, 2.891652426709158e-16, 2.2750016241705544e-20, 1.837758172149778e-24, 1.0, 6.893573298530907e-19, 8.065571390565747e-15]
    2        [1.2471907477623206e-23, 2.291396136267388e-25, 1.0, 1.294716955176118e-32, 3.952643278911311e-25, 3.526924652059716e-36, 3.607279481567486e-25, 1.8093850397574458e-30, 7.008172489249426e-26, 2.6986217649454554e-29]
    9        [5.124952379488745e-22, 1.917571388490136e-20, 2.02434602684524e-21, 2.1246177460406675e-18, 1.8790316524963657e-11, 2.7904309518969085e-14, 7.973171243464317e-26, 6.233734909559877e-14, 9.224547341257772e-12, 1.0]

CLASSES is the most probable digit of the given image, and PROBABILITIES shows the probabilities of each digit.

Executing predictions with Cloud Datalab

To test your predictions, create a Cloud Datalab instance, which uses interactive Jupyter Notebooks to execute code.

  1. In Cloud Shell, enter the following command to create a Cloud Datalab instance.

    datalab create mnist-datalab --zone us-central1-a
  2. From Cloud Shell you can launch the Cloud Datalab notebook listing page by clicking on Cloud Shell Web preview (Square icon in the top right).

  3. Select Change port and select Port 8081 to launch a new tab in your browser.

  4. In the Datalab application, create a new notebook by clicking on the +Notebook icon in the upper left.

  5. Paste the following text into the first cell of the new notebook.

    cat Online\ prediction\ example.ipynb > Untitled\ Notebook.ipynb
  6. Click the Run command at the top of the page to download the Online prediction example.ipynb notebook and copy its contents into the current notebook.

  7. Refresh the browser page to load the new notebook content. Then select the first cell containing the JavaScript code and click the Run command to execute it.

  8. Scroll down the page until you see the number drawing panel, and draw a number with your cursor.

    Number 3 drawn by hand.

  9. Click in the next cell to activate it and then click on the down arrow next to the Run button at the top and select Run from this Cell.

  10. The output of the prediction returns a class label and a list of probabilities. The class label indicates a prediction of the number you entered. In the list of probabilities, each index from 0-9 contains a number. The closer the number is to 1, the more likely that index matches the number you entered. In the following example, you can see that the number 3 slot highlighted in the list is very close to 1. Correspondingly, the prediction is 3.

  11. The last cell in the notebook draws a bar chart so you can see that it predicted your number.

    Bar chart shows bar at number 3.

Cleaning up

To avoid incurring charges to your Google Cloud Platform account for the resources used in this tutorial:

The easiest way to delete all resources is simply to delete the project you created for this tutorial.

  1. In the GCP Console, go to the Projects page.

    Go to the Projects page

  2. In the project list, select the project you want to delete and click Delete delete.
  3. In the dialog, type the project ID, and then click Shut down to delete the project.

What's next

Was this page helpful? Let us know how we did:

Send feedback about...

AI Platform for TensorFlow