Cloud TPU 上の TPUEstimator API

このドキュメントでは、Cloud TPU で TPUEstimator API を使用する方法について説明します。TPUEstimator は、多数の低レベルのハードウェア固有の詳細を処理することにより、Cloud TPU でのモデルの実行を簡素化します。

TPUEstimator を使用して作成されたモデルは、CPU、GPU、単一の TPU デバイス、TPU Pod で動作します。通常、コードの変更は必要ありません。また、TPUEstimator を使用すると、一部の最適化が自動的に行われるため、パフォーマンスを最大限に高めることが容易になります。

TPU ハードウェア上での機械学習ワークロードの動作に関する全般的な情報については、システム アーキテクチャのドキュメントをご覧ください。

標準の TensorFlow Estimator API

標準の TensorFlow Estimator API は以下を提供します。

  • Estimator.train() - 一定数のステップで指定された入力についてモデルをトレーニングします。
  • Estimator.evaluate() - テストセットでモデルを評価します。
  • Estimator.predict() - トレーニングされたモデルを使用して推論を実行します。
  • Estimator.export_savedmodel() - 提供するモデルをエクスポートします。

また、Estimator には、チェックポイントの保存と復元、TensorBoard のサマリーの作成など、トレーニング ジョブに共通するデフォルトの動作が含まれています。

Estimator では、TensorFlow グラフのモデルおよび入力部分に対応する model_fninput_fn を記述する必要があります。

TPUEstimator プログラミング モデル

TPUEstimator は計算(model_fn)をラップして、使用可能なすべての Cloud TPU コアに分散します。学習率はバッチサイズで調整する必要があります。

  • input_fn 関数は、リモートホスト CPU で実行する入力パイプラインをモデル化します。tf.data は、プログラマー向けガイドで説明されているように、入力演算のプログラミングに使用されます。各呼び出しは、1 台のデバイスへのグローバル バッチの入力を処理します。シャードのバッチサイズは params['batch_size'] から取得されます。ワンポイント: 最適なパフォーマンスを得るには、テンソルの代わりにデータセットを返します。

  • model_fn 関数は、複製して TPU に分散する計算をモデル化します。この計算には、Cloud TPU でサポートされている演算のみが含まれます。使用可能な演算のリストについては、TensorFlow 演算をご覧ください。

TPUEstimator を使用したトレーニング例

次のコードは、TPUEstimator を使用した MNIST のトレーニングを示しています。

def model_fn(features, labels, mode, params):
  """A simple CNN."""
  del params  # unused

  input_layer = tf.reshape(features, [-1, 28, 28, 1])
  conv1 = tf.layers.conv2d(
      inputs=input_layer, filters=32, kernel_size=[5, 5], padding="same",
      activation=tf.nn.relu)
  pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
  conv2 = tf.layers.conv2d(
      inputs=pool1, filters=64, kernel_size=[5, 5],
      padding="same", activation=tf.nn.relu)
  pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
  pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])
  dense = tf.layers.dense(inputs=pool2_flat, units=128, activation=tf.nn.relu)
  dropout = tf.layers.dropout(
      inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)
  logits = tf.layers.dense(inputs=dropout, units=10)
  onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=10)

  loss = tf.losses.softmax_cross_entropy(
      onehot_labels=onehot_labels, logits=logits)

  learning_rate = tf.train.exponential_decay(
      FLAGS.learning_rate, tf.train.get_global_step(), 100000, 0.96)

  optimizer = tpu_optimizer.CrossShardOptimizer(
      tf.train.GradientDescentOptimizer(learning_rate=learning_rate))

  train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
  return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op)

def make_input_fn(filename):
  """Returns an `input_fn` for train and eval."""

  def input_fn(params):
    """An input_fn to parse 28x28 images from filename using tf.data."""
    batch_size = params["batch_size"]

    def parser(serialized_example):
      """Parses a single tf.Example into image and label tensors."""
      features = tf.parse_single_example(
          serialized_example,
          features={
              "image_raw": tf.FixedLenFeature([], tf.string),
              "label": tf.FixedLenFeature([], tf.int64),
          })
      image = tf.decode_raw(features["image_raw"], tf.uint8)
      image.set_shape([28 * 28])
      # Normalize the values of the image from the range [0, 255] to [-0.5, 0.5]
      image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
      label = tf.cast(features["label"], tf.int32)
      return image, label

    dataset = tf.contrib.data.TFRecordDataset(
        filename, buffer_size=FLAGS.dataset_reader_buffer_size)
    dataset = dataset.repeat()
    dataset = dataset.apply(
      tf.contrib.data.map_and_batch(
         parser, batch_size=batch_size,
         num_parallel_batches=8,
         drop_remainder=True))
    return dataset

  return input_fn

