使用 JAX 在 Cloud TPU 虚拟机上运行计算

本文档简要介绍了如何搭配使用 JAX 和 Cloud TPU。

准备工作

在运行本文档中的命令之前,您必须创建一个 Google Cloud账号,安装 Google Cloud CLI,并配置 gcloud 命令。如需了解详情,请参阅设置 Cloud TPU 环境

使用 gcloud 创建 Cloud TPU 虚拟机

  1. 定义一些环境变量,以便更轻松地使用命令。

    export PROJECT_ID=your-project
    export ACCELERATOR_TYPE=v5p-8
    export ZONE=us-east5-a
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export TPU_NAME=your-tpu-name

    环境变量说明

    PROJECT_ID
    您的 Google Cloud 项目 ID。
    ACCELERATOR_TYPE
    加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
    ZONE
    拟在其中创建 Cloud TPU 的可用区
    RUNTIME_VERSION
    Cloud TPU 运行时版本。如需了解详情,请参阅 TPU 虚拟机映像
    TPU_NAME
    用户为 Cloud TPU 分配的名称。
  2. 从 Cloud Shell 或安装了 Google Cloud CLI 的计算机终端运行以下命令,以创建 TPU 虚拟机。

    $ gcloud compute tpus tpu-vm create $TPU_NAME \
    --project=$PROJECT_ID \
    --zone=$ZONE \
    --accelerator-type=$ACCELERATOR_TYPE \
    --version=$RUNTIME_VERSION

连接到 Cloud TPU 虚拟机

使用以下命令通过 SSH 连接到 TPU 虚拟机:

$ gcloud compute tpus tpu-vm ssh $TPU_NAME \
--project=$PROJECT_ID \
--zone=$ZONE

在 Cloud TPU 虚拟机上安装 JAX

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

系统检查

验证 JAX 是否可以访问 TPU 并运行基本操作:

  1. 启动 Python 3 解释器:

    (vm)$ python3
    >>> import jax
  2. 显示可用的 TPU 核心数:

    >>> jax.device_count()

系统会显示 TPU 核心数。显示的核心数取决于您使用的 TPU 版本。如需了解详情,请参阅 TPU 版本

执行计算:

>>> jax.numpy.add(1, 1)

将显示 numpy add 的结果:

命令的输出为:

Array(2, dtype=int32, weak_type=true)

退出 Python 解释器:

>>> exit()

在 TPU 虚拟机上运行 JAX 代码

现在,您可以运行任何 JAX 代码。要运行 JAX 中的标准机器学习模型,Flax 示例是一个很好的起点。例如,如需训练基本 MNIST 卷积网络,请执行以下操作:

  1. 安装 Flax 示例依赖项

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
  2. 安装 FLAX

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
  3. 运行 FLAX MNIST 训练脚本

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5

该脚本会下载数据集并开始训练。脚本输出应如下所示:

  0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
  I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
  I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
  I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
  I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

清理

为避免因本页中使用的资源导致您的 Google Cloud 账号产生费用,请按照以下步骤操作。

完成 TPU 虚拟机的操作后,请按照以下步骤清理资源。

  1. 断开与 Compute Engine 实例的连接(如果您尚未这样做):

    (vm)$ exit
  2. 删除您的 Cloud TPU。

    $ gcloud compute tpus tpu-vm delete $TPU_NAME \
      --project=$PROJECT_ID \
      --zone=$ZONE
  3. 通过运行以下命令来验证资源已删除。确保您的 TPU 不再列出。删除操作可能需要几分钟时间才能完成。

    $ gcloud compute tpus tpu-vm list \
      --zone=$ZONE

性能说明

以下是对于在 JAX 中使用 TPU 尤其重要的一些细节。

填充

TPU 性能下降的最常见原因之一是引入了意外填充:

  • Cloud TPU 中会平铺数组。这需要将其中一个维度填充为 8 的倍数,并将另一个维度填充为 128 的倍数。
  • 矩阵乘法单元在用于大型矩阵对时性能最佳,大型矩阵对可最大限度地减少填充需求。

bfloat16 dtype

默认情况下,TPU 上 JAX 中的矩阵乘法使用 bfloat16 和 float32 累积。这可以通过相关 jax.numpy 函数调用(matmul、dot、einsum 等)中的精度参数来控制。具体而言:

  • precision=jax.lax.Precision.DEFAULT:使用混合 bfloat16 精度(最快)
  • precision=jax.lax.Precision.HIGH:使用多个 MXU 通道来实现更高的精度
  • precision=jax.lax.Precision.HIGHEST:使用更多 MXU 通道来实现完整的 float32 精度

JAX 还添加了 bfloat16 dtype,可用于将数组明确转换为 bfloat16,例如 jax.numpy.array(x, dtype=jax.numpy.bfloat16)

后续步骤

如需详细了解 Cloud TPU,请参阅: