在 TPU Pod 切片上运行 JAX 代码

在单个 TPU 板上运行 JAX 代码后,您可以在 TPU Pod 切片上运行代码,从而进行扩容。TPU Pod 切片是多个 TPU 板,通过专用高速网络连接相互连接。本文档介绍如何在 TPU Pod 切片上运行 JAX 代码;如需深入了解,请参阅在多主机和多进程环境中使用 JAX

创建 TPU Pod 切片

您可以使用 gcloud 命令创建 TPU Pod 切片。例如,使用以下命令创建 v2-32 Pod 切片:

$ gcloud alpha compute tpus tpu-vm create tpu-name \
  --zone europe-west4-a \
  --accelerator-type v2-32 \
  --version v2-alpha

在 Pod 切片上安装 JAX

创建 TPU Pod 切片之后,您必须在 TPU Pod 切片中的所有主机上安装 JAX。您可以使用 --worker=all 选项通过一个命令在所有主机上安装 JAX:

$ gcloud alpha compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a \
  --worker=all \
  --command="pip install --upgrade jax jaxlib"

在 Pod 切片上运行 JAX 代码

要在 TPU Pod 切片上运行 JAX 代码,您必须在 TPU Pod 切片中的每个主机上运行代码。这意味着您必须通过 SSH 连接到每个主机,并在每个主机上执行 JAX 代码。以下 Python 代码说明了如何使用 gcloud 命令的 --worker=all 选项在 TPU Pod 切片上运行简单的 JAX 计算。

准备代码

$ read -r -d '' PYTHON_CMD << EOF
# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the pod
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)
EOF

在 Pod 切片上运行代码

$ gcloud alpha compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a \
  --worker=all \
  --command "python3 -c \"$PYTHON_CMD\""

输出(使用 v2-32 Pod 切片生成):

global device count: 32
local device count: 8
pmap result: [32. 32. 32. 32. 32. 32. 32. 32.]

这是在每个主机上运行 JAX Python 代码的方法之一,但您可以使用您喜欢的任何方法。无论您使用什么运行方法,上面的 jax.device_count() 调用在它在 Pod 切片中的每个主机上被调用之前将一直挂起,因为所有主机都必须存在才能初始化 TPU 运行时。

清除数据

完成后,您可以使用 gcloud 命令释放 TPU 虚拟机资源:

$ gcloud alpha compute tpus tpu-vm delete tpu-name \
  --zone europe-west4-a