Cloud TPU 虚拟机 JAX 快速入门

本文档简要介绍如何搭配使用 JAX 和 Cloud TPU。

登录您的 Google 帐号。 如果您还没有 Google 帐号,请注册新帐号。在 Google Cloud Console 中,从项目选择器页面选择或创建 Cloud 项目。确保您的项目已启用结算功能

安装 Google Cloud SDK

Google Cloud SDK 包含用于与 Google Cloud 产品和服务交互的工具和库。如需了解详情,请参阅安装 Google Cloud SDK

配置 gcloud 命令

运行以下命令配置 gcloud,以使用 GCP 项目并安装 TPU 虚拟机预览版所需的组件。

  $ gcloud config set account your-email-account
  $ gcloud config set project project-id

启用 Cloud TPU API

  1. Cloud Shell 中使用以下 gcloud 命令启用 Cloud TPU API。(您也可以从 Google Cloud Console 启用)。

    $ gcloud services enable tpu.googleapis.com
    
  2. 运行以下命令以创建服务身份。

    $ gcloud beta services identity create --service tpu.googleapis.com
    

使用 gcloud 创建 Cloud TPU 虚拟机

使用 Cloud TPU 虚拟机时,模型和代码直接在 TPU 主机上运行。您通过 SSH 直接连接到 TPU 主机。您可以直接在 TPU 主机上运行任意代码、安装软件包、查看日志和调试代码。

  1. 从 GCP Cloud Shell 或安装了 Google Cloud SDK 的计算机终端运行以下命令,以创建 TPU 虚拟机。

    (vm)$ gcloud alpha compute tpus tpu-vm create tpu-name \
    --zone europe-west4-a \
    --accelerator-type v3-8 \
    --version v2-alpha

    必需的字段

    zone
    拟在其中创建 Cloud TPU 的地区
    accelerator-type
    要创建的 Cloud TPU 的类型
    version
    Cloud TPU 运行时版本。在单 TPU 设备、Pod 切片或整个 Pod 上使用 JAX 时,请将此字段设置为“v2-alpha”。

连接到 Cloud TPU 虚拟机

使用以下命令通过 SSH 连接到 TPU 虚拟机:

$ gcloud alpha compute tpus tpu-vm ssh tpu-name --zone europe-west4-a

必需的字段

tpu_name
要连接的 TPU 虚拟机的名称。
zone
您创建 Cloud TPU 的可用区

在 Cloud TPU 虚拟机上安装 JAX

(vm)$ pip3 install --upgrade jax jaxlib

系统检查

通过检查 JAX 看到 Cloud TPU 核心以及可以运行基本操作来测试已正确安装所有组件:

启动 Python 3 解释器:

(vm)$ python3
>>> import jax

显示可用的 TPU 核心数:

>>> jax.device_count()

显示 TPU 核心数,应为 8

执行简单的计算:

>>> jax.numpy.add(1, 1)

将显示 numpy add 的结果:

命令的输出为:

DeviceArray(2, dtype=int32)

退出 Python 解释器:

>>> exit()

在 TPU 虚拟机上运行 JAX 代码

现在,您可以运行任何 JAX 代码。要运行 JAX 中的标准机器学习模型,Flax 示例是一个很好的起点。例如,如需训练基本 MNIST 卷积网络,请执行以下操作:

  1. 安装 Tensorflow 数据集

    (vm)$ pip install --upgrade clu
    
  2. 安装 FLAX。

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user -e flax
    
  3. 运行 FLAX MNIST 训练脚本

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

    脚本输出应如下所示:

    I0513 21:09:35.448946 140431261813824 train.py:125] train epoch: 1, loss: 0.2312, accuracy: 93.00
    I0513 21:09:36.402860 140431261813824 train.py:176] eval epoch: 1, loss: 0.0563, accuracy: 98.05
    I0513 21:09:37.321380

清理

完成 TPU 虚拟机的操作后,请按照以下步骤清理资源。

  1. 断开与 Compute Engine 实例的连接(如果您尚未这样做):

    (vm)$ exit
    
  2. 删除您的 Cloud TPU。

    $ gcloud alpha compute tpus tpu-vm delete tpu-name \
      --zone europe-west4-a
    
  3. 通过运行以下命令来验证资源已删除。确保您的 TPU 不再列出。删除操作可能需要几分钟时间才能完成。

性能说明

以下是对于在 JAX 中使用 TPU 尤其重要的一些细节。

填充

TPU 性能下降的最常见原因之一是引入了意外填充:

  • Cloud TPU 中会平铺数组。这需要将其中一个维度填充为 8 的倍数,并将另一个维度填充为 128 的倍数。
  • 矩阵乘法单元在用于大型矩阵对时性能最佳,大型矩阵对可最大限度地减少填充需求。

bfloat16 dtype

默认情况下,TPU 上 JAX 中的矩阵乘法使用 bfloat16 和 float32 累积。这可以通过相关 jax.numpy 函数调用(matmul、dot、einsum 等)中的精度参数来控制。特别是:

  • precision=jax.lax.Precision.DEFAULT:使用混合 bfloat16 精度(最快)
  • precision=jax.lax.Precision.HIGH:使用多个 MXU 通道来实现更高的精度
  • precision=jax.lax.Precision.HIGHEST:使用更多 MXU 通道来实现完整的 float32 精度

JAX 还添加了 bfloat16 dtype,可用于将数组明确转换为 bfloat16,例如:jax.numpy.array(x, dtype=jax.numpy.bfloat16)

在 Colab 中运行 JAX

当您在 Colab 笔记本中运行 JAX 代码时,Colab 会自动创建旧版 TPU 节点。TPU 节点具有不同的架构。如需了解详情,请参阅系统架构