在 TPU Pod 切片上运行 JAX 代码

在单个 TPU 板上运行 JAX 代码后,您可以通过在 TPU Pod 切片上运行代码来扩容代码。 TPU Pod 切片是通过专用高速网络连接相互连接的多个 TPU 板。本文档介绍了如何在 TPU Pod 切片上运行 JAX 代码;如需了解更深入的信息,请参阅在多主机和多进程环境中使用 JAX

如果您想使用装载的 NFS 进行数据存储,则必须为所有服务设置 OS Login Pod 切片中的 TPU 虚拟机。如需了解详情,请参阅 使用 NFS 进行数据存储

创建 TPU Pod 切片

在运行本文档中的命令之前,请确保您已按照 设置账号和 Cloud TPU 项目中的说明。 在本地机器上运行以下命令。

使用 gcloud 命令可以创建 TPU Pod 切片。例如,要创建 v4-32 Pod 切片使用如下命令:

$ gcloud compute tpus tpu-vm create tpu-name  \
  --zone=us-central2-b \
  --accelerator-type=v4-32  \
  --version=tpu-ubuntu2204-base 

在 Pod 切片上安装 JAX

创建 TPU Pod 切片之后,您必须在 TPU Pod 切片中的所有主机上安装 JAX。您可以使用 --worker=all 选项通过一个命令在所有主机上安装 JAX:

  gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b --worker=all --command="pip install \
  --upgrade 'jax[tpu]>0.3.0' \
  -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

在 Pod 切片上运行 JAX 代码

要在 TPU Pod 切片上运行 JAX 代码,您必须在 TPU Pod 切片中的每个主机上运行代码jax.device_count() 调用停止响应,直到 Pod 切片中每个主机上调用的方法。以下示例说明了如何 在 TPU Pod 切片上运行简单的 JAX 计算。

准备代码

您的 gcloud 版本不低于 344.0.0(对于 scp 命令)。 使用 gcloud --version 检查您的 gcloud 版本, 运行 gcloud components upgrade(如果需要)。

使用以下代码创建一个名为 example.py 的文件:

# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the Pod
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

example.py 复制到 Pod 切片中的所有 TPU 工作器虚拟机

$ gcloud compute tpus tpu-vm scp example.py tpu-name: \
  --worker=all \
  --zone=us-central2-b

如果您以前未使用过 scp 命令,则可能会看到 错误,类似于以下内容:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

如需解决此错误,请运行 ssh-add 命令,如 错误消息并重新运行该命令。

在 Pod 切片上运行代码

在每个虚拟机上启动 example.py 程序:

$ gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b \
  --worker=all \
  --command="python3 example.py"

输出(使用 v4-32 Pod 切片生成):

global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]

清理

完成后,您可以使用 gcloud 命令释放 TPU 虚拟机资源:

$ gcloud compute tpus tpu-vm delete tpu-name \
  --zone=us-central2-b