TensorFlow Trainer Development Considerations

Cloud Machine Learning Engine can run an existing TensorFlow training application with little or no alteration. Developing a trainer is a complex process that is largely outside of the scope of this documentation. You can start learning TensorFlow by working through its getting started guide.

Once you understand some TensorFlow fundamentals, the best way to learn about what goes into a training application is to study a good example. The samples page of this documentation includes information about TensorFlow samples that have been developed specifically to work well with Cloud ML Engine. We suggest that you use the two samples that use U.S. census data to predict income level as primary sources for TensorFlow best practices when working with Cloud ML Engine. The two trainers have nearly identical functionality, but one uses the lower-level APIs of TensorFlow core and the other uses the higher-level estimator-based APIs. You can learn a lot by reading these two samples, and by comparing the ways that they accomplish the same thing.

Here are some important things to look for as you look through the samples:

  • How they work with command-line arguments to get important information that may change from one training job to another. Cloud ML Engine passes arguments that you specify when you start a training job to each replica of your trainer that it runs in the cloud. Command-line arguments are the primary mechanism for communicating with your trainer at the time of execution.

  • How they use the TF_CONFIG environment variable to set up a distributed processing cluster. This is the method by which Cloud ML Engine communicates job information to the individual replicas of your trainer that run on the allocated training instances.

  • How they manage distributed processing to account for the different task types (master, parameter server, and worker) in one application.

  • How they accommodate different stages in the training process (notably training, evaluation, and export) by using checkpoints and variations of the computation graph.

  • How they define input and output data, both for training the model and then for exporting it.

Use the TF_CONFIG environment variable

Distributed training and hyperparameter tuning both require extra communication between your trainer and the service. Cloud Machine Learning Engine sets an environment variable called TF_CONFIG on each training instance so that your trainer can access it while running.

TF_CONFIG is a JSON string, following the structure shown here:

Key Description
"cluster" The TensorFlow cluster description. This object is formatted as a TensorFlow cluster specification, and can be passed to the constructor of tf.train.ClusterSpec.
"task" Describes the task of the particular node on which your code is running. Using this information enables you to write code for specific workers in a distributed job. This entry is a dictionary with the following keys:
"type" The type of task performed by this node. Will be master, worker, or ps.
"index" The zero-based index of the task. Most distributed training jobs will have a single master task, one or more parameter servers, and one or more workers.
"trial" The identifier of the trial being run. When you configure hyperparameter tuning for your job, you set a number of trials to train. This value gives you a way to differentiate between trials that are running in your code. The identifier is a string value containing the trial number, starting with trial 1.
"job" The job parameters you used when you initiated the job. In most cases, you can ignore this entry, as it replicates the data passed to your trainer application through its command-line arguments.

Getting TF_CONFIG

The following example shows how to get the contents of TF_CONFIG in your trainer code. The code puts the information into a Python dictionary and retrieves the task type from it.

def dispatch(*args, **kwargs):
  """Parse TF_CONFIG to cluster_spec and call run() method
  TF_CONFIG environment variable is available when running using
  gcloud either locally or on cloud. It has all the information required
  to create a ClusterSpec which is important for running distributed code.
  """

  tf_config = os.environ.get('TF_CONFIG')

  # If TF_CONFIG is not available run local
  if not tf_config:
    return run(target='', cluster_spec=None, is_chief=True, *args, **kwargs)

  tf_config_json = json.loads(tf_config)

  cluster = tf_config_json.get('cluster')
  job_name = tf_config_json.get('task', {}).get('type')
  task_index = tf_config_json.get('task', {}).get('index')

  # If cluster information is empty run local
  if job_name is None or task_index is None:
    return run(target='', cluster_spec=None, is_chief=True, *args, **kwargs)

  cluster_spec = tf.train.ClusterSpec(cluster)
  server = tf.train.Server(cluster_spec,
                           job_name=job_name,
                           task_index=task_index)

  # Wait for incoming connections forever
  # Worker ships the graph to the ps server
  # The ps server manages the parameters of the model.
  #
  # See a detailed video on distributed TensorFlow
  # https://www.youtube.com/watch?v=la_M6bCV91M
  if job_name == 'ps':
    server.join()
    return
  elif job_name in ['master', 'worker']:
    return run(server.target, cluster_spec, is_chief=(job_name == 'master'),
               *args, **kwargs)

What's next

Monitor your resources on the go

Get the Google Cloud Console app to help you manage your projects.

Send feedback about...

Cloud Machine Learning Engine (Cloud ML Engine)