Using TF_CONFIG for Distributed Training Details

AI Platform sets an environment variable called TF_CONFIG on each training instance. The service and your application can access details of the training jobs in TF_CONFIG while running.

This guide shows you how to access details of your training jobs in the TF_CONFIG environment variable. This technique is useful for distributed training jobs and for hyperparameter tuning jobs, as they both require extra communication between your training application and the service.

TensorFlow and TF_CONFIG

TensorFlow's Estimator API parses the TF_CONFIG environment variable, if present, and uses the relevant details from TF_CONFIG to construct properties for distributed training, including the cluster spec, task ID, and other properties.

If your application uses tf.estimator for distributed training, the propagation of properties to the cluster spec works automatically, as AI Platform sets TF_CONFIG for you.

Similarly, if you run your distributed training application on AI Platform with a custom container, then AI Platform sets TF_CONFIG and populates an environment variable, CLUSTER_SPEC, on each machine.

The format of TF_CONFIG

The TF_CONFIG environment variable is a JSON string with the following format:

Key Description
"cluster" The TensorFlow cluster description. This object is formatted as a TensorFlow cluster specification, and can be passed to the constructor of tf.train.ClusterSpec.
"task" Describes the task of the particular node on which your code is running. You can use this information to write code for specific workers in a distributed job. This entry is a dictionary with the following keys:
"type" The type of task performed by this node. Possible values are master, worker, and ps.
"index" The zero-based index of the task. Most distributed training jobs have a single master task, one or more parameter servers, and one or more workers.
"trial" The identifier of the hyperparameter tuning trial currently running. When you configure hyperparameter tuning for your job, you set a number of trials to train. This value gives you a way to differentiate in your code between trials that are running. The identifier is a string value containing the trial number, starting at 1.
"job" The job parameters you used when you initiated the job. In most cases, you can ignore this entry, as it replicates the data passed to your application through its command-line arguments.

Getting TF_CONFIG and setting the distributed cluster spec

The example below shows how to get the contents of TF_CONFIG in your application.

The example also shows you how to use TF_CONFIG to set tf.train.ClusterSpec during distributed training. Note: You only need to build tf.train.ClusterSpec from TF_CONFIG if your code uses the TensorFlow core APIs. If you use tf.estimator, TensorFlow parses the variable and builds the cluster spec for you. See the section about TensorFlow and TF_CONFIG on this page.

def train_and_evaluate(args):
  """Parse TF_CONFIG to cluster_spec and call run() method.

  TF_CONFIG environment variable is available when running using
  gcloud either locally or on cloud. It has all the information required
  to create a ClusterSpec which is important for running distributed code.

  Args:
    args (args): Input arguments.
  """

  tf_config = os.environ.get('TF_CONFIG')
  # If TF_CONFIG is not available run local.
  if not tf_config:
    return run(target='', cluster_spec=None, is_chief=True, args=args)

  tf_config_json = json.loads(tf_config)
  cluster = tf_config_json.get('cluster')
  job_name = tf_config_json.get('task', {}).get('type')
  task_index = tf_config_json.get('task', {}).get('index')

  # If cluster information is empty run local.
  if job_name is None or task_index is None:
    return run(target='', cluster_spec=None, is_chief=True, args=args)

  cluster_spec = tf.train.ClusterSpec(cluster)
  server = tf.train.Server(cluster_spec,
                           job_name=job_name,
                           task_index=task_index)

  # Wait for incoming connections forever.
  # Worker ships the graph to the ps server.
  # The ps server manages the parameters of the model.
  #
  # See a detailed video on distributed TensorFlow
  # https://www.youtube.com/watch?v=la_M6bCV91M
  if job_name == 'ps':
    server.join()
    return
  elif job_name in ['master', 'worker']:
    return run(server.target, cluster_spec, is_chief=(job_name == 'master'),
               args=args)

Setting device_filters for distributed training

In large scale distributed training, you need to ensure that the communication between machines is reliable so that your training application is resilient to machine failure.

To accomplish this, you need to set device filters in your training application to make sure the master does not rely on all the workers to be active. By setting device_filters, you can ensure that the master and worker machines communicate only with parameter servers.

The example below shows how to set device_filters when using the tf.estimator library:

def _get_session_config_from_env_var():
    """Returns a tf.ConfigProto instance that has appropriate device_filters
    set."""

    tf_config = json.loads(os.environ.get('TF_CONFIG', '{}'))

    if (tf_config and 'task' in tf_config and 'type' in tf_config['task'] and
            'index' in tf_config['task']):
        # Master should only communicate with itself and ps
        if tf_config['task']['type'] == 'master':
            return tf.ConfigProto(device_filters=['/job:ps', '/job:master'])
        # Worker should only communicate with itself and ps
        elif tf_config['task']['type'] == 'worker':
            return tf.ConfigProto(device_filters=[
                '/job:ps',
                '/job:worker/task:%d' % tf_config['task']['index']
            ])
    return None

Next, pass this session_config to tf.estimator.RunConfig as follows:

config = tf.estimator.RunConfig(session_config=_get_session_config_from_env_var())

This is set by default in the tf.estimator library from Tensorflow 1.10 and onwards.

What's next

Was this page helpful? Let us know how we did:

Send feedback about...

AI Platform for TensorFlow