使用 TF_CONFIG 取得分散式訓練詳細資料

AI Platform 會在每個訓練執行個體上設定一個名為 TF_CONFIG 的環境變數。這項服務和您的應用程式可在執行時存取 TF_CONFIG 中訓練工作的詳細資料。

本指南說明如何以 TF_CONFIG 環境變數存取訓練工作的詳細資料。這種技術很適合用於分散式訓練工作和超參數調整工作,因為這兩種工作都需要在訓練應用程式和服務之間進行額外的通訊。

TensorFlow 和 TF_CONFIG

TensorFlow 的 Estimator API 會剖析 TF_CONFIG 環境變數 (如果有的話),然後使用 TF_CONFIG 中的相關詳細資料來建構分散式訓練的屬性,包括叢集規格、工作 ID 及其他屬性。

如果您的應用程式使用 tf.estimator 進行分散式訓練,系統會自動啟動傳播屬性到叢集規格的作業,因為 AI Platform 已為您設定了 TF_CONFIG

同樣地,如果您在 AI Platform 上使用自訂容器執行分散式訓練應用程式,則 AI Platform 會設定 TF_CONFIG 並在每台機器上填入環境變數 CLUSTER_SPEC

TF_CONFIG 的格式

TF_CONFIG 環境變數是採用下列格式的 JSON 字串:

索引鍵 說明
"cluster" TensorFlow 叢集說明。這個物件會格式化為 TensorFlow 叢集規格,且可傳送至 tf.train.ClusterSpec 的建構函式。
"task" 說明執行程式碼的特定節點的工作。您可以使用此資訊,為分散式工作中的特定工作站編寫程式碼。這個項目是含有下列索引鍵的字典:
"type" 此節點執行的工作類型,可能的值為 masterworkerps
"index" 工作的索引,從零開始。大部分分散式訓練工作都有一個主要工作、一或多個參數伺服器,以及一或多個工作站。
"trial" 目前正在執行的超參數調整試驗 ID。為工作設定超參數調整時,您會設定要訓練的多個試驗。這個值讓您可以在程式碼中區別正在執行的不同試驗。ID 是含有試驗編號的字串值,從 1 開始。
"job" 啟動工作時使用的工作參數。在大部分情況下,您可以忽略這個項目,因為它會透過指令列引數複製傳送至應用程式的資料。

取得 TF_CONFIG 及設定分散式叢集規格

以下範例說明如何在應用程式中取得 TF_CONFIG 的內容。

本範例還示範如何在分散式訓練期間使用 TF_CONFIG 設定 tf.train.ClusterSpec附註:如果您的程式碼使用 TensorFlow Core API,您只需要建構 tf.train.ClusterSpec (依據 TF_CONFIG)。如果您使用 tf.estimator,TensorFlow 會自動剖析變數,並為您建構叢集規格。請參閱本頁面的 TensorFlow 和 TF_CONFIG 一節。

def train_and_evaluate(args):
  """Parse TF_CONFIG to cluster_spec and call run() method.

  TF_CONFIG environment variable is available when running using
  gcloud either locally or on cloud. It has all the information required
  to create a ClusterSpec which is important for running distributed code.

  Args:
    args (args): Input arguments.
  """

  tf_config = os.environ.get('TF_CONFIG')
  # If TF_CONFIG is not available run local.
  if not tf_config:
    return run(target='', cluster_spec=None, is_chief=True, args=args)

  tf_config_json = json.loads(tf_config)
  cluster = tf_config_json.get('cluster')
  job_name = tf_config_json.get('task', {}).get('type')
  task_index = tf_config_json.get('task', {}).get('index')

  # If cluster information is empty run local.
  if job_name is None or task_index is None:
    return run(target='', cluster_spec=None, is_chief=True, args=args)

  cluster_spec = tf.train.ClusterSpec(cluster)
  server = tf.train.Server(cluster_spec,
                           job_name=job_name,
                           task_index=task_index)

  # Wait for incoming connections forever.
  # Worker ships the graph to the ps server.
  # The ps server manages the parameters of the model.
  #
  # See a detailed video on distributed TensorFlow
  # https://www.youtube.com/watch?v=la_M6bCV91M
  if job_name == 'ps':
    server.join()
    return
  elif job_name in ['master', 'worker']:
    return run(server.target, cluster_spec, is_chief=(job_name == 'master'),
               args=args)

設定 device_filters 以進行分散式訓練

在大規模分散式訓練中,您必須確保機器之間的通訊可靠,這樣訓練應用程式才能在機器無法運作時繼續執行。

如要完成這項作業,您必須在訓練應用程式中設定裝置篩選器,確保不需啟動所有工作站,主伺服器也能運作。設定 device_filters 後,可以確保主伺服器和工作站機器只能與參數伺服器通訊。

以下範例說明如何在使用 tf.estimator 程式庫時設定 device_filters

def _get_session_config_from_env_var():
    """Returns a tf.ConfigProto instance that has appropriate device_filters
    set."""

    tf_config = json.loads(os.environ.get('TF_CONFIG', '{}'))

    if (tf_config and 'task' in tf_config and 'type' in tf_config['task'] and
            'index' in tf_config['task']):
        # Master should only communicate with itself and ps
        if tf_config['task']['type'] == 'master':
            return tf.ConfigProto(device_filters=['/job:ps', '/job:master'])
        # Worker should only communicate with itself and ps
        elif tf_config['task']['type'] == 'worker':
            return tf.ConfigProto(device_filters=[
                '/job:ps',
                '/job:worker/task:%d' % tf_config['task']['index']
            ])
    return None

接下來,將這個 session_config 傳送至 tf.estimator.RunConfig,如下所示:

config = tf.estimator.RunConfig(session_config=_get_session_config_from_env_var())

在 Tensorflow 1.10 及之後的版本中,這是 tf.estimator 程式庫中的預設設定。

後續步驟

本頁內容對您是否有任何幫助?請提供意見:

傳送您對下列選項的寶貴意見...

這個網頁
TensorFlow 適用的 AI Platform