Training ResNet on Cloud TPU

This tutorial shows you how to train the Tensorflow ResNet-50 model on Cloud TPU. You can apply the same pattern to other TPU-optimised image classification models that use TensorFlow and the ImageNet dataset.


This tutorial uses a third-party dataset. Google provides no representation, warranty, or other guarantees about the validity, or any other aspects of, this dataset.

Model description

The model in this tutorial is based on Deep Residual Learning for Image Recognition, which first introduces the residual network (ResNet) architecture. This tutorial uses the 50-layer variant, known as ResNet-50.

This tutorial uses tf.estimator to train the model. tf.estimator is a high-level TensorFlow API and is the recommended way to build and run a machine learning model on Cloud TPU. The API simplifies the model development process by hiding most of the low-level implementation, making it easier to switch between TPU and other platforms such as GPU or CPU.

Before you begin

Before starting this tutorial, check that your Cloud project is correctly set up, and create a Compute Engine VM and a TPU resource.

This section is identical to the Quickstart guide. If you already completed the Quickstart without deleting your VM and TPU resource, you can skip directly to getting the data.

Set up your Cloud project

  1. Sign in to your Google Account.

    If you don't already have one, sign up for a new account.

  2. Select or create a GCP project.

    Go to the Manage resources page

  3. Make sure that billing is enabled for your project.

    Learn how to enable billing

  4. This walkthrough uses billable components of Google Cloud Platform. Check the Cloud TPU pricing page to estimate your costs, and follow the instructions to clean up resources when you've finished with them.

  5. Visit the TPU quota page to see your project's current Cloud TPU quota. By default, you have quota for at least 2 Cloud TPUs (16 cores).

    Go to TPU quota page

    If you need additional quota, submit a quota request form to describe your needs.

  6. Install the gcloud command-line tool via the Cloud SDK.

    Install the Cloud SDK

  7. Install the gcloud beta components, which include the commands necessary to create Cloud TPU resources.

    $ gcloud components install beta

Create a Compute Engine VM and a TPU resource

This section shows you how to create a Compute Engine VM and create a TPU resource from your local machine.

  1. Use the gcloud command-line tool to specify your Cloud Platform project:

    $ gcloud config set project your-cloud-project

    where your-cloud-projectis the name of the your Cloud Platform project with access to TPU quota.

  2. Specify the zone where you plan to create your VM and TPU resource. For this tutorial, use the us-central1-c zone:

    $ gcloud config set compute/zone us-central1-c

    For reference, Cloud TPU is available in the following zones:

    • us-central1-b
    • us-central1-c
  3. Create the Compute Engine VM instance:

    $ gcloud compute instances create tpu-demo-vm \
      --machine-type=n1-standard-4 \
      --image-project=ml-images \
      --image-family=tf-1-8 \


    • tpu-demo-vm is a name for identifying the VM instance that you're creating.

    • --machine-type=n1-standard-4 is a standard machine type with 4 virtual CPUs and 15 GB of memory. See Machine Types for more machine types.

    • --image-project=ml-images is a shared collection of images that makes the tf-1-8 image available for your use.

    • --image-family=tf-1-8 is an image with the required pip package for TensorFlow.

    • --scopes=cloud-platform allows the VM to access Cloud Platform APIs.

  4. Create a new Cloud TPU resource. For this example, name the resource demo-tpu. Keep in mind that billing begins as soon as the TPU is created, until the time it is deleted. (Check the Cloud TPU pricing page to estimate your costs.) If you are using a dataset that requires a substantial download and processing phase, hold off on running this command until you are ready to use the TPU:

    $ gcloud beta compute tpus create demo-tpu \
      --range= --version=1.8

    --range specifies the address of the created TPU resource and can be any value in 10.240.*.*/29. For this example, use

Get the data

Below are the instructions for using a randomly generated fake dataset to test the model. Alternatively, you can use the full ImageNet dataset.

Use the fake dataset for testing purposes

The fake dataset is at this location on Cloud Storage:


Note that the fake dataset is only useful for understanding how to use a Cloud TPU, and validating end-to-end performance. The accuracy numbers and saved model will not be meaningful.

Grant storage access to the TPU

You need to give your TPU read/write access to Cloud Storage objects. To do that, you must grant the required access to the service account used by the TPU. Follow these steps to find the TPU service account and grant the necessary access:

  1. List your TPUs to find their names:

    $ gcloud beta compute tpus list
  2. Use the describe command to find the service account of your TPU, where demo-tpu is the name of your TPU resource:

    $ gcloud beta compute tpus describe demo-tpu
  3. Copy the name of the TPU service account from the output of the describe command. The name has the format of an email address, like

  4. Log in to to the Google Cloud Platform Console and choose the project in which you’re using the TPU.

  5. Choose IAM & Admin > IAM.

  6. Click the +Add button to add a member to the project.

  7. Enter the name of the TPU service account in the Members text box.

  8. Click the Roles dropdown list.

  9. Enable the following roles:

    • Project > Viewer
    • Logging > Logs Writer
    • Storage > Storage Object Admin

Run the ResNet-50 model

