Inception v3 高级指南
本文档讨论 Inception 模型的多个方面,以及说明这些方面如何彼此成就,使该模型可在 Cloud TPU 上高效运行。它是在 Cloud TPU 上运行 Inception v3 的高级指南,同时详细讨论促使模型发生了显著改进的具体变化。本文档是对 Inception v3 教程的补充说明。
Inception v3 TPU 训练将运行采用类似配置的 GPU 作业生成的匹配准确率曲线。该模型已成功通过 v2-8、v2-128 和 v2-512 配置完成训练。该模型约用了 170 个周期达到了 78.1% 以上的准确率。
本文档中显示的代码示例旨在简要说明实际的实现情况。 工作代码可在 GitHub 中找到。
简介
Inception v3 是一个图像识别模型,已被证实在 ImageNet 数据集上的准确率超过 78.1%。该模型是数年来多位研究人员提出的诸多想法积淀的成果。它以 Szegedy 等人发表的《Rethinking the Inception Architecture for Computer Vision》原创性论文为理论依据。
模型本身由对称和非对称构建块组成,包括卷积、平均池化、最大池化、串联、丢弃和全连接层。批量归一化也在模型中广泛应用,同时用于激活输入。损失是通过 Softmax 计算的。
以下是该模型的简要图示:
Inception 模型自述文件中介绍了有关 Inception 架构的更多信息。
Estimator API
Inception v3 的 TPU 版本是采用 TPUEstimator 编写的,该 API 旨在简化开发工作,可让您专注于模型本身而不是底层硬件的细节。该 API 在后台执行在 TPU 上运行模型所需的大部分低级杂项工作,同时自动执行一些常用功能(例如保存和恢复检查点)。
Estimator API 强制分离模型和代码的输入部分。
您可以根据模型定义和输入流水线来定义 model_fn
和 input_fn
函数。以下代码展示了这些函数的声明:
def model_fn(features, labels, mode, params):
…
return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op)
def input_fn(params):
def parser(serialized_example):
…
return image, label
…
images, labels = dataset.make_one_shot_iterator().get_next()
return images, labels
API 提供了 train()
和 evaluate()
这两个关键函数,用于训练和评估,如以下代码所示:
def main(unused_argv):
…
run_config = tpu_config.RunConfig(
master=FLAGS.master,
model_dir=FLAGS.model_dir,
session_config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=True),
tpu_config=tpu_config.TPUConfig(FLAGS.iterations, FLAGS.num_shards),)
estimator = tpu_estimator.TPUEstimator(
model_fn=model_fn,
use_tpu=FLAGS.use_tpu,
train_batch_size=FLAGS.batch_size,
eval_batch_size=FLAGS.batch_size,
config=run_config)
estimator.train(input_fn=input_fn, max_steps=FLAGS.train_steps)
eval_results = inception_classifier.evaluate(
input_fn=imagenet_eval.input_fn, steps=eval_steps)
ImageNet 数据集
必须先使用大量已加标签的图片对模型进行训练,然后才能使用模型识别图片。ImageNet 是一种常用的数据集。
ImageNet 拥有超过一千万张带标签的图片的网址。其中 100 万张图片还带有边界框,用于为加标签的对象指定更精确的位置。
对于该模型,ImageNet 数据集包括 1331167 张图片组成,其中训练数据集包含 1281167 张;评估数据集包含 50000 张。
训练数据集与评估数据集是有意互相隔离的。只有训练数据集的图片会用于训练模型,并且只有评估数据集中的图片会用于评估模型准确率。
模型要求将图片存储为 TFRecord 格式。要将图片从原始 JPEG 文件转换为 TFRecord 文件,请使用开源批处理脚本:download_and_preprocess_imagenet.sh
。该脚本会生成一系列以下形式的文件(用于训练和验证):
${DATA_DIR}/train-00000-of-01024 ${DATA_DIR}/train-00001-of-01024 ... ${DATA_DIR}/train-01023-of-01024 and ${DATA_DIR}/validation-00000-of-00128 S{DATA_DIR}/validation-00001-of-00128 ... ${DATA_DIR}/validation-00127-of-00128
其中 DATA_DIR 是数据集所在的位置,例如:DATA_DIR = $ HOME / imagenet-data
Inception 模型自述文件的使用入门部分中详细说明了如何构建和运行脚本。
输入流水线
每个 Cloud TPU 设备都具有 8 个内核并且连接到主机 (CPU)。 较大的切片拥有多个主机。其他较大的配置与多个主机交互。例如,v2-256 与 16 个主机通信。
主机从文件系统或本地内存中检索数据,执行数据预处理所需的任何操作,然后将预处理后的数据传输到 TPU 核心。 我们认为主机会单独完成三个数据处理阶段,这三个阶段分别称为:1) 存储、2) 预处理、3) 传输。下图概要显示了该图:
为实现良好性能,系统应处于平衡状态。如果主机 CPU 完成三个数据处理阶段的时间长于 TPU,则执行操作将受限于主机。这两种情况如下图所示:
Inception v3 的当前实现处于受限于输入的边缘。需要从文件系统检索、解码图片,然后进行预处理。有多种不同类型(从中等到复杂)的预处理阶段可供使用。如果使用最复杂的预处理阶段,则训练流水线将受限于预处理。使用中等复杂的预处理阶段,可以使模型受 TPU 的制约,实现超过 78.1% 的准确率。
该模型使用 tf.data.Dataset 来处理输入流水线处理。如需详细了解如何优化输入流水线,请参阅数据集性能指南。
虽然您可以定义函数并将其传递给 Estimator API,但 InputPipeline
类会封装所有必需功能。
Estimator API 可让您直接使用此类。只需将函数传递给函数 train()
和 evaluate()
的 input_fn
参数,如下面的代码段所示:
def main(unused_argv):
…
inception_classifier = tpu_estimator.TPUEstimator(
model_fn=inception_model_fn,
use_tpu=FLAGS.use_tpu,
config=run_config,
params=params,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=eval_batch_size,
batch_axis=(batch_axis, 0))
…
for cycle in range(FLAGS.train_steps // FLAGS.train_steps_per_eval):
tf.logging.info('Starting training cycle %d.' % cycle)
inception_classifier.train(
input_fn=InputPipeline(True), steps=FLAGS.train_steps_per_eval)
tf.logging.info('Starting evaluation cycle %d .' % cycle)
eval_results = inception_classifier.evaluate(
input_fn=InputPipeline(False), steps=eval_steps, hooks=eval_hooks)
tf.logging.info('Evaluation results: %s' % eval_results)
以下代码段显示了 InputPipeline
的主要元素。
class InputPipeline(object):
def __init__(self, is_training):
self.is_training = is_training
def __call__(self, params):
# Storage
file_pattern = os.path.join(
FLAGS.data_dir, 'train-*' if self.is_training else 'validation-*')
dataset = tf.data.Dataset.list_files(file_pattern)
if self.is_training and FLAGS.initial_shuffle_buffer_size > 0:
dataset = dataset.shuffle(
buffer_size=FLAGS.initial_shuffle_buffer_size)
if self.is_training:
dataset = dataset.repeat()
def prefetch_dataset(filename):
dataset = tf.data.TFRecordDataset(
filename, buffer_size=FLAGS.prefetch_dataset_buffer_size)
return dataset
dataset = dataset.apply(
tf.contrib.data.parallel_interleave(
prefetch_dataset,
cycle_length=FLAGS.num_files_infeed,
sloppy=True))
if FLAGS.followup_shuffle_buffer_size > 0:
dataset = dataset.shuffle(
buffer_size=FLAGS.followup_shuffle_buffer_size)
# Preprocessing
dataset = dataset.map(
self.dataset_parser,
num_parallel_calls=FLAGS.num_parallel_calls)
dataset = dataset.prefetch(batch_size)
dataset = dataset.apply(
tf.contrib.data.batch_and_drop_remainder(batch_size))
dataset = dataset.prefetch(2) # Prefetch overlaps in-feed with training
images, labels = dataset.make_one_shot_iterator().get_next()
# Transfer
return images, labels
存储部分始于创建数据集,并包括从存储中读取 TFRecord(使用 tf.data.TFRecordDataset
)。根据需要使用了特殊目的函数 repeat()
和 shuffle()
。函数 tf.contrib.data.parallel_interleave()
将在输入中映射函数 prefetch_dataset()
以生成嵌套数据集,并以交错输出其元素。它将从 cycle_length
嵌套数据集中并行获取元素,从而提高吞吐量。sloppy
参数放宽了以确定顺序产生输出的要求,允许实现过程跳过在请求时元素尚未就绪的嵌套数据集。
预处理部分将调用 dataset.map(parser)
,而后者又会在预处理图片时调用解析器函数。下一部分中将详细讨论预处理阶段。
传输部分(函数结尾处)包括 return images, labels
行。TPUEstimator 获取返回值并自动将它们传输到设备中。
下图显示了 Inception v3 的 Cloud TPU 性能跟踪示例。TPU 计算时间(忽略任何馈入暂停)约为 815 毫秒。
主机存储已写入跟踪记录,如以下屏幕截图所示:
主机预处理中包括图片解码和一系列图片失真函数,如以下屏幕截图所示:
主机/TPU 传输如以下屏幕截图所示:
预处理阶段
图片预处理是系统的关键部分,可能会影响模型在训练期间获得的最高准确率。至少需要对图片进行解码并调整大小以适合模型需求。对于 Inception,图片必须为 299x299x3 像素。
但是,仅仅进行解码和调整大小不足以实现良好的准确率。ImageNet 训练数据集中包含 1281167 张图片。利用训练图片集中的图片训练一遍称为一个周期。在训练期间,模型需要利用训练数据集中的图片训练多次,以提高图片识别能力。如需训练 Inception v3 以使其达到足够的准确率,请使用 140 到 200 个周期,具体取决于全局批次大小。
在将图像馈送到模型之前不断更改图像,使特定图像在每个周期都略有不同,这会很有用。如何对图片进行最佳预处理既是科学,也是艺术。精心设计的预处理阶段可以显著提高模型的识别能力。如果预处理阶段过于简单,可能会人为地限制同一模型可在训练期间达到的准确率顶点。
Inception v3 提供了多种预处理阶段选项,从相对简单且计算开销较小到相当复杂且计算开销很大的选项均有涵盖。文件 vgg_preprocessing.py 和 inception_preprocessing.py 中分别介绍了两种不同类型的选项。
文件 vgg_preprocessing.py 定义了一个预处理阶段,resnet
训练中使用该预处理阶段成功达到了 75% 的准确率,Inception v3 使用该预处理阶段时却结果欠佳。
文件 inception_preprocessing.py 包含一个预处理阶段,用于训练 Inception v3 在 TPU 上运行时准确率在 78.1 到 78.5% 之间。
根据模型是在接受训练还是用于推理/评估,预处理操作有所差异。
在评估时,预处理很简单:剪裁图片的中心区域,然后将其调整为默认的 299x299 大小。以下代码段展示了一个预处理实现:
def preprocess_for_eval(image, height, width, central_fraction=0.875):
with tf.name_scope(scope, 'eval_image', [image, height, width]):
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.image.central_crop(image, central_fraction=central_fraction)
image = tf.expand_dims(image, 0)
image = tf.image.resize_bilinear(image, [height, width], align_corners=False)
image = tf.squeeze(image, [0])
image = tf.subtract(image, 0.5)
image = tf.multiply(image, 2.0)
image.set_shape([height, width, 3])
return image
用于训练时,剪裁区域将随机选择:随机选择边界框以选择图片区域,然后调整其大小。已调整大小的图片将视情况翻转并且其颜色将失真。以下代码段展示了这些操作的实现:
def preprocess_for_train(image, height, width, bbox, fast_mode=True, scope=None):
with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]):
if bbox is None:
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox)
distorted_image.set_shape([None, None, 3])
num_resize_cases = 1 if fast_mode else 4
distorted_image = apply_with_random_selector(
distorted_image,
lambda x, method: tf.image.resize_images(x, [height, width], method),
num_cases=num_resize_cases)
distorted_image = tf.image.random_flip_left_right(distorted_image)
if FLAGS.use_fast_color_distort:
distorted_image = distort_color_fast(distorted_image)
else:
num_distort_cases = 1 if fast_mode else 4
distorted_image = apply_with_random_selector(
distorted_image,
lambda x, ordering: distort_color(x, ordering, fast_mode),
num_cases=num_distort_cases)
distorted_image = tf.subtract(distorted_image, 0.5)
distorted_image = tf.multiply(distorted_image, 2.0)
return distorted_image
函数 distort_color
负责更改颜色。它提供了一种仅修改亮度和饱和度的快速模式。完整模式会按随机顺序修改亮度、饱和度和色调。
def distort_color(image, color_ordering=0, fast_mode=True, scope=None):
with tf.name_scope(scope, 'distort_color', [image]):
if fast_mode:
if color_ordering == 0:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
else:
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
else:
if color_ordering == 0:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
elif color_ordering == 1:
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
elif color_ordering == 2:
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
elif color_ordering == 3:
image = tf.image.random_hue(image, max_delta=0.2)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
return tf.clip_by_value(image, 0.0, 1.0)
函数 distort_color
的计算开销很大,这在一定程度上是由于要获得色调和饱和度所需的非线性 RGB 至 HSV 以及 HSV 至 RGB 转换。快速模式和完整模式都需要进行这些转换。尽管快速模式的计算开销较小,但如果启用,仍会将模型推到受限于 CPU 计算的范围。
您还可以在选项列表中添加新函数 distort_color_fast
。此函数会使用 JPEG 转换方案将图片从 RGB 映射为 YCrCb,并在映射回 RGB 之前随机改变亮度和 Cr/Cb 色度。以下代码段展示了此函数的实现:
def distort_color_fast(image, scope=None):
with tf.name_scope(scope, 'distort_color', [image]):
br_delta = random_ops.random_uniform([], -32./255., 32./255., seed=None)
cb_factor = random_ops.random_uniform(
[], -FLAGS.cb_distortion_range, FLAGS.cb_distortion_range, seed=None)
cr_factor = random_ops.random_uniform(
[], -FLAGS.cr_distortion_range, FLAGS.cr_distortion_range, seed=None)
channels = tf.split(axis=2, num_or_size_splits=3, value=image)
red_offset = 1.402 * cr_factor + br_delta
green_offset = -0.344136 * cb_factor - 0.714136 * cr_factor + br_delta
blue_offset = 1.772 * cb_factor + br_delta
channels[0] += red_offset
channels[1] += green_offset
channels[2] += blue_offset
image = tf.concat(axis=2, values=channels)
image = tf.clip_by_value(image, 0., 1.)
return image
这是一个经过预处理的示例图片。系统已随机选择了该图片的个区域,并使用 distort_color_fast
函数更改其颜色。
函数 distort_color_fast
的计算效率很高,且仍会使训练受限于 TPU 执行时间。此外,该模型已用于训练 Inception v3 模型,其批次大小介于 1024-16384 之间,并且该模型的准确率超过 78.1%。
优化器
当前模型展示了三种风格的优化器:SGD、动量和 RMSProp。
Stochastic gradient descent (SGD)
是最简单的更新:向负梯度方向微移权重。尽管它十分简单,但仍然可在某些模型中获得良好结果。更新动态可写为:
动量是一种主流优化器,通常相比 SGD 收 faster 更快。此优化器更新权重的方式与 SGD 非常相似,但会在上次更新的方向上添加一个组件。以下等式描述了动量优化器执行的更新:
可写为:
最后一项是上次更新方向的组件。
对于动量 \({\beta}\),我们使用值 0.9。
RMSprop 是由 Geoff Hinton 在一次讲座中首次提出的常用优化器。 以下等式描述了优化器的工作原理:
对于 Inception v3,测试显示 RMSProp 取得了最高准确率并且用时最短,动量的表现紧随其后。因此,RMSprop 被设置为默认优化器。所用参数为:decay \({\alpha}\) = 0.9、momentum \({\beta}\) = 0.9,且 \({\epsilon}\) = 1.0。
以下代码段展示了如何设置这些参数:
if FLAGS.optimizer == 'sgd':
tf.logging.info('Using SGD optimizer')
optimizer = tf.train.GradientDescentOptimizer(
learning_rate=learning_rate)
elif FLAGS.optimizer == 'momentum':
tf.logging.info('Using Momentum optimizer')
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate, momentum=0.9)
elif FLAGS.optimizer == 'RMS':
tf.logging.info('Using RMS optimizer')
optimizer = tf.train.RMSPropOptimizer(
learning_rate,
RMSPROP_DECAY,
momentum=RMSPROP_MOMENTUM,
epsilon=RMSPROP_EPSILON)
else:
tf.logging.fatal('Unknown optimizer:', FLAGS.optimizer)
当在 TPU 上运行并使用 Estimator API 时,优化器需要封装在 CrossShardOptimizer
函数中,以确保副本之间的同步(以及任何必要的交叉通信)。以下代码段展示了 Inception v3 模型如何封装优化器:
if FLAGS.use_tpu:
optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss, global_step=global_step)
指数平均数 (EMA)
在训练时,系统会根据优化器的更新规则在反向传播期间更新可训练参数。上一部分中已经讨论过说明这些规则的方程式,为方便起见,在此再介绍一遍:
指数平均数(也称为指数平滑)是一个应用于更新后权重的可选后期处理操作步骤,有时可以显显著提高性能。TensorFlow 提供函数 tf.train.ExponentialMovingAverage,该函数使用以下公式计算权重 \({\theta}\) 的 ema \({\hat{\theta}}\):
其中 \({\alpha}\) 是一个衰减因数(接近于 1.0)。在 Inception v3 模型中,\({\alpha}\) 设置为 0.995。
虽然这种计算是无限脉冲响应 (IIR) 滤波器,但衰减因数会建立大部分能量(或相关样本)的有效窗口,如下图所示:
我们可以重写过滤器方程,如下所示:
其中我们使用 \({\hat\theta_{-1}}=0\)。
\({\alpha}^k\) 值随着 k 的增加而衰减,因此只有一部分样本会对 \(\hat{\theta}_{t+T+1}\) 产生较大影响。衰减因数值的经验法则是:\(\frac {1} {1-\alpha}\),对应于 \({\alpha}\) = 200 =0.995。
我们首先获得一系列可训练变量,然后使用 apply()
方法为各个训练后的变量创建影子变量。以下代码段展示了 Inception v3 模型实现:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss, global_step=global_step)
if FLAGS.moving_average:
ema = tf.train.ExponentialMovingAverage(
decay=MOVING_AVERAGE_DECAY, num_updates=global_step)
variables_to_average = (tf.trainable_variables() +
tf.moving_average_variables())
with tf.control_dependencies([train_op]), tf.name_scope('moving_average'):
train_op = ema.apply(variables_to_average)
我们希望在评估期间使用 EMA 变量。我们定义了 LoadEMAHook
类,用于将 variables_to_restore()
方法应用于检查点文件,以使用影子变量名称进行评估:
class LoadEMAHook(tf.train.SessionRunHook):
def __init__(self, model_dir):
super(LoadEMAHook, self).__init__()
self._model_dir = model_dir
def begin(self):
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY)
variables_to_restore = ema.variables_to_restore()
self._load_ema = tf.contrib.framework.assign_from_checkpoint_fn(
tf.train.latest_checkpoint(self._model_dir), variables_to_restore)
def after_create_session(self, sess, coord):
tf.logging.info('Reloading EMA...')
self._load_ema(sess)
hooks
函数会被传递到 evaluate()
,如以下代码段所示:
if FLAGS.moving_average:
eval_hooks = [LoadEMAHook(FLAGS.model_dir)]
else:
eval_hooks = []
…
eval_results = inception_classifier.evaluate(
input_fn=InputPipeline(False), steps=eval_steps, hooks=eval_hooks)
批量归一化
批量归一化是一种广泛用于对模型的输入特征进行归一化处理的技术,从而显著缩短收敛时间。它是近年来机器学习领域中使用较为广泛的有效算法改进之一,可在多种模型(包括 Inception v3)中使用。
激活输入通过减去平均值并除以标准差进行归一化。为了在有反向传播的情况下平衡各个因素,系统将在每一层中引入两个可训练参数。归一化输出 \({\hat{x}}\) 会进行后续运算 \({\gamma\hat{x}}+\beta\),其中 \({\gamma}\) 和 \({\beta}\) 是一种标准差和模型本身学到的平均值。
这篇论文中包含完整方程集,为方便起见,在此再介绍一遍:
输入:小批量中的 X 值:\(\Phi=\{ {x_{1..m}\} }\) 要学习的参数:\({\gamma}\)、\({\beta}\)
输出:{ \({y_i}=BN_{\gamma,\beta}{(x_i)}\) }
\[{\mu_\phi} \leftarrow {\frac{1}{m}}{\sum_{i=1}^m}x_i \qquad \mathsf(mini-batch\ mean)\]
\[{\sigma_\phi}^2 \leftarrow {\frac{1}{m}}{\sum_{i=1}^m} {(x_i - {\mu_\phi})^2} \qquad \mathbf(mini-batch\ variance)\]
\[{\hat{x_i}} \leftarrow {\frac{x_i-{\mu_\phi}}{\sqrt {\sigma^2_\phi}+{\epsilon}}}\qquad \mathbf(normalize)\]
\[{y_i}\leftarrow {\gamma \hat{x_i}} + \beta \equiv BN_{\gamma,\beta}{(x_i)}\qquad \mathbf(scale \ and \ shift)\]
归一化发生在训练期间,但在评估时,我们希望模型以确定性的方式表现:图片的分类结果应仅取决于输入图片,而不是馈送给模型的一组图片。因此,我们需要修正 \({\mu}\) 和 \({\sigma}^2\),并使用表示图片填充统计信息的值。
模型计算小批次均值和方差的移动平均值:
\[{\hat\mu_i} = {\alpha \hat\mu_{t-1}}+{(1-\alpha)\mu_t}\]
\[{\hat\sigma_t}^2 = {\alpha{\hat\sigma^2_{t-1}}} + {(1-\alpha) {\sigma_t}^2}\]
具体对于 Inception v3 而言,即已获得(使用超参数微调)适当衰减因数以用于 GPU。我们也想要在 TPU 上使用该值,为此需要进行一些调整。
批量归一化移动均值和方差均通过低通滤波器计算,如以下方程式所示(此处,\({y_t}\) 表示移动均值或方差):
\[{y_t}={\alpha y_{t-1}}+{(1-\alpha)}{x_t} \]
(1)
在 8x1 GPU (同步)作业中,每个副本都将读取并更新当前移动平均值。当前副本必须写入新的移动变量,然后下一个副本才能读取该变量。
存在 8 个副本时,集成学习更新的操作集如下所示:
\[{y_t}={\alpha y_{t-1}}+{(1-\alpha)}{x_t} \]
\[{y_{t+1}}={\alpha y_{t}}+{(1-\alpha)}{x_{t+1}} \]
\[{y_{t+2}}={\alpha y_{t+1}}+{(1-\alpha)}{x_{t+2}} \]
\[{y_{t+3}}={\alpha y_{t+2}}+{(1-\alpha)}{x_{t+3}} \]
\[{y_{t+4}}={\alpha y_{t+3}}+{(1-\alpha)}{x_{t+4}} \]
\[{y_{t+5}}={\alpha y_{t+4}}+{(1-\alpha)}{x_{t+5}} \]
\[{y_{t+6}}={\alpha y_{t+5}}+{(1-\alpha)}{x_{t+6}} \]
\[{y_{t+7}}={\alpha y_{t+6}}+{(1-\alpha)}{x_{t+7}} \]
这组 8 个顺序更新可以写为:
\[{y_{t+7}}={\alpha^8y_{t-1}}+(1-\alpha){\sum_{k=0}^7} {\alpha^{7-k}}{x_{t+k}}\]
(2)
在 TPU 上的当前移动力矩计算实现中,每个分片将会独立执行计算,并且不存在跨分片通信。系统会向每个分片分发批量,每个分片处理总批量的 1/8(存在 8 个分片时)。
虽然每个分片都会计算移动时刻(即平均值和方差),但只有分片 0 的结果会传回主机 CPU。因此,实际上只有一个副本会执行移动均值/方差更新:
\[{z_t}={\beta {z_{t-1}}}+{(1-\beta)u_t}\]
(3)
这种更新的发生速率是其顺序对应项的 1/8。为了比较 GPU 和 TPU 更新方程式,我们需要对齐各自的时间尺度。具体而言,GPU 上构成一组 8 个顺序更新的操作集应与 TPU 上的单个更新进行比较,如下图所示:
以下为已修改时间索引的方程式:
\[{y_t}={\alpha^8y_{t-1}}+(1-\alpha){\sum_{k=0}^7} {\alpha^{7-k}}{x_{t-k/8}} \qquad \mathsf(GPU)\]
\[{z_t}={\beta {z_{t-1}}}+{(1-\beta)u_t}\qquad \mathsf(TPU) \]
如果我们假设 8 个小批次(针对所有相关维度进行归一化)在 GPU 8 小批次顺序更新中产生相似的值,就可以将这些方程进行近似处理,如下所示:
\[{y_t}={\alpha^8y_{t-1}}+(1-\alpha){\sum_{k=0}^7} {\alpha^{7-k}}{\hat{x_t}}={\alpha^8y_{t-1}+(1-\alpha^8){\hat{x_t}}} \qquad \mathsf(GPU)\]
\[{z_t}={\beta {z_{t-1}}}+{(1-\beta)u_t}\qquad \mathsf(TPU) \]
为了匹配给定衰减因数对 GPU 的影响,我们相应修改 TPU 上的衰减因数。具体来说,我们设置 \({\beta}\)=\({\alpha}^8\)。
对于 Inception v3,GPU 中使用的衰减值为 \({\alpha}\)=0.9997,可转换为 TPU 上的衰减值 \({\beta}\)=0.9976。
学习速率自适应
随着批量大小增加,训练难度加大。人们不断提出不同方法,希望能够高效地对大批量进行训练(例如,请参阅此处、此处和此处)。
其中一种技术是逐步提高学习速率(也称为渐进式提升)。使用渐进式提升,可在使用 4096 到 16384 范围内的批量大小对模型进行训练时成功实现 78.1% 以上的准确率。对于 Inception v3,首先将学习速率设置为常规起始学习速率的约 10%。在指定(少量)数量的“冷周期”中,学习速率保持此恒定低值,然后在指定数量的“热身周期”中开始线性增加。在“热身周期”的最后,学习速率会与使用了常规指数衰减时的学习速率相交。下图对此进行了说明。
以下代码段展示了如何执行此操作:
initial_learning_rate = FLAGS.learning_rate * FLAGS.train_batch_size / 256
if FLAGS.use_learning_rate_warmup:
warmup_decay = FLAGS.learning_rate_decay**(
(FLAGS.warmup_epochs + FLAGS.cold_epochs) /
FLAGS.learning_rate_decay_epochs)
adj_initial_learning_rate = initial_learning_rate * warmup_decay
final_learning_rate = 0.0001 * initial_learning_rate
train_op = None
if training_active:
batches_per_epoch = _NUM_TRAIN_IMAGES / FLAGS.train_batch_size
global_step = tf.train.get_or_create_global_step()
current_epoch = tf.cast(
(tf.cast(global_step, tf.float32) / batches_per_epoch), tf.int32)
learning_rate = tf.train.exponential_decay(
learning_rate=initial_learning_rate,
global_step=global_step,
decay_steps=int(FLAGS.learning_rate_decay_epochs * batches_per_epoch),
decay_rate=FLAGS.learning_rate_decay,
staircase=True)
if FLAGS.use_learning_rate_warmup:
wlr = 0.1 * adj_initial_learning_rate
wlr_height = tf.cast(
0.9 * adj_initial_learning_rate /
(FLAGS.warmup_epochs + FLAGS.learning_rate_decay_epochs - 1),
tf.float32)
epoch_offset = tf.cast(FLAGS.cold_epochs - 1, tf.int32)
exp_decay_start = (FLAGS.warmup_epochs + FLAGS.cold_epochs +
FLAGS.learning_rate_decay_epochs)
lin_inc_lr = tf.add(
wlr, tf.multiply(
tf.cast(tf.subtract(current_epoch, epoch_offset), tf.float32),
wlr_height))
learning_rate = tf.where(
tf.greater_equal(current_epoch, FLAGS.cold_epochs),
(tf.where(tf.greater_equal(current_epoch, exp_decay_start),
learning_rate, lin_inc_lr)),
wlr)
# Set a minimum boundary for the learning rate.
learning_rate = tf.maximum(
learning_rate, final_learning_rate, name='learning_rate')