def main(unused_argv):

  tf.logging.set_verbosity(tf.logging.INFO)

  run_config = tpu_config.RunConfig(
      master=FLAGS.master,
      model_dir=FLAGS.model_dir,
      session_config=tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=True),
      tpu_config=tpu_config.TPUConfig(FLAGS.iterations))

  estimator = tpu_estimator.TPUEstimator(
      model_fn=model_fn,
      use_tpu=FLAGS.use_tpu,
      train_batch_size=FLAGS.batch_size,
      eval_batch_size=FLAGS.batch_size,
      config=run_config)

  estimator.train(input_fn=make_input_fn(FLAGS.train_file),
                  max_steps=FLAGS.train_steps)

次のセクションでは、Cloud TPU を効率的に使用するために、上記のサンプルで紹介した新しいコンセプトについて説明します。

TPUEstimator のコンセプト

TPUEstimator は、TensorFlow プログラムを実行するためにグラフ内レプリケーション手法を使用します。グラフ内(シングルセッション)レプリケーションは、分散 TensorFlow で通常使用されるグラフ間(マルチセッション)レプリケーション トレーニングとは異なります。主な違いは次のとおりです。

  1. TPUEstimator では、TensorFlow セッション イニシエーターはローカルではありません。Python プログラムは、Cloud TPU のすべてのコアに複製される単一のグラフを作成します。一般的な構成では、TensorFlow セッション イニシエーターを最初のワーカーにするように設定します。

  2. 入力パイプラインはリモートホスト(ローカルではなく)に配置され、トレーニング サンプルができるだけ速く Cloud TPU に提供されるようにします。データセット(tf.data)が必要です。

  3. Cloud TPU ワーカーは同期的に動作し、各ワーカーが同じステップを同時に実行します。

TensorFlow Estimator から TPUEstimator への変換

最初に小規模なモデルを移植して、その動作をテストすることをおすすめします。そうすることで、TPUEstimator の基本的なコンセプトをしっかり理解できます。モデルが動作したら、徐々に機能を追加します。

サンプルモデルのセットと、これらのモデルを Cloud TPU で実行する手順については、チュートリアルをご覧ください。その他のモデルは GitHub で入手できます。

tf.estimator.Estimator からコードを変換して tf.contrib.tpu.TPUEstimator を使用するには、次のように変更します。

  • tf.estimator.RunConfigtf.contrib.tpu.RunConfig に変更します。
  • iterations_per_loop を指定するように TPUConfigtf.contrib.tpu.RunConfig の一部)を設定します。iterations_per_loop は、1 回の session.run の呼び出し(トレーニング ループごと)で、Cloud TPU で実行するイテレーションの数です。

Cloud TPU は、そのトレーニング ループに指定された反復回数を実行してからホストに返します。Cloud TPU の反復がすべて実行されるまで、チェックポイントまたは概要は保存されません。

  • model_fn で、tf.contrib.tpu.CrossShardOptimizer を使用してオプティマイザーをラップします。次に例を示します。

     optimizer = tf.contrib.tpu.CrossShardOptimizer(
          tf.train.GradientDescentOptimizer(learning_rate=learning_rate))
    
  • tf.estimator.Estimatortf.contrib.tpu.TPUEstimator に変更します。

デフォルトの RunConfig は、100 ステップごとに TensorBoard のサマリーを保存して、10 分ごとにチェックポイントを書き込みます。

よくある質問

入力パイプラインで tf.data が必要なのはなぜですか?

その理由は 2 つあります。

  1. アプリケーション コードはクライアントで実行されますが、TPU の計算は worker で実行されます。パフォーマンスを向上させるには、入力パイプライン演算をワーカーで実行する必要があります。tf.data はワーカーで演算を実行します。

  2. TPU の起動コストを平均化するために、モデルのトレーニング ステップは tf.while_loop にラップされています。1 つの Session.run により、1 回のトレーニング ループで複数回のイテレーションが実行されます。現在、tf.while_loop でラップできるのは tf.data のみです。

モデル トレーニングのパフォーマンスのプロファイルを作成するにはどうすればよいですか?

TensorBoard に提供されたプロファイラを使用して、モデル トレーニングのパフォーマンスのプロファイルを作成できます。