You are now ready to train and evaluate the ResNet-50 model.

  1. Remotely connect to the created Compute Engine VM:

    $ gcloud compute ssh tpu-demo-vm -- -L 6006:localhost:6006
    -L 6006:localhost:6006 port forwards the Tensorboard port from the VM to your local machine.

  2. Run the following command, where demo-tpu is the name of the TPU resource you created earlier:

    (vm)$ export TPU_NAME=demo-tpu

  3. Run the following command, replacing [DATA_DIR] with gs://cloud-tpu-test-datasets/fake_imagenet if you are using the fake dataset, or with the path to the Cloud Storage bucket containing your training data:

    (vm)$ export DATA_DIR=[DATA_DIR]

  4. Create a Cloud Storage bucket to store the trained model and training logs. Cloud Storage bucket names must be unique:

    (vm)$ export MODEL_DIR=gs://your_resnet_bucket
    (vm)$ gsutil mkdir $MODEL_DIR
    This bucket stores checkpoints so re-using an existing bucket will resume training at the most recent checkpoint.

  5. The ResNet-50 model is pre-installed on the tf-1-8 VM image you are using for this tutorial. Navigate to the directory:

    (vm)$ cd /usr/share/tpu/models/official/resnet/

  6. Run the training script:

    (vm)$ python \
      --tpu_name=$TPU_NAME \
      --data_dir=$DATA_DIR \

What to expect

The above procedure trains the ResNet-50 model for 100 epochs and evaluates every fixed number of steps. With the default flags, the model should train to above 75% accuracy.

TPU-specific modifications to the ResNet-50 model

The ResNet code in this tutorial uses TPUEstimator which is based on the high-level Estimator API. There are a few code changes that are required in order to convert an Estimator-based model to a TPUEstimator-based model for training.

Import the following libraries:

from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer

Use the CrossShardOptimizer function to wrap the optimizer, such as:

if FLAGS.use_tpu:
  optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

Define the model_fn and return a TPUEstimator specification using:

return tpu_estimator.TPUEstimatorSpec(

To run the model on TPU, you need the TPU gRPC address, which you can get using Define an Estimator compatible configuration using:

config = tpu_config.RunConfig(

Creating the Estimator object using configuration and model data:

estimator = tpu_estimator.TPUEstimator(

Finally, the Python program runs the estimator.train function for the number of iterations defined in the configuration:

estimator.train(input_fn=input_fn, max_steps=FLAGS.train_steps)

Clean up

When you've finished with the tutorial, clean up the VM and TPU resource to avoid incurring extra charges to your Google Cloud Platform account.

If you haven't set the project and zone for this session, do so before running the cleanup procedure.

  1. Use the gcloud command-line tool to delete your Cloud TPU resource:

    (vm)$ gcloud beta compute tpus delete demo-tpu
  2. Disconnect from the Compute Engine VM instance:

    (vm)$ exit
  3. Use the gcloud command-line tool to delete your Compute Engine instance:

    $ gcloud compute instances delete tpu-demo-vm
  4. Go to the VPC Networking page in the Google Cloud Platform Console.

    Go to the VPC Networking page
  5. Select the VPC network that Google automatically created as part of the Cloud TPU setup. The peering entry starts with cp-to-tp-peering in the ID.

  6. At the top of the VPC Networking page, click Delete to delete the selected VPC network.

  7. Go to the Network Routes page in the Google Cloud Platform Console.

    Go to the Network Routes page
  8. Select the route that Google automatically created as part of the Cloud TPU setup. The peering entry starts with peering-route in the ID.

  9. At the top of the Network Routes page, click Delete to delete the selected route.

When you've finished finished examining the data, use the gsutil command to delete any Cloud Storage buckets you created during this tutorial. (See the Cloud Storage pricing guide for free storage limits and other pricing information.) Replace your-bucket-name with the name of your Cloud Storage bucket:

$ gsutil rm -r gs://your-bucket-name

Using the full ImageNet dataset

You need about 300GB of space available on your local machine or VM to run the script used in this section.

If you decide to process the data on your Compute Engine VM, follow these steps to add disk space to the VM:

  • Follow the Compute Engine guide to add a disk to your VM.
  • Set the disk size to 300GB or more.
  • Set When deleting instance to Delete disk to ensure that the disk is removed when you remove the VM.
  • Make a note of the path to your new disk. For example: /mnt/disks/mnt-dir.

Download and convert the ImageNet data:

  1. Sign up for an ImageNet account. Remember the username and password you used to create the account.

  2. Create a Cloud Storage bucket for the dataset. Cloud Storage bucket names must be unique:

    $ export DATA_DIR=gs://your_bucket_name
    $ gsutil mkdir $DATA_DIR
    Alternatively you can create the bucket via Google Cloud Platform Console.

  3. Download the script from GitHub:

    $ wget

  4. Set a SCRATCH_DIR variable to contain the script's working files. The variable must specify a location on your local machine or on your Compute Engine VM. For example, on your local machine:

    $ SCRATCH_DIR=./imagenet_tmp_files

    Or if you're processing the data on the VM:

    (vm)$ SCRATCH_DIR=/mnt/disks/mnt-dir/imagenet_tmp_files

  5. Run the script to download, format, and upload the ImageNet data to the bucket. Replace [USERNAME] and [PASSWORD] with the username and password you used to create your ImageNet account.

    $ pip install google-cloud-storage
    $ python \
      --project=$PROJECT \
      --gcs_output_path=$DATA_DIR \
      --local_scratch_dir=$SCRATCH_DIR \
      --imagenet_username=[USERNAME] \

Note: Downloading and preprocessing the data can take up to half a day, depending on your network and computer speed. Do not interrupt the script.

When the script finishes processing, a message like the following appears:

2018-02-17 14:30:17.287989: Finished writing all 1281167 images in data set.

The script produces a series of directories (for both training and validation) of the form:




What's next

Send feedback about...