在 TPU Pod 切片上运行 JAX 代码

在运行本文档中的命令之前,请确保已按照设置账号和 Cloud TPU 项目中的说明操作。

在单个 TPU 板上运行 JAX 代码后,您可以通过在 TPU Pod 切片上运行代码来扩容代码。 TPU Pod 切片是通过专用高速网络连接相互连接的多个 TPU 板。本文档介绍了如何在 TPU Pod 切片上运行 JAX 代码;如需了解更深入的信息,请参阅在多主机和多进程环境中使用 JAX

如果您要使用已装载的 NFS 存储数据,则必须为 Pod 切片中的所有 TPU 虚拟机设置 OS Login。如需了解详情,请参阅使用 NFS 存储数据

创建 Cloud TPU Pod 切片

  1. 创建一些环境变量:

    export PROJECT_ID=your-project
    export ACCELERATOR_TYPE=v5p-32
    export ZONE=europe-west4-b
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export TPU_NAME=your-tpu-name

    环境变量说明

    PROJECT_ID
    您的 Google Cloud 项目 ID。
    ACCELERATOR_TYPE
    加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
    ZONE
    拟在其中创建 Cloud TPU 的可用区
    RUNTIME_VERSION
    Cloud TPU 运行时版本
    TPU_NAME
    用户为 Cloud TPU 分配的名称。
  2. 使用 gcloud 命令可以创建 TPU Pod 切片。例如,如需创建 v5p-32 Pod 切片,请使用以下命令:

    $ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
    --zone=${ZONE} \
    --project=${PROJECT_ID} \
    --accelerator-type=${ACCELERATOR_TYPE}  \
    --version=${RUNTIME_VERSION} 

在 Pod 切片上安装 JAX

创建 TPU Pod 切片之后,您必须在 TPU Pod 切片中的所有主机上安装 JAX。您可以使用 gcloud compute tpus tpu-vm ssh 命令并使用 --worker=all--commamnd 参数执行此操作。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

在 Pod 切片上运行 JAX 代码

要在 TPU Pod 切片上运行 JAX 代码,您必须在 TPU Pod 切片中的每个主机上运行代码jax.device_count() 调用在它在 Pod 切片中的每个主机上被调用之前将一直挂起。以下示例展示了如何在 TPU Pod 切片上运行 JAX 计算。

准备代码

您需要 gcloud 344.0.0 版或更高版本(对于 scp 命令)。使用 gcloud --version 检查您的 gcloud 版本,并根据需要运行 gcloud components upgrade

使用以下代码创建一个名为 example.py 的文件:


import jax

# The total number of TPU cores in the Pod
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

example.py 复制到 Pod 切片中的所有 TPU 工作器虚拟机

$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \
  --worker=all \
  --zone=${ZONE} \
  --project=${PROJECT_ID}

如果您以前没有使用过 scp 命令,则可能会看到类似于以下内容的错误:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

如需解决此错误,请运行错误消息中显示的 ssh-add 命令,并重新运行该命令。

在 Pod 切片上运行代码

在每个虚拟机上启动 example.py 程序:

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command="python3 ./example.py"

输出(使用 v4-32 Pod 切片生成):

global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]

清理

完成 TPU 虚拟机的操作后,请按照以下步骤清理资源。

  1. 断开与 Compute Engine 实例的连接(如果您尚未这样做):

    (vm)$ exit

    您的提示符现在应为 username@projectname,表明您位于 Cloud Shell 中。

  2. 删除您的 Cloud TPU 和 Compute Engine 资源。

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID}
  3. 通过运行 gcloud compute tpus execution-groups list 验证资源是否已删除。删除操作可能需要几分钟时间才能完成。以下命令的输出不应包含本教程中创建的任何资源:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE} \
    --project=${PROJECT_ID}