在 TPU Pod 切片上运行 JAX 代码

在单个 TPU 板上运行 JAX 代码后,您可以通过在 TPU Pod 切片上运行代码来扩容代码。 TPU Pod 切片是通过专用高速网络连接相互连接的多个 TPU 板。本文档介绍了如何在 TPU Pod 切片上运行 JAX 代码;如需了解更深入的信息,请参阅在多主机和多进程环境中使用 JAX

如果您要使用已装载的 NFS 存储数据,则必须为 Pod 切片中的所有 TPU 虚拟机设置 OS Login。如需了解详情,请参阅使用 NFS 存储数据

创建 TPU Pod 切片

在运行本文档中的命令之前,请确保已按照设置帐号和 Cloud TPU 项目中的说明操作。在本地机器上运行以下命令。

使用 gcloud 命令可以创建 TPU Pod 切片。例如,如需创建 v2-32 Pod 切片,请使用以下命令:

$ gcloud alpha compute tpus tpu-vm create tpu-name \
  --zone europe-west4-a \
  --accelerator-type v2-32 \
  --version v2-alpha

在 Pod 切片上安装 JAX

创建 TPU Pod 切片之后,您必须在 TPU Pod 切片中的所有主机上安装 JAX。您可以使用 --worker=all 选项通过一个命令在所有主机上安装 JAX:

$ gcloud alpha compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a \
  --worker=all \
  --command="pip install 'jax[tpu]>=0.2.16' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

在 Pod 切片上运行 JAX 代码

要在 TPU Pod 切片上运行 JAX 代码,您必须在 TPU Pod 切片中的每个主机上运行代码。这意味着您必须通过 SSH 连接到每个主机,并在每个主机上执行 JAX 代码。以下 Python 代码说明了如何使用 gcloud 命令的 --worker=all 选项在 TPU Pod 切片上运行简单的 JAX 计算。

准备代码

您需要 gcloud 344.0.0 版或更高版本(对于 scp 命令)。使用 gcloud --version 检查您的 gcloud 版本,并根据需要运行 gcloud components upgrade

example.py 写入本地机器:

   cat > example.py << EOF

   # The following code snippet will be run on all TPU hosts
   import jax

   # The total number of TPU cores in the Pod
   device_count = jax.device_count()

   # The number of TPU cores attached to this host
   local_device_count = jax.local_device_count()

   # The psum is performed over all mapped devices across the Pod
   xs = jax.numpy.ones(jax.local_device_count())
   r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

   # Print from a single host to avoid duplicated output
   if jax.process_index() == 0:
       print('global device count:', jax.device_count())
       print('local device count:', jax.local_device_count())
       print('pmap result:', r)
   EOF

example.py 复制到 Pod 切片中的所有虚拟机。

$ gcloud alpha compute tpus tpu-vm scp example.py tpu-name: \
  --worker=all --zone=europe-west4-a

如果这是您第一次使用 scp 命令,您可能会看到类似于以下内容的错误:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH agent.
Please run "ssh-add /.../.ssh/google_compute_engine" to add it, and try again.
运行错误消息中显示的 ssh-add 命令,并重新运行该命令来解决该错误。

在 Pod 切片上运行代码

在每个虚拟机上启动 example.py 程序:

$ gcloud alpha compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a --worker=all --command "python3 example.py"

输出(使用 v2-32 Pod 切片生成):

global device count: 32
local device count: 8
pmap result: [32. 32. 32. 32. 32. 32. 32. 32.]

这是在每个主机上运行 JAX Python 代码的方法之一,但您可以使用您喜欢的任何方法。无论您使用什么运行方法,上面的 jax.device_count() 调用在它在 Pod 切片中的每个主机上被调用之前将一直挂起,因为所有主机都必须存在才能初始化 TPU 运行时。

清除数据

完成后,您可以使用 gcloud 命令释放 TPU 虚拟机资源:

$ gcloud alpha compute tpus tpu-vm delete tpu-name \
  --zone europe-west4-a