使用 JAX 在 Cloud TPU 虚拟机上运行计算

本文档简要介绍了如何使用 JAX 和 Cloud TPU。

在按照本快速入门之前,您必须创建 Google Cloud Platform 帐号、安装 Google Cloud CLI 并配置 gcloud 命令。如需了解详情,请参阅设置账号和 Cloud TPU 项目

安装 Google Cloud CLI

Google Cloud CLI 包含用于与 Google Cloud 产品和服务交互的工具和库。如需了解详情,请参阅安装 Google Cloud CLI

配置 gcloud 命令

运行以下命令以将 gcloud 配置为使用您的 Google Cloud 项目,并安装 TPU 虚拟机预览所需的组件。

  $ gcloud config set account your-email-account
  $ gcloud config set project your-project-id

启用 Cloud TPU API

  1. Cloud Shell 中使用以下 gcloud 命令启用 Cloud TPU API。(您也可以通过 Google Cloud 控制台启用该功能)。

    $ gcloud services enable tpu.googleapis.com
    
  2. 运行以下命令以创建服务身份。

    $ gcloud beta services identity create --service tpu.googleapis.com
    

使用 gcloud 创建 Cloud TPU 虚拟机

借助 Cloud TPU 虚拟机,您的模型和代码可直接在 TPU 主机上运行。您通过 SSH 直接连接到 TPU 主机。您可以直接在 TPU 主机上运行任意代码、安装软件包、查看日志和调试代码。

  1. 从 Cloud Shell 或安装了 Google Cloud CLI 的计算机终端运行以下命令,创建 TPU 虚拟机。

    (vm)$ gcloud compute tpus tpu-vm create tpu-name \
    --zone=us-central2-b \
    --accelerator-type=v4-8 \
    --version=tpu-ubuntu2204-base
    

    必填字段

    zone
    您计划在其中创建 Cloud TPU 的区域
    accelerator-type
    加速器类型指定要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
    version
    Cloud TPU 软件版本。对于所有 TPU 类型,请使用 tpu-ubuntu2204-base

连接到 Cloud TPU 虚拟机

使用以下命令通过 SSH 连接到 TPU 虚拟机:

$ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central2-b

必需的字段

tpu_name
要连接的 TPU 虚拟机的名称。
zone
您创建 Cloud TPU 的可用区

在 Cloud TPU 虚拟机上安装 JAX

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

系统检查

验证 JAX 是否可以访问 TPU 并可以运行基本操作:

启动 Python 3 解释器:

(vm)$ python3
>>> import jax

显示可用的 TPU 核心数:

>>> jax.device_count()

系统随即会显示 TPU 核心的数量。如果您使用的是 v4 TPU,则此参数应为 4。如果您使用的是 v2 或 v3 TPU,则此参数应为 8

执行简单的计算:

>>> jax.numpy.add(1, 1)

将显示 numpy add 的结果:

命令的输出为:

Array(2, dtype=int32, weak_type=true)

退出 Python 解释器:

>>> exit()

在 TPU 虚拟机上运行 JAX 代码

您现在可以运行所需的任何 JAX 代码。要运行 JAX 中的标准机器学习模型,Flax 示例是一个很好的起点。例如,要训练一个基本的 MNIST 卷积网络,请使用以下代码:

  1. 安装 Flax 示例依赖项

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
    
  2. 安装 FLAX

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
    
  3. 运行 FLAX MNIST 训练脚本

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

该脚本会下载数据集并开始训练。脚本输出应如下所示:

  0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
  I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
  I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
  I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
  I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

清理

完成 TPU 虚拟机的操作后,请按照以下步骤清理资源。

  1. 断开与 Compute Engine 实例的连接(如果您尚未这样做):

    (vm)$ exit
    
  2. 删除您的 Cloud TPU。

    $ gcloud compute tpus tpu-vm delete tpu-name \
      --zone=us-central2-b
    
  3. 通过运行以下命令来验证资源已删除。确保您的 TPU 不再列出。删除操作可能需要几分钟时间才能完成。

    $ gcloud compute tpus tpu-vm list \
      --zone=us-central2-b
    

性能说明

以下是对于在 JAX 中使用 TPU 尤其重要的一些细节。

填充

TPU 性能下降的最常见原因之一是引入了意外填充:

  • Cloud TPU 中会平铺数组。这需要将其中一个维度填充为 8 的倍数,并将另一个维度填充为 128 的倍数。
  • 矩阵乘法单元在用于大型矩阵对时性能最佳,大型矩阵对可最大限度地减少填充需求。

bfloat16 dtype

默认情况下,TPU 上的 JAX 中的矩阵乘法会将 bfloat16 与 float32 累积结合使用。这可以通过相关 jax.numpy 函数调用(matmul、dot、einsum 等)中的精度参数来控制。特别是:

  • precision=jax.lax.Precision.DEFAULT:使用混合 bfloat16 精度(最快)
  • precision=jax.lax.Precision.HIGH:使用多个 MXU 传递以获得更高的精确度
  • precision=jax.lax.Precision.HIGHEST:使用更多 MXU 传递来实现 float32 完全精度

JAX 还添加了 bfloat16 dtype,您可以使用它来将数组显式转换为 bfloat16,例如 jax.numpy.array(x, dtype=jax.numpy.bfloat16)

在 Colab 中运行 JAX

当您在 Colab 笔记本中运行 JAX 代码时,Colab 会自动创建旧版 TPU 节点。TPU 节点具有不同的架构。如需了解详情,请参阅系统架构

后续步骤

如需详细了解 Cloud TPU,请参阅: