使用 TPU 加速器进行训练

Vertex AI 支持使用 TPU 虚拟机通过各种框架和库进行训练。 配置计算资源时,您可以指定 TPU v2TPU v3TPU v5e 虚拟机。TPU v5e 支持 JAX 0.4.6+、TensorFlow 2.15+ 和 PyTorch 2.1+。如需详细了解如何为自定义训练配置 TPU 虚拟机,请参阅为自定义训练配置计算资源

TensorFlow 训练

预构建容器

使用支持 TPU 的预构建训练容器,并创建一个 Python 训练应用

自定义容器

使用自定义容器,该容器应安装了专为 TPU 虚拟机构建的 tensorflowlibtpu 版本。这些库由 Cloud TPU 服务维护,并在支持的 TPU 配置文档中列出。

选择所需的 tensorflow 版本及其对应的 libtpu 库。接下来,在构建容器时,在 Docker 容器映像中安装这些组件。

例如,如果您要使用 TensorFlow 2.12,请在 Dockerfile 中添加以下说明:

  # Download and install `tensorflow`.
  RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.15.0/tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

  # Download and install `libtpu`.
  # You must save `libtpu.so` in the '/lib' directory of the container image.
  RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.9.0/libtpu.so -o /lib/libtpu.so

  # TensorFlow training on TPU v5e requires the PJRT runtime. To enable the PJRT
  # runtime, configure the following environment variables in your Dockerfile.
  # For details, see https://cloud.google.com/tpu/docs/runtimes#tf-pjrt-support.
  # ENV NEXT_PLUGGABLE_DEVICE_USE_C_API=true
  # ENV TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so

TPU Pod

如需在 TPU Pod 上进行 tensorflow 训练,您需要在训练容器中进行一些额外的设置。Vertex AI 会维护一个用于处理初始设置的基础 Docker 映像。

映像 URI Python 版本和 TPU 版本
  • us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
  • europe-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
  • asia-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
3.8
  • us-docker.pkg.dev/vertex-ai/training/tf-tpu.2-15-pod-base-cp310:latest
  • europe-docker.pkg.dev/vertex-ai/training/tf-tpu.2-15-pod-base-cp310:latest
  • asia-docker.pkg.dev/vertex-ai/training/tf-tpu.2-15-pod-base-cp310:latest
3.10

以下是构建自定义容器的步骤:

  1. 为您选择的 Python 版本选择基础映像。适用于 TensorFlow 2.12 及更低版本的 TPU TensorFlow 轮子支持 Python 3.8。TensorFlow 2.13 及更高版本支持 Python 3.10 或更高版本。如需了解具体的 TensorFlow 轮子,请参阅 Cloud TPU 配置
  2. 使用训练程序代码和启动命令扩展映像。
# Specifies base image and tag
FROM us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
WORKDIR /root

# Download and install `tensorflow`.
RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.12.0/tensorflow-2.12.0-cp38-cp38-linux_x86_64.whl

# Download and install `libtpu`.
# You must save `libtpu.so` in the '/lib' directory of the container image.
RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.6.0/libtpu.so -o /lib/libtpu.so

# Copies the trainer code to the docker image.
COPY your-path-to/model.py /root/model.py
COPY your-path-to/trainer.py /root/trainer.py

# The base image is setup so that it runs the CMD that you provide.
# You can provide CMD inside the Dockerfile like as follows.
# Alternatively, you can pass it as an `args` value in ContainerSpec:
# (https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec#containerspec)
CMD ["python3", "trainer.py"]

PyTorch 训练

使用 TPU 进行训练时,您可以使用预构建或自定义容器。

预构建容器

使用支持 TPU 的预构建训练容器,并创建一个 Python 训练应用

自定义容器

使用在其中安装了 PyTorch 库的自定义容器

例如,您的 Dockerfile 可能如下所示:

FROM python:3.10

# v5e specific requirement - enable PJRT runtime
ENV PJRT_DEVICE=TPU

# install pytorch and torch_xla
RUN pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0
 -f https://storage.googleapis.com/libtpu-releases/index.html

# Add your artifacts here
COPY trainer.py .

# Run the trainer code
CMD ["python3", "trainer.py"]

TPU Pod

训练会在 TPU Pod 的所有主机上运行(请参阅在 TPU Pod 切片上运行 PyTorch 代码)。

Vertex AI 会等待所有主机的响应以确定作业的完成情况。

JAX 训练

预构建容器

JAX 没有预构建容器。

自定义容器

使用在其中安装了 JAX 库的自定义容器

例如,您的 Dockerfile 可能如下所示:

# Install JAX.
RUN pip install 'jax[tpu]>=0.4.6' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Add your artifacts here
COPY trainer.py trainer.py

# Set an entrypoint.
ENTRYPOINT ["python3", "trainer.py"]

TPU Pod

训练会在 TPU Pod 的所有主机上运行(请参阅在 TPU Pod 切片上运行 JAX 代码)。

Vertex AI 会监控 TPU Pod 的第一个主机以确定作业的完成情况。您可以使用以下代码段来确保所有主机同时退出:

# Your training logic
...

if jax.process_count() > 1:
  # Make sure all hosts stay up until the end of main.
  x = jnp.ones([jax.local_device_count()])
  x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x))
  assert x[0] == jax.device_count()

环境变量

下表详细介绍了可在容器中使用的环境变量:

名称
TPU_NODE_NAME my-first-tpu-node
TPU_CONFIG {"project": "tenant-project-xyz", "zone": "us-central1-b", "tpu_node_name": "my-first-tpu-node"}

自定义服务账号

自定义服务账号可用于 TPU 训练。如需了解如何使用自定义服务账号,请参阅如何使用自定义服务账号页面。

用于训练的专用 IP(VPC 网络对等互连)

专用 IP 可用于 TPU 训练。请参阅有关如何使用专用 IP 进行自定义训练的页面。

VPC Service Controls

启用了 VPC Service Controls 的项目可以提交 TPU 训练作业。

限制

使用 TPU 虚拟机进行训练时存在以下限制:

TPU 类型

如需详细了解 TPU 加速器(例如内存限制),请参阅 TPU 类型