Converting an image classification dataset for use with Cloud TPU

This tutorial describes how to use the image classification data converter sample script to convert a raw dataset for image classification into the TFRecord format used by Cloud TPU Tensorflow models. The image classification repository on github contains the converter script,, and a sample implementation,, you can copy and modify to do your own data conversion.

The image classification data converter sample defines two classes, ImageClassificationConfig and ImageClassificationBuilder. These classes are defined in tpu/tools/data_converter/

ImageClassificationConfig is an abstract base class. You subclass ImageClassificationConfig to define the configuration needed to instantiate an ImageClassificationBuilder.

ImageClassificationBuilder is a Tensorflow dataset builder for image classification datasets. It is a subclass of tdfs.core.GeneratorBasedBuilder. It retrieves data examples from your dataset and converts them to TFRecords. The TFRecords are written to a path specified by the data_dir parameter to the __init__ method of ImageClassificationBuilder.

In, SimpleDatasetConfig subclasses ImageClassificationConfig, implementing properties that define the supported modes, number of image classes, and an example generator that yields a dictionary containing image data and an image class for each example in the dataset.

The main() function creates a dataset of randomly generated image data and instantiates a SimpleDatasetConfig object specifying the number of classes and the path to the dataset on disk. Next, main() instantiates an ImageClassificationBuilder object, passing in the SimpleDatasetConfig instance. Finally, main() calls download_and_prepare(). When this method is called, the ImageClassificationBuilder instance uses the data example generator implemented by SimpleDatasetConfig to load each example and saves them to a series of TFRecord files.

For a more detailed explanation, please see the Classification Converter Notebook.

Modifying the data conversion sample to load your dataset

To convert your dataset into TFRecord format, subclass the ImageClassificationConfig class defining the following properties:

  • num_labels - returns the number of image classes
  • supported_modes - returns a list of modes supported by your data set (for example: test, train, and validate)
  • text_label_map - returns a dictionary that models the mapping between a text class label and an integer class label (SimpleDatasetConfig does not use this property, because it does not require a mapping)
  • download_path - the path from which to download your dataset (SimpleDatasetConfig does not use this property, the example_generator loads the data from disk)

Implement the example_generator generator function. This method must yield a dictionary containing the image data and the image class name for each example. ImageClassificationBuilder uses the example_generator() function to retrieve each example and writes them to disk in TFRecord format.

Running the data conversion sample

  1. Create a Cloud Storage bucket using the following command:

    gsutil mb -p ${PROJECT_ID} -c standard -l us-central1 -b on gs://bucket-name
  2. Create the VM using the ctpu command:

    ctpu up --vm-only \
      --zone=us-central1-b \
      --name=img-class-converter \
    gcloud compute ssh img-class-converter --zone=us-central1-b 

    From this point on, a prefix of (vm)$ means you should run the command on the Compute Engine VM instance.

  3. Install required packages.

    (vm)$ pip3 install opencv-python-headless pillow
  4. Create the following environment variables used by the script.

    (vm)$ export STORAGE_BUCKET=gs://bucket-name
    (vm)$ export CONVERTED_DIR=$HOME/tfrecords
    (vm)$ export GENERATED_DATA=$HOME/data
    (vm)$ export GCS_CONVERTED=$STORAGE_BUCKET/data_converter/image_classification/tfrecords
    (vm)$ export GCS_RAW=$STORAGE_BUCKET/image_classification/raw
    (vm)$ export PYTHONPATH="$PYTHONPATH:/usr/share/tpu/models"
  5. Change to the data_converter directory.

    (vm)$ cd /usr/share/tpu/tools/data_converter

Running the data converter on a fake dataset

The script is located in the image_classification folder of the data converter sample. Running the script with the following parameters generates a set of fake images and converts them into TFRecords.

(vm)$ python3 image_classification/ \
  --num_classes=1000 \
  --data_path=$GENERATED_DATA \
  --generate=True \
  --num_examples_per_class_low=10 \
  --num_examples_per_class_high=11 \

Running the data converter on one of our raw datasets

  1. Create an environment variable for the location of the raw data.

    (vm)$ export GCS_RAW=gs://cloud-tpu-test-datasets/data_converter/raw_image_classification
  2. Run the script.

    (vm)$ python3 image_classification/ \
    --num_classes=1000 \
    --data_path=$GCS_RAW \
    --generate=False \

The script takes the following parameters:

  • num_classes refers to the number of classes in the dataset. We're using 1000 here to match ImageNet format
  • generate determines whether or not to generate the raw data.
  • data_path refers to the path where the data should be generated if generate=True or the path where the raw data is stored if generate=False.
  • num_examples_per_class_low and num_examples_per_class_high determine how many examples per class to generate. The script generates a random number of examples in this range.
  • save_dir refers to where the saved TFRecords should be saved. In order to train a model on Cloud TPU, the data must be stored on Cloud Storage. This can be on Cloud Storage or on the VM.

Renaming and moving the TFRecords to Cloud Storage

The following example uses the converted data with the TF 1.x ResNet model.

  1. Rename the TFRecords to the same format as ImageNet TFRecords:

    (vm)$ cd $CONVERTED_DIR/image_classification_builder/Simple/0.1.0/
    (vm)$ rename -v 's/image_classification_builder-(\w+)\.tfrecord/$1/g' *
  2. Copy the TFRecords to Cloud Storage:

    (vm)$ gsutil -m cp train* $GCS_CONVERTED
    (vm)$ gsutil -m cp validation* $GCS_CONVERTED

Run ResNet on the generated dataset

  1. Create the TPU:

    (vm)$ export TPU_NAME=imageclassificationconverter
    (vm)$ ctpu up --zone=us-central1-b 

  2. Create an environment variable for the storage bucket directory that contains the model.

    (vm)$ export MODEL_BUCKET=${STORAGE_BUCKET}/image_classification_converter
  3. Run the model

    (vm)$ cd /usr/share/tpu/models/official/resnet
    (vm)$ python3 --tpu=${TPU_NAME}