This tutorial shows you how to train the Inception model on Cloud TPU.
Disclaimer
This tutorial uses a third-party dataset. Google provides no representation, warranty, or other guarantees about the validity, or any other aspects of, this dataset.
Model description
Inception v3 is an image recognition model that can attain significant accuracy. The model is the culmination of many ideas developed by multiple researchers over the years. It is based on the original paper: "Rethinking the Inception Architecture for Computer Vision" by Szegedy, et. al.
The model has a mixture of symmetric and asymmetric building blocks, including:
- convolutions
- average pooling
- max pooling
- concatenate
- dropouts
- fully connected layers
Loss is computed using Softmax.
The following picture shows the model at a high level:
You can find more information about the model at GitHub.
The model is built using the Estimator API.
The API greatly simplifies model creation by encapsulating most low level functions, allowing you to focus on model development, not the inner workings of the underlying hardware that runs things.
Objectives
- Create a Cloud Storage bucket to hold your dataset and model output.
- Run the training job.
- Verify the output results.
Costs
This tutorial uses the following billable components of Google Cloud:
- Compute Engine
- Cloud TPU
- Cloud Storage
To generate a cost estimate based on your projected usage,
use the pricing calculator.
Before you begin
Before starting this tutorial, check that your Google Cloud project is correctly set up.
- Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Cloud project. Learn how to check if billing is enabled on a project.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Cloud project. Learn how to check if billing is enabled on a project.
This walkthrough uses billable components of Google Cloud. To estimate your costs, see the Cloud TPU pricing page. Be sure to clean up resources you create when you've finished with them to avoid unnecessary charges.
Set up your resources
This section provides information on setting up Cloud Storage, VM, and Cloud TPU resources for tutorials.
Open a Cloud Shell window.
Create a variable for your project's ID.
export PROJECT_ID=project-id
Configure Google Cloud CLI to use the project in which you want to create your Cloud TPU.
gcloud config set project ${PROJECT_ID}
The first time you run this command in a new Cloud Shell VM, an
Authorize Cloud Shell
page is displayed. ClickAuthorize
at the bottom of the page to allowgcloud
to make Google Cloud API calls with your credentials.Create a Service Account for the Cloud TPU project.
gcloud beta services identity create --service tpu.googleapis.com --project $PROJECT_ID
The command returns a Cloud TPU Service Account with following format:
service-PROJECT_NUMBER@cloud-tpu.iam.gserviceaccount.com
Create a Cloud Storage bucket using the following command. Replace bucket-name with a name for your bucket.
gsutil mb -p ${PROJECT_ID} -c standard -l us-central1 -b on gs://bucket-name
The Cloud Storage bucket stores the data you use to train your model and the training results. The
ctpu up
tool sets up default permissions for the Cloud TPU Service Account. If you want finer-grain permissions, review the access level permissions.The bucket location must be in the same region as your virtual machine (VM) and your TPU node. VMs and TPU nodes are located in specific zones, which are subdivisions within a region.
Launch the Compute Engine resources using the
ctpu up
command.ctpu up --project=${PROJECT_ID} \ --zone=us-central1-b \ --vm-only \ --machine-type=n1-standard-8 \ --tf-version=1.15.5 \ --name=inception-tutorial
Command flag descriptions
project
- Your Google Cloud project ID
zone
- The zone where you plan to create your Cloud TPU.
vm-only
- Creates the VM without creating a Cloud TPU. By default the
ctpu up
command creates a VM and a Cloud TPU. machine-type
- The machine type of the Compute Engine VM to create.
tf-version
- The version of TensorFlow
ctpu
installs on the VM. name
- The name of the Cloud TPU to create.
For more information on the CTPU utility, see CTPU Reference.
When prompted, press y to create your Cloud TPU resources.
To verify you are logged into your Compute Engine VM, your shell prompt should have changed from
username@projectname
tousername@vm-name
. If you are not connected to the Compute Engine instance, you can do so by running the following command:gcloud compute ssh inception-tutorial --zone=us-central1-b
From this point on, a prefix of
(vm)$
means you should run the command on the Compute Engine VM instance.Create an environment variable for the storage bucket. Replace bucket-name with the name of your Cloud Storage bucket.
(vm)$ export STORAGE_BUCKET=gs://bucket-name
Create an environment variable for the TPU name.
(vm)$ export TPU_NAME=inception-tutorial
Training dataset
The training application expects your training data to be accessible in Cloud Storage. The training application also uses your Cloud Storage bucket to store checkpoints during training.
ImageNet is an image database. The images in the database are organized into a hierarchy, with each node of the hierarchy depicted by hundreds and thousands of images.
This tutorial uses a demonstration version of the full ImageNet dataset, referred to as the fake_imagenet dataset. This dataset enables you to test the tutorial, without requiring the storage or time that required to download and run a model against the full ImageNet database. Alternatively, you can use the full ImageNet dataset.
A DATA_DIR
environment variable is used to specify the dataset on which to
train.
The fake_imagenet dataset is only useful for understanding how to use a Cloud TPU. The accuracy numbers and saved model are not meaningful.
The fake_imagenet dataset is the following location on Cloud Storage:
gs://cloud-tpu-test-datasets/fake_imagenet
(Optional) Set up TensorBoard
TensorBoard offers a suite of tools designed to present TensorFlow data visually. When used for monitoring, TensorBoard can help identify bottlenecks in processing and suggest ways to improve performance.
If you don't need to monitor the model's output, you can skip the TensorBoard setup steps.
If you want to monitor the model's output and performance, follow the guide to setting up TensorBoard.
Run the model
You are now ready to train and evaluate the Inception v3 model using ImageNet data.
The Inception v3 model is pre-installed on your Compute Engine VM, in
the /usr/share/tpu/models/experimental/inception/
directory.
In the following steps, a prefix of (vm)$
means you should run the command on
your Compute Engine VM:
Set up a
DATA_DIR
environment variable containing one of the following values:If you are using the fake_imagenet dataset:
(vm)$ export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet
If you have uploaded a set of training data to your Cloud Storage bucket:
(vm)$ export DATA_DIR=${STORAGE_BUCKET}/data
Run the Inception v3 model:
(vm)$ python /usr/share/tpu/models/experimental/inception/inception_v3.py \ --tpu=$TPU_NAME \ --learning_rate=0.165 \ --train_steps=250000 \ --iterations=500 \ --use_tpu=True \ --use_data=real \ --mode=train_and_eval \ --train_steps_per_eval=2000 \ --data_dir=${DATA_DIR} \ --model_dir=${STORAGE_BUCKET}/inception
--tpu
specifies the name of the Cloud TPU.ctpu
passes this name to the Compute Engine VM as an environment variable (TPU_NAME
).--use_data
specifies which type of data the program must use during training, either fake or real. The default value is fake.--data_dir
specifies the Cloud Storage path for training input. The application ignores this parameter when you're using fake_imagenet data.--model_dir
specifies the directory where checkpoints and summaries are stored during model training. If the folder is missing, the program creates one. When using a Cloud TPU, themodel_dir
must be a Cloud Storage path (gs://...
). You can reuse a folder to load current checkpoint data and to store additional checkpoints. You must use the same TensorFlow version to write and load checkpoints.
What to expect
Inception v3 operates on 299x299 images. The default training batch size is 1024, which means that each iteration operates on 1024 of those images.
You can use the --mode
flag to select one of three modes of operation:
train, eval, and train_and_eval:
--mode=train
or--mode=eval
specifies either a training-only or an evaluation-only job.--mode=train_and_eval
specifies a hybrid job that does both training and evaluation.
Train-only jobs run for the specified number of steps defined in train_steps
and can go through the entire training set, if desired.
Train_and_eval jobs cycle though training and evaluation segments. Each training
cycle runs for train_steps_per_eval
and is followed by an evaluation job
(using the weights that have been trained up to that point).
You can calculate the number of training cycles using the floor
function of train_steps
divided by train_steps_per_eval
.
floor(train_steps / train_steps_per_eval)
By default, Estimator API-based models report loss values every certain number
of steps. The reporting format looks like:
step = 15440, loss = 12.6237
Discussion: TPU-specific modifications to the model
The specific modifications required to get Estimator API-based models ready for
TPUs are surprisingly minimal. The program imports the following libraries:
from google.third_party.tensorflow.contrib.tpu.python.tpu import tpu_config
from google.third_party.tensorflow.contrib.tpu.python.tpu import tpu_estimator
from google.third_party.tensorflow.contrib.tpu.python.tpu import tpu_optimizer
The CrossShardOptimizer function wraps the optimizer, as in:
if FLAGS.use_tpu:
optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)
The function that defines the model returns an Estimator specification using:
return tpu_estimator.TPUEstimatorSpec(
mode=mode, loss=loss, train_op=train_op, eval_metrics=eval_metrics)
The main function defines an Estimator-compatible configuration using:
run_config = tpu_config.RunConfig(
master=tpu_grpc_url,
evaluation_master=tpu_grpc_url,
model_dir=FLAGS.model_dir,
save_checkpoints_secs=FLAGS.save_checkpoints_secs,
save_summary_steps=FLAGS.save_summary_steps,
session_config=tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=FLAGS.log_device_placement),
tpu_config=tpu_config.TPUConfig(
iterations_per_loop=iterations,
num_shards=FLAGS.num_shards,
per_host_input_for_training=per_host_input_for_training))
The program uses this defined configuration and a model definition function to
create an Estimator object:
inception_classifier = tpu_estimator.TPUEstimator(
model_fn=inception_model_fn,
use_tpu=FLAGS.use_tpu,
config=run_config,
params=params,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=eval_batch_size,
batch_axis=(batch_axis, 0))
Train-only jobs need only to call the train function:
inception_classifier.train(
input_fn=imagenet_train.input_fn, steps=FLAGS.train_steps)
Evaluation-only jobs get their data from available checkpoints and wait until a
new one becomes available:
for checkpoint in get_next_checkpoint():
eval_results = inception_classifier.evaluate(
input_fn=imagenet_eval.input_fn,
steps=eval_steps,
hooks=eval_hooks,
checkpoint_path=checkpoint)
When you choose the option train_and_eval
, the training and the evaluation
jobs run in parallel. During evaluation, trainable variables are loaded from the
latest available checkpoint. Training and evaluation cycles repeat as you
specify in the flags::
```
for cycle in range(FLAGS.train_steps // FLAGS.train_steps_per_eval):
inception_classifier.train(
input_fn=imagenet_train.input_fn, steps=FLAGS.train_steps_per_eval)
eval_results = inception_classifier.evaluate(
input_fn=imagenet_eval.input_fn, steps=eval_steps, hooks=eval_hooks)
If you used the fake\_imagenet dataset to train the model, proceed to
[clean up](#clean-up).
## Using the full Imagenet dataset {: #full-dataset }
The ImageNet dataset consists of three parts, training data, validation data,
and image labels.
The training data contains 1000 categories and 1.2 million images, packaged for
easy downloading. The validation and test data are not contained in the ImageNet
training data (duplicates have been removed).
The validation and test data consists of 150,000 photographs, collected from
[Flickr](https://www.flickr.com/) and other search engines, hand labeled with
the presence or absence of 1000 object categories. The 1000 object categories
contain both internal nodes and leaf nodes of ImageNet, but do not overlap with
each other. A random subset of 50,000 of the images with labels has been
released as validation data along with a list of the 1000 categories. The
remaining images are used for evaluation and have been released without labels.
### Steps to pre-processing the full ImageNet dataset
There are five steps to preparing the full ImageNet dataset for use by a Machine
Learning model:
1. Verify that you have space on the download target.
1. Set up the target directories.
1. Register on the ImageNet site and request download permission.
1. Download the dataset to local disk or Compute Engine VM.
Note: Downloading the Imagenet dataset to a Compute Engine VM takes
considerably longer than downloading to your local machine (approximately 40
hours versus 7 hours). If you download the dataset to your local
machine, you must copy the files to a Compute Engine VM to pre-process them.
You must then upload the files to Cloud Storage before using them to train
your model. Copying the training and validation files from your local machine to the VM
takes about 13 hours. The recommended approach is to download the dataset to
a VM.
1. Run the pre-processing and upload script.
### Verify space requirements
Whether you download the dataset to your local machine or to a Compute Engine
VM, you need about 300GB of space available on the download target. On a VM, you
can check your available storage with the `df -ha` command.
Note: If you use `gcloud compute` to set up your VM, it will allocate 250 GB by
default.
You can increase the size of the VM disk using one of the following methods:
* Specify the `--disk-size` flag on the `gcloud compute` command line with the
size, in GB, that you want allocated.
* Follow the Compute Engine guide to [add a disk][add-disk] to your
VM.
* Set **When deleting instance** to **Delete disk** to ensure that the
disk is removed when you remove the VM.
* Make a note of the path to your new disk. For example: `/mnt/disks/mnt-dir`.
### Set up the target directories
On your local machine or Compute Engine VM, set up the directory structure to
store the downloaded data.
* Create and export a home directory for the ImageNet dataset.
Create a directory, for example, `imagenet` under your home directory on
your local machine or VM. Under this directory, create two sub directories:
`train` and `validation`. Export the home directory as IMAGENET_HOME:
<pre class="prettyprint lang-sh tat-dataset">
export IMAGENET_HOME=~/imagenet
</pre>
### Register and request permission to download the dataset
* Register on the [Imagenet website](http://image-net.org/). You cannot
download the dataset until ImageNet confirms your registration and sends you
a confirmation email. If you do not get the confirmation email within a
couple of days, contact [ImageNet support](mailto:support@image-net.org) to
see why your registration has not been confirmed. Once your registration is
confirmed, you can download the dataset. The Cloud TPU tutorials that use the
ImageNet dataset use the images from the ImageNet Large Scale Visual
Recognition Challenge 2012 (ILSVRC2012).
### Download the ImageNet dataset
1. From the [LSRVC 2012 download site](https://image-net.org/challenges/LSVRC/2012/2012-downloads.php),
go to the Images section on the page and right-click
"Training images (Task 1 & 2)". The URL to download
the largest part of the training set. Save the URL.
Right-click "Training images (Task 3)" to get the URL for the second
training set. Save the URL.
Right-click "Validation images (all tasks)" to get the URL for the
validation dataset. Save the URL.
If you download the ImageNet files to your local machine, you need to copy
the directories on your local machine to the corresponding `$IMAGENET_HOME`
directory on your Compute Engine VM. Copying the ImageNet dataset from
local host to your VM takes approximately 13 hours.
The following command copies the files under
$IMAGENET_HOME on your local machine to <var>~/imagenet</var> on your VM (<var>username@vm-name</var>):
<pre class="prettyprint lang-sh tat-dataset">
gcloud compute scp --recurse $IMAGENET_HOME <var>username@vm-name</var>:~/imagenet
</pre>
1. From $IMAGENET_HOME, use `wget` to download the training and validation files
using the saved URLs.
The "Training images (Task 1 & 2)" file is the large training set. It is
138 GB and if you are downloading to a Compute Engine VM using the Cloud
Shell, the download takes approximately 40 hours. If the Cloud Shell loses its
connection to the VM, you can prepend `nohup` to the command or use
[screen](https://linuxize.com/post/how-to-use-linux-screen/).
<pre class="prettyprint lang-sh tat-dataset">
cd $IMAGENET_HOME \
nohup wget http://image-net.org/challenges/LSVRC/2012/dd31405981ef5f776aa17412e1f0c112/ILSVRC2012_img_train.tar
</pre>
This command downloads a large tar file: ILSVRC2012_img_train.tar.
From $IMAGENET_HOME on the VM, extract the individual training directories
into the `$IMAGENET_HOME/train` directory using the following command. The
extraction takes between 1 - 3 hours.
<pre class="prettyprint lang-sh tat-dataset">
tar xf ILSVRC2012_img_train.tar
</pre>
Extract the individual training tar files located in the $IMAGENET_HOME/train
directory, as shown in the following script:
<pre class="prettyprint lang-sh tat-dataset">
cd $IMAGENET_HOME/train
for f in *.tar; do
d=`basename $f .tar`
mkdir $d
tar xf $f -C $d
done
</pre>
Delete the tar files after you have extracted them to free up disk space.
The "Training images (Task 3)" file is 728 MB and takes just a few minutes
to download so you do not need to take precautions against losing the Cloud
Shell connection.
When you download this file, it extracts the individual training directories
into the existing `$IMAGENET_HOME/train` directory.
<pre class="prettyprint lang-sh tat-dataset">
wget http://www.image-net.org/challenges/LSVRC/2012/dd31405981ef5f776aa17412e1f0c112/ILSVRC2012_img_train_t3.tar
</pre>
When downloading the "Validation images (all tasks)" file, your Cloud Shell may disconnect.
You can use `nohup` or [screen](https://linuxize.com/post/how-to-use-linux-screen/) to
prevent Cloud Shell from disconnecting.
<pre class="prettyprint lang-sh tat-dataset">
wget http://www.image-net.org/challenges/LSVRC/2012/dd31405981ef5f776aa17412e1f0c112/ILSVRC2012_img_val.tar
</pre>
This download takes about 30 minutes. When you download this file, it
extracts the individual validation directories into the
`$IMAGENET_HOME/validation` directory.
If you downloaded the validation files to your local machine, you need to
copy the `$IMAGENET_HOME/validation` directory on your local machine to the
`$IMAGENET_HOME/validation` directory on your Compute Engine VM. This copy
operation takes about 30 minutes.
Download the labels file.
<pre class="prettyprint lang-sh tat-dataset">
wget -O $IMAGENET_HOME/synset_labels.txt \
https://raw.githubusercontent.com/tensorflow/models/master/research/inception/inception/data/imagenet_2012_validation_synset_labels.txt
</pre>
If you downloaded the labels file to your local machine, you need to copy it
to the `$IMAGENET_HOME` directory on your local machine to `$IMAGENET_HOME`
on your Compute Engine VM. This copy operation takes a few seconds.
The training subdirectory names (for example, n03062245) are "WordNet IDs"
(wnid). The [ImageNet API](https://image-net.org/download-attributes.php)
shows the mapping of WordNet IDs to their associated validation labels in the
`synset_labels.txt` file. A synset in this context is a visually similar
group of images.
### Process the Imagenet dataset and, optionally, upload to Cloud Storage
1. Download the `imagenet_to_gcs.py` script from GitHub:
<pre class="prettyprint lang-sh tat-dataset">
wget https://raw.githubusercontent.com/tensorflow/tpu/master/tools/datasets/imagenet_to_gcs.py
</pre>
1. If you are uploading the dataset to Cloud Storage, specify the storage
bucket location to upload the ImageNet dataset:
<pre class="lang-sh prettyprint tat-client-exports">
export STORAGE_BUCKET=gs://<var>bucket-name</var>
</pre>
1. If you are uploading the dataset to your local machine or VM, specify a data
directory to hold the dataset:
<pre class="lang-sh prettyprint tat-client-exports">
<span class="no-select">(vm)$ </span>export DATA_DIR=$IMAGENET_HOME/<var>dataset-directory</var>
</pre>
1. Run the script to pre-process the raw dataset as TFRecords and upload it to
Cloud Storage using the following command:
Note: If you don't want to upload to Cloud Storage, specify `--nogcs_upload`
as another parameter and leave off the `--project` and `--gcs_output_path`
parameters.
<pre class="prettypring lang-sh tat-dataset">
python3 imagenet_to_gcs.py \
--project=$PROJECT \
--gcs_output_path=$STORAGE_BUCKET \
--raw_data_dir=$IMAGENET_HOME \
--local_scratch_dir=$IMAGENET_HOME/tf_records
</pre>
Note: Downloading and preprocessing the data can take 10 or more hours,
depending on your network and computer speed. Do not interrupt the script.
The script generates a set of directories (for both training and validation) of
the form:
${DATA_DIR}/train-00000-of-01024
${DATA_DIR}/train-00001-of-01024
...
${DATA_DIR}/train-01023-of-01024
and
${DATA_DIR}/validation-00000-of-00128
S{DATA_DIR}/validation-00001-of-00128
...
${DATA_DIR}/validation-00127-of-00128
After the data has been uploaded to your Cloud bucket, run your model and set
`--data_dir=${DATA_DIR}`.
## Clean up {: #clean-up }
To avoid incurring charges to your GCP account for the resources used
in this topic:
1. Disconnect from the Compute Engine VM:
<pre class="lang-sh prettyprint tat-skip">
<span class="no-select">(vm)$ </span>exit
</pre>
Your prompt should now be `username@projectname`, showing you are in the
Cloud Shell.
1. In your Cloud Shell, run `ctpu delete` with the --zone flag you used when
you set up the Cloud TPU to delete your
Compute Engine VM and your Cloud TPU:
<pre class="lang-sh prettyprint tat-resource-setup">
<span class="no-select">$ </span>ctpu delete [optional: --zone]
</pre>
Important: If you set the TPU resources name when you ran `ctpu up`, you must
specify that name with the `--name` flag when you run `ctpu delete` in
order to shut down your TPU resources.
1. Run `ctpu status` to make sure you have no instances allocated to avoid
unnecessary charges for TPU usage. The deletion might take several minutes.
A response like the one below indicates there are no more allocated
instances:
<pre class="lang-sh prettyprint tat-skip">
<span class="no-select">$ </span>ctpu status --zone=europe-west4-a
</pre>
<pre class="lang-sh prettyprint tat-skip">
2018/04/28 16:16:23 WARNING: Setting zone to "--zone=europe-west4-a"
No instances currently exist.
Compute Engine VM: --
Cloud TPU: --
</pre>
1. Run `gsutil` as shown, replacing <var>bucket-name</var> with the name of the
Cloud Storage bucket you created for this tutorial:
<pre class="lang-sh prettyprint tat-resource-setup">
<span class="no-select">$ </span>gsutil rm -r gs://<var>bucket-name</var>
</pre>
Note: For free storage limits and other pricing information, see the
[Cloud Storage pricing guide](/storage/pricing).
## Inception v4
The Inception v4 model is a deep neural network model that uses Inception v3
building blocks to achieve higher accuracy than Inception v3. It is described in
the paper "Inception-v4, Inception-ResNet and the Impact of Residual Connections
on Learning" by Szegedy et. al.
The Inception v4 model is pre-installed on your Compute Engine VM, in
the `/usr/share/tpu/models/experimental/inception/` directory.
In the following steps, a prefix of `(vm)$` means you should run the command on
your Compute Engine VM:
1. If you have TensorBoard running in your Cloud Shell tab, you need another tab
to work in. Open another tab in your Cloud Shell, and use `ctpu` in the new
shell to connect to your Compute Engine VM:
<pre class="lang-sh prettyprint">
<span class="no-select">$ </span>ctpu up --project=${PROJECT_ID} </pre>
1. Set up a `DATA_DIR` environment variable containing one of the following
values:
* If you are using the fake\_imagenet dataset:
<pre class="prettyprint lang-sh">
<span class="no-select">(vm)$ </span>export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet
</pre>
* If you have uploaded a set of training data to your Cloud Storage
bucket:
<pre class="prettyprint lang-sh">
<span class="no-select">(vm)$ </span>export DATA_DIR=${STORAGE_BUCKET}/data
</pre>
1. Run the Inception v4 model:
<pre class="lang-sh prettyprint">
<span class="no-select">(vm)$ </span>python /usr/share/tpu/models/experimental/inception/inception_v4.py \
--tpu=$TPU_NAME \
--learning_rate=0.36 \
--train_steps=1000000 \
--iterations=500 \
--use_tpu=True \
--use_data=real \
--train_batch_size=256 \
--mode=train_and_eval \
--train_steps_per_eval=2000 \
--data_dir=${DATA_DIR} \
--model_dir=${STORAGE_BUCKET}/inception</pre>
* `--tpu` specifies the name of the Cloud TPU. `ctpu`
passes this name to the Compute Engine VM as an environment
variable (`TPU_NAME`).
* `--use_data` specifies which type of data the program must use during
training, either fake or real. The default value is fake.
* `--train_batch_size` specifies the train batch size to be 256. As the
Inception v4 model is larger than Inception v3, it must be run at a
smaller batch size per TPU core.
* `--data_dir` specifies the Cloud Storage path for training input.
The application ignores this parameter when you're using fake\_imagenet
data.
* `--model_dir` specifies the directory where checkpoints and
summaries are stored during model training. If the folder is missing, the
program creates one. When using a Cloud TPU, the `model_dir`
must be a Cloud Storage path (`gs://...`). You can reuse an existing
folder to load current checkpoint data and to store additional
checkpoints as long as the previous checkpoints were created using TPU of
the same size and TensorFlow version.