使用 Pax 在单主机 TPU 上训练


本文档简要介绍了如何在单主机 TPU(v2-8、v3-8、v4-8)上使用 Pax。

Pax 是一个框架,用于在 JAX 之上配置和运行机器学习实验。Pax 专注于通过与现有机器学习框架共享基础架构组件,并利用 Praxis 建模库实现模块化,从而简化大规模机器学习。

目标

  • 设置 TPU 资源以进行训练
  • 在单主机 TPU 上安装 Pax
  • 使用 Pax 训练基于 Transformer 的 SPMD 模型

准备工作

运行以下命令配置 gcloud,以使用您的 Cloud TPU 项目并安装在单主机 TPU 上训练运行 Pax 的模型所需的组件。

安装 Google Cloud CLI

Google Cloud CLI 包含用于与 Google Cloud CLI 产品和服务交互的工具和库。如果您之前尚未安装该工具,请按照安装 Google Cloud CLI 中的说明进行安装。

配置 gcloud 命令

(运行 gcloud auth list 可查看可用账号)。

$ gcloud config set account account

$ gcloud config set project project-id

启用 Cloud TPU API

Cloud Shell 中使用以下 gcloud 命令启用 Cloud TPU API。(您也可以从 Google Cloud 控制台启用)。

$ gcloud services enable tpu.googleapis.com

运行以下命令以创建服务身份(服务账号)。

$ gcloud beta services identity create --service tpu.googleapis.com

创建 TPU 虚拟机

使用 Cloud TPU 虚拟机时,模型和代码直接在 TPU 虚拟机上运行。您可以通过 SSH 直接连接到 TPU 虚拟机。您可以直接在 TPU 虚拟机上运行任意代码、安装软件包、查看日志和调试代码。

从 Cloud Shell 或安装了 Google Cloud CLI 的计算机终端运行以下命令,以创建 TPU 虚拟机。

根据合同中的可用性设置 zone,如有需要,请参阅 TPU 区域和可用区

accelerator-type 变量设置为 v2-8、v3-8 或 v4-8。

对于 v2 和 v3 TPU 版本,将 version 变量设置为 tpu-vm-base;对于 v4 TPU,将其设置为 tpu-vm-v4-base

$ gcloud compute tpus tpu-vm create tpu-name \
--zone zone \
--accelerator-type accelerator-type \
--version version

连接到 Google Cloud TPU 虚拟机

使用以下命令通过 SSH 连接到 TPU 虚拟机:

$ gcloud compute tpus tpu-vm ssh tpu-name --zone zone

登录虚拟机后,shell 提示符会从 username@projectname 更改为 username@vm-name

在 Google Cloud TPU 虚拟机上安装 Pax

使用以下命令在 TPU 虚拟机上安装 Pax、JAX 和 libtpu

(vm)$ python3 -m pip install -U pip \
python3 -m pip install paxml jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

系统检查

通过检查 JAX 是否看到 TPU 核心来测试是否已正确安装所有组件:

(vm)$ python3 -c "import jax; print(jax.device_count())"

系统会显示 TPU 核心数,如果您使用的是 v2-8 或 v3-8,则应为 8;如果您使用的是 v4-8,则应为 4。

在 TPU 虚拟机上运行 Pax 代码

现在,您可以运行任何 Pax 代码。如需开始在 Pax 中运行模型,lm_cloud 示例是一个很好的起点。例如,以下命令会基于合成数据训练20 亿参数的基于转换器的 SPMD 语言模型。

以下命令显示了 SPMD 语言模型的训练输出。它大约需要 20 分钟完成 300 步训练。

(vm)$ python3 .local/lib/python3.10/site-packages/paxml/main.py  --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps --job_log_dir=job_log_dir

在 v4-8 slice 上,输出应包含:

损失和步数时间

步数为 step_# 时的摘要张量 loss = loss
步数为 step_# 时的摘要张量 每秒步数 x

清理

为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

完成 TPU 虚拟机的操作后,请按照以下步骤清理资源。

断开与 Compute Engine 实例的连接(如果您尚未这样做):

(vm)$ exit

删除您的 Cloud TPU。

$ gcloud compute tpus tpu-vm delete tpu-name  --zone zone

后续步骤

如需详细了解 Cloud TPU,请参阅: