Cloud TPU 性能指南

排查 TPU 性能问题时,第一步是分析模型。如需详细了解如何捕获性能配置文件,请参阅在 Cloud TPU 上剖析模型

TPU 模型性能

本部分介绍可能会降低模型性能的常见问题,以及如何解决这些问题。

  1. 模型受限于输入

    TPU 执行计算的速度非常快。为了确保 TPU 不处于空闲状态,请务必确保将稳定的数据流加载到 TPU。实现这一点的方式取决于您加载和预处理数据集的方式。例如,您可以使用 tf.data.TFRecordset()num_parallel_reads 参数并行读取数据文件。

  2. 由于分片,批次大小太小(跨核心拆分批次)

    TPU 运行时在 TPU 设备(例如 v2-8 或 v3-8)的所有 8 个核心上拆分批量。如果指定的全局批量大小为 128,则每个核心会收到批量大小 16 (128 / 8)。

    为获得最佳内存用量,请使用 TPU 内存中可容纳的最大批量大小。每个 TPU 核心都使用二维 8 X 128 向量寄存器来处理矩阵乘法。通常,您的批次大小应该能被 8 或 128 整除。

XLA 编译器优化

XLA 是一款机器学习编译器,可以为 TPU、CPU、GPU 和其他平台生成二进制文件。虽然 XLA 是标准 TensorFlow 代码库的一部分,但它也可用于 PyTorchJAX 模型。适用于 Cloud TPU 的模型被转换为 XLA 图,然后 XLA 将其编译为 TPU 可执行文件。如需详细了解 XLA,请参阅 XLA:优化机器学习编译器

填充

为了高效地使用 TPU 内存,请设计数据结构以将其平铺为 128 x 8 个数据块。如果矩阵计算的数据填满了 128 x 8 个数据块,则 XLA 编译器会填充张量。填充有两个不足:

  1. 填充的张量不能充分利用 TPU 核心。
  2. 填充增加了张量所需的片上内存存储,并可能导致内存不足错误。

虽然填充是由 XLA 编译器在必要时自动执行的,但您可以使用内存查看器工具来确定填充量。您可以通过选择非常适合 TPU 的张量维度来避免填充。

张量维度

XLA 编译器将存储在 TPU HBM 内存中的张量的大小向上舍入,以便更高效地执行计算。此填充操作在硬件级别以透明方式进行,不会影响结果。但是,在某些情况下,填充可能会导致内存使用量和执行时间显著增加。

TPU 运行时会在内存中存放张量,以便最大限度地提高计算效率并减少填充。为了最大限度地减少内存开销并提高计算效率,必须满足以下条件之一

  1. 总批量大小应为 64 的倍数(每个 TPU 核心 8 个),并且特征维度大小应为 128 的倍数。

  2. 总批量大小应为 1024 的倍数(每个 TPU 核心 128 个),并且特征维度大小应为 8 的倍数。

使用 1024 作为批量大小以及 128 的倍数作为特征维度可以获得最佳效率,但这并非适用于所有模型。

融合

融合是 XLA 编译器用于优化程序的通用技术。融合操作是指合并多个需要联合执行的操作。

例如,请考虑以下系列操作:

    tmp = tf.add(x, y)
    result = tf.multiply(tmp, z)

此代码大致相当于以下伪代码:

    for (i = 0; i < element_count; i++) {
      tmp[i] = x[i] + y[i];
    }

    for (i = 0; i < element_count; i++) {
      result = tmp[i] * z[i];
    }

通过融合,数组访问将同时进行:

    for (i = 0; i < element_count; i++) {
      result = (x[i] + y[i]) * z[i];
    }

在此例中,内存往返次数减少,XLA 无需为“tmp”分配任何空间。

融合是一项关键优化措施,可从多方面对 Cloud TPU 形成有利影响:

  • 无需在主内存中存储中间结果(通常很慢),可减少内存传输。
  • 可以更充分地利用硬件设备,以免这些设备被闲置。
  • 由于同时需要的活动缓冲区减少,可以降低模型的内存利用量。

广播

当两个形状不同但可兼容的张量合并时,系统将以隐式方式进行广播。

例如,tf.add(vector, matrix) 要求将矢量广播到矩阵的形状。操作结果具有与矩阵相同的形状。如需了解详情,请参阅广播数组指南。

虽然广播通常可与其使用方融合,但强制执行广播可能会导致性能低下,同时增加内存使用量。

在以下示例中,矢量与矩阵的相加操作中隐含的广播不能与 argmax 融合,因而导致广播物化:

`tf.argmax(tf.add(vector, zero_matrix), axis=0)`