Vertex AI supports training with various frameworks and libraries using a TPU VM. When configuring compute resources, you can specify TPU v2, TPU v3, or TPU v5e VMs. TPU v5e supports JAX 0.4.6+, TensorFlow 2.15+, and PyTorch 2.1+. For details on configuring TPU VMs for custom training, see Configure compute resources for custom training.
TensorFlow training
Prebuilt container
Use a prebuilt training container that supports TPUs, and create a Python training application.
Custom container
Use a custom container in which you
have installed versions of the tensorflow
and libtpu
specially built
for TPU VMs. These libraries are maintained by the Cloud TPU
service and are listed in the
Supported TPU configurations
documentation.
Select the tensorflow
version of your choice and its corresponding libtpu
library. Next, install these in your Docker container image when you build
the container.
For example, if you want to use TensorFlow 2.12, include the following instructions in your Dockerfile:
# Download and install `tensorflow`.
RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.15.0/tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
# Download and install `libtpu`.
# You must save `libtpu.so` in the '/lib' directory of the container image.
RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.9.0/libtpu.so -o /lib/libtpu.so
# TensorFlow training on TPU v5e requires the PJRT runtime. To enable the PJRT
# runtime, configure the following environment variables in your Dockerfile.
# For details, see https://cloud.google.com/tpu/docs/runtimes#tf-pjrt-support.
# ENV NEXT_PLUGGABLE_DEVICE_USE_C_API=true
# ENV TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so
TPU Pod
tensorflow
training on a TPU Pod
requires additional setup in the training
container. Vertex AI maintains a base docker image that handles
the initial setup.
Image URIs | Python Version and TPU Version |
---|---|
|
3.8 |
|
3.10 |
Here are the steps to build your custom container:
- Choose the base image for the Python version of your choice. TPU TensorFlow wheels for TensorFlow 2.12 and lower support Python 3.8. TensorFlow 2.13 and greater support Python 3.10 or greater. For the specific TensorFlow wheels, see Cloud TPU configurations.
- Extend the image with your trainer code and the startup command.
# Specifies base image and tag
FROM us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
WORKDIR /root
# Download and install `tensorflow`.
RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.12.0/tensorflow-2.12.0-cp38-cp38-linux_x86_64.whl
# Download and install `libtpu`.
# You must save `libtpu.so` in the '/lib' directory of the container image.
RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.6.0/libtpu.so -o /lib/libtpu.so
# Copies the trainer code to the docker image.
COPY your-path-to/model.py /root/model.py
COPY your-path-to/trainer.py /root/trainer.py
# The base image is setup so that it runs the CMD that you provide.
# You can provide CMD inside the Dockerfile like as follows.
# Alternatively, you can pass it as an `args` value in ContainerSpec:
# (https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec#containerspec)
CMD ["python3", "trainer.py"]
PyTorch training
You can use prebuilt or custom containers for PyTorch when training with TPUs.
Prebuilt container
Use a prebuilt training container that supports TPUs, and create a Python training application.
Custom container
Use a custom container in which you
have installed the PyTorch
library.
For example, your Dockerfile might look like the following:
FROM python:3.10
# v5e specific requirement - enable PJRT runtime
ENV PJRT_DEVICE=TPU
# install pytorch and torch_xla
RUN pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0
-f https://storage.googleapis.com/libtpu-releases/index.html
# Add your artifacts here
COPY trainer.py .
# Run the trainer code
CMD ["python3", "trainer.py"]
TPU Pod
The training runs on all hosts of the TPU Pod (see Run PyTorch code on TPU Pod slices).
Vertex AI waits for a response from all the hosts to decide completion of the job.
JAX training
Prebuilt container
There are no prebuilt containers for JAX.
Custom container
Use a custom container in which you
have installed the JAX
library.
For example, your Dockerfile might look like the following:
# Install JAX.
RUN pip install 'jax[tpu]>=0.4.6' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Add your artifacts here
COPY trainer.py trainer.py
# Set an entrypoint.
ENTRYPOINT ["python3", "trainer.py"]
TPU Pod
The training runs on all hosts of the TPU Pod (see Run JAX code on TPU Pod slices).
Vertex AI watches the first host of the TPU Pod to decide completion of the job. You can use the following code snippet to make sure that all hosts exit at the same time:
# Your training logic
...
if jax.process_count() > 1:
# Make sure all hosts stay up until the end of main.
x = jnp.ones([jax.local_device_count()])
x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x))
assert x[0] == jax.device_count()
Environment variables
The following table details the environment variables that you can use within the container:
Name | Value |
---|---|
TPU_NODE_NAME | my-first-tpu-node |
TPU_CONFIG | {"project": "tenant-project-xyz", "zone": "us-central1-b", "tpu_node_name": "my-first-tpu-node"} |
Custom Service Account
A custom service account can be used for TPU training. On how to use a custom service account, refer to the page on how to use a custom service account.
Private IP (VPC network peering) for training
A private IP can be used for TPU training. Refer to the page on how to use a private IP for custom training.
VPC Service Controls
VPC Service Controls enabled projects can submit TPU training jobs.
Limitations
The following limitations apply when you train using a TPU VM:
TPU types
Refer to TPU types for more information about TPU accelerators such as memory limit.