使用 JAX 在 Cloud TPU 虚拟机上运行计算
本文档简要介绍了如何搭配使用 JAX 和 Cloud TPU。
准备工作
在运行本文档中的命令之前,您必须创建一个 Google Cloud账号,安装 Google Cloud CLI,并配置 gcloud
命令。如需了解详情,请参阅设置 Cloud TPU 环境。
使用 gcloud
创建 Cloud TPU 虚拟机
定义一些环境变量,以便更轻松地使用命令。
export PROJECT_ID=your-project export ACCELERATOR_TYPE=v5p-8 export ZONE=us-east5-a export RUNTIME_VERSION=v2-alpha-tpuv5 export TPU_NAME=your-tpu-name
从 Cloud Shell 或安装了 Google Cloud CLI 的计算机终端运行以下命令,以创建 TPU 虚拟机。
$ gcloud compute tpus tpu-vm create $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
连接到 Cloud TPU 虚拟机
使用以下命令通过 SSH 连接到 TPU 虚拟机:
$ gcloud compute tpus tpu-vm ssh $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
在 Cloud TPU 虚拟机上安装 JAX
(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
系统检查
验证 JAX 是否可以访问 TPU 并运行基本操作:
启动 Python 3 解释器:
(vm)$ python3
>>> import jax
显示可用的 TPU 核心数:
>>> jax.device_count()
系统会显示 TPU 核心数。显示的核心数取决于您使用的 TPU 版本。如需了解详情,请参阅 TPU 版本。
执行计算:
>>> jax.numpy.add(1, 1)
将显示 numpy add 的结果:
命令的输出为:
Array(2, dtype=int32, weak_type=true)
退出 Python 解释器:
>>> exit()
在 TPU 虚拟机上运行 JAX 代码
现在,您可以运行任何 JAX 代码。要运行 JAX 中的标准机器学习模型,Flax 示例是一个很好的起点。例如,如需训练基本 MNIST 卷积网络,请执行以下操作:
安装 Flax 示例依赖项
(vm)$ pip install --upgrade clu (vm)$ pip install tensorflow (vm)$ pip install tensorflow_datasets
安装 FLAX
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user flax
运行 FLAX MNIST 训练脚本
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
该脚本会下载数据集并开始训练。脚本输出应如下所示:
0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88 I0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72 I0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04 I0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15 I0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18
清理
为避免因本页中使用的资源导致您的 Google Cloud 账号产生费用,请按照以下步骤操作。
完成 TPU 虚拟机的操作后,请按照以下步骤清理资源。
断开与 Compute Engine 实例的连接(如果您尚未这样做):
(vm)$ exit
删除您的 Cloud TPU。
$ gcloud compute tpus tpu-vm delete $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
通过运行以下命令来验证资源已删除。确保您的 TPU 不再列出。删除操作可能需要几分钟时间才能完成。
$ gcloud compute tpus tpu-vm list \ --zone=$ZONE
性能说明
以下是对于在 JAX 中使用 TPU 尤其重要的一些细节。
填充
TPU 性能下降的最常见原因之一是引入了意外填充:
- Cloud TPU 中会平铺数组。这需要将其中一个维度填充为 8 的倍数,并将另一个维度填充为 128 的倍数。
- 矩阵乘法单元在用于大型矩阵对时性能最佳,大型矩阵对可最大限度地减少填充需求。
bfloat16 dtype
默认情况下,TPU 上 JAX 中的矩阵乘法使用 bfloat16 和 float32 累积。这可以通过相关 jax.numpy 函数调用(matmul、dot、einsum 等)中的精度参数来控制。具体而言:
precision=jax.lax.Precision.DEFAULT
:使用混合 bfloat16 精度(最快)precision=jax.lax.Precision.HIGH
:使用多个 MXU 通道来实现更高的精度precision=jax.lax.Precision.HIGHEST
:使用更多 MXU 通道来实现完整的 float32 精度
JAX 还添加了 bfloat16 dtype,可用于将数组明确转换为 bfloat16
,例如 jax.numpy.array(x, dtype=jax.numpy.bfloat16)
。
后续步骤
如需详细了解 Cloud TPU,请参阅: