使用 JAX 在 Cloud TPU 虚拟机上运行计算
本文档简要介绍了如何搭配使用 JAX 和 Cloud TPU。
在按照本快速入门操作之前,您必须创建一个 Google Cloud Platform 账号,安装 Google Cloud CLI,并配置 gcloud
命令。如需了解详情,请参阅设置账号和 Cloud TPU 项目
安装 Google Cloud CLI
Google Cloud CLI 包含用于与 Google Cloud 产品和服务交互的工具和库。如需了解详情,请参阅安装 Google Cloud CLI。
配置 gcloud
命令
运行以下命令配置 gcloud
,以使用 Google Cloud 项目并安装 TPU 虚拟机预览版所需的组件。
$ gcloud config set account your-email-account $ gcloud config set project your-project-id
启用 Cloud TPU API
在 Cloud Shell 中使用以下
gcloud
命令启用 Cloud TPU API。(您也可以从 Google Cloud 控制台启用)。$ gcloud services enable tpu.googleapis.com
运行以下命令以创建服务身份。
$ gcloud beta services identity create --service tpu.googleapis.com
使用 gcloud
创建 Cloud TPU 虚拟机
使用 Cloud TPU 虚拟机时,模型和代码直接在 TPU 主机上运行。您通过 SSH 直接连接到 TPU 主机。您可以直接在 TPU 主机上运行任意代码、安装软件包、查看日志和调试代码。
从 Cloud Shell 或安装了 Google Cloud CLI 的计算机终端运行以下命令,以创建 TPU 虚拟机。
(vm)$ gcloud compute tpus tpu-vm create tpu-name \ --zone=us-central1-a \ --accelerator-type=v3-8 \ --version=tpu-ubuntu2204-base
连接到 Cloud TPU 虚拟机
使用以下命令通过 SSH 连接到 TPU 虚拟机:
$ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central1-a
必需的字段
tpu_name
- 要连接的 TPU 虚拟机的名称。
zone
- 您创建 Cloud TPU 的可用区。
在 Cloud TPU 虚拟机上安装 JAX
(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
系统检查
验证 JAX 是否可以访问 TPU 并运行基本操作:
启动 Python 3 解释器:
(vm)$ python3
>>> import jax
显示可用的 TPU 核心数:
>>> jax.device_count()
系统会显示 TPU 核心数。如果您使用的是 v4 TPU,则此值应为 4
。如果您使用的是 v2 或 v3 TPU,则此值应为 8
。
执行简单的计算:
>>> jax.numpy.add(1, 1)
将显示 numpy add 的结果:
命令的输出为:
Array(2, dtype=int32, weak_type=true)
退出 Python 解释器:
>>> exit()
在 TPU 虚拟机上运行 JAX 代码
现在,您可以运行任何 JAX 代码。要运行 JAX 中的标准机器学习模型,Flax 示例是一个很好的起点。例如,如需训练基本 MNIST 卷积网络,请执行以下操作:
安装 Flax 示例依赖项
(vm)$ pip install --upgrade clu (vm)$ pip install tensorflow (vm)$ pip install tensorflow_datasets
安装 FLAX
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user flax
运行 FLAX MNIST 训练脚本
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
该脚本会下载数据集并开始训练。脚本输出应如下所示:
0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88 I0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72 I0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04 I0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15 I0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18
清理
完成 TPU 虚拟机的操作后,请按照以下步骤清理资源。
断开与 Compute Engine 实例的连接(如果您尚未这样做):
(vm)$ exit
删除您的 Cloud TPU。
$ gcloud compute tpus tpu-vm delete tpu-name \ --zone=us-central1-a
通过运行以下命令来验证资源已删除。确保您的 TPU 不再列出。删除操作可能需要几分钟时间才能完成。
$ gcloud compute tpus tpu-vm list \ --zone=us-central1-a
性能说明
以下是对于在 JAX 中使用 TPU 尤其重要的一些细节。
填充
TPU 性能下降的最常见原因之一是引入了意外填充:
- Cloud TPU 中会平铺数组。这需要将其中一个维度填充为 8 的倍数,并将另一个维度填充为 128 的倍数。
- 矩阵乘法单元在用于大型矩阵对时性能最佳,大型矩阵对可最大限度地减少填充需求。
bfloat16 dtype
默认情况下,TPU 上 JAX 中的矩阵乘法使用 bfloat16 和 float32 累积。这可以通过相关 jax.numpy 函数调用(matmul、dot、einsum 等)中的精度参数来控制。特别是:
precision=jax.lax.Precision.DEFAULT
:使用混合 bfloat16 精度(最快)precision=jax.lax.Precision.HIGH
:使用多个 MXU 通道来实现更高的精度precision=jax.lax.Precision.HIGHEST
:使用更多 MXU 通道来实现完整的 float32 精度
JAX 还添加了 bfloat16 dtype,可用于将数组明确转换为 bfloat16
,例如 jax.numpy.array(x, dtype=jax.numpy.bfloat16)
。
在 Colab 中运行 JAX
当您在 Colab 笔记本中运行 JAX 代码时,Colab 会自动创建旧版 TPU 节点。TPU 节点具有不同的架构。如需了解详情,请参阅系统架构。
后续步骤
如需详细了解 Cloud TPU,请参阅: