Entrena con aceleradores de TPU

Vertex AI admite el entrenamiento con varios frameworks y bibliotecas mediante una VM de TPU. Cuando configuras recursos de procesamiento, puedes especificar VMs de TPU v2, TPU v3 o TPU v5e. TPU v5e es compatible con JAX 0.4.6+, TensorFlow 2.15+ y PyTorch 2.1+. Si deseas obtener detalles sobre cómo configurar las VMs de TPU para el entrenamiento personalizado, consulta Configura recursos de procesamiento para el entrenamiento personalizado.

Entrenamiento de TensorFlow

Contenedor previamente compilado

Usa un contenedor de entrenamiento compilado previamente que admita TPU y crea una aplicación de entrenamiento de Python.

Contenedor personalizado

Usa un contenedor personalizado en el que hayas instalado versiones de tensorflow y libtpu especialmente compiladas para VMs de TPU. El servicio de Cloud TPU mantiene estas bibliotecas y se enumeran en la documentación de Configuraciones de TPU compatibles.

Selecciona la versión de tensorflow que desees y su biblioteca libtpu correspondiente. A continuación, instálalos en tu imagen de contenedor de Docker cuando compiles el contenedor.

Por ejemplo, si deseas usar TensorFlow 2.12, incluye las siguientes instrucciones en tu Dockerfile:

  # Download and install `tensorflow`.
  RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.15.0/tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

  # Download and install `libtpu`.
  # You must save `libtpu.so` in the '/lib' directory of the container image.
  RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.9.0/libtpu.so -o /lib/libtpu.so

  # TensorFlow training on TPU v5e requires the PJRT runtime. To enable the PJRT
  # runtime, configure the following environment variables in your Dockerfile.
  # For details, see https://cloud.google.com/tpu/docs/runtimes#tf-pjrt-support.
  # ENV NEXT_PLUGGABLE_DEVICE_USE_C_API=true
  # ENV TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so

pod de TPU

El entrenamiento tensorflow en un TPU Pod requiere una configuración adicional en el contenedor de entrenamiento. Vertex AI mantiene una imagen base de Docker que controla la configuración inicial.

URIs de imágenes Versión de Python y versión de TPU
  • us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
  • europe-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
  • asia-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
3.8
  • us-docker.pkg.dev/vertex-ai/training/tf-tpu.2-15-pod-base-cp310:latest
  • europe-docker.pkg.dev/vertex-ai/training/tf-tpu.2-15-pod-base-cp310:latest
  • asia-docker.pkg.dev/vertex-ai/training/tf-tpu.2-15-pod-base-cp310:latest
3.10

A continuación, se muestran los pasos para compilar tu contenedor personalizado:

  1. Elige la imagen base para la versión de Python que elijas. Las ruedas de TensorFlow de TPU para TensorFlow 2.12 y versiones anteriores son compatibles con Python 3.8. TensorFlow 2.13 y las versiones posteriores son compatibles con Python 3.10 o una versión posterior. Para las ruedas de TensorFlow específicas, consulta Configuraciones de Cloud TPU.
  2. Extiende la imagen con tu código de entrenador y comando de inicio.
# Specifies base image and tag
FROM us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
WORKDIR /root

# Download and install `tensorflow`.
RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.12.0/tensorflow-2.12.0-cp38-cp38-linux_x86_64.whl

# Download and install `libtpu`.
# You must save `libtpu.so` in the '/lib' directory of the container image.
RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.6.0/libtpu.so -o /lib/libtpu.so

# Copies the trainer code to the docker image.
COPY your-path-to/model.py /root/model.py
COPY your-path-to/trainer.py /root/trainer.py

# The base image is setup so that it runs the CMD that you provide.
# You can provide CMD inside the Dockerfile like as follows.
# Alternatively, you can pass it as an `args` value in ContainerSpec:
# (https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec#containerspec)
CMD ["python3", "trainer.py"]

Entrenamiento de PyTorch

Puedes usar contenedores personalizados o compilados previamente para PyTorch cuando entrenas con TPU.

Contenedor previamente compilado

Usa un contenedor de entrenamiento compilado previamente que admita TPU y crea una aplicación de entrenamiento de Python.

Contenedor personalizado

Usa un contenedor personalizado en el que instalaste la biblioteca PyTorch.

Por ejemplo, tu Dockerfile podría verse de la siguiente manera:

FROM python:3.10

# v5e specific requirement - enable PJRT runtime
ENV PJRT_DEVICE=TPU

# install pytorch and torch_xla
RUN pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0
 -f https://storage.googleapis.com/libtpu-releases/index.html

# Add your artifacts here
COPY trainer.py .

# Run the trainer code
CMD ["python3", "trainer.py"]

pod de TPU

El entrenamiento se ejecuta en todos los hosts del Pod de TPU (consulta Ejecuta el código JAX en porciones de Pod de TPU).

Vertex AI espera una respuesta de todos los hosts para decidir que se completó el trabajo.

Entrenamiento de JAX

Contenedor previamente compilado

No hay contenedores compilados previamente para JAX.

Contenedor personalizado

Usa un contenedor personalizado en el que instalaste la biblioteca JAX.

Por ejemplo, tu Dockerfile podría verse de la siguiente manera:

# Install JAX.
RUN pip install 'jax[tpu]>=0.4.6' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Add your artifacts here
COPY trainer.py trainer.py

# Set an entrypoint.
ENTRYPOINT ["python3", "trainer.py"]

pod de TPU

El entrenamiento se ejecuta en todos los hosts del pod de TPU (consulta Ejecuta el código JAX en porciones de pod de TPU).

Vertex AI observa el primer host del pod de TPU para decidir la finalización del trabajo. Puedes usar el siguiente fragmento de código para asegurarte de que todos los hosts salgan al mismo tiempo:

# Your training logic
...

if jax.process_count() > 1:
  # Make sure all hosts stay up until the end of main.
  x = jnp.ones([jax.local_device_count()])
  x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x))
  assert x[0] == jax.device_count()

Variables de entorno

En la siguiente tabla, se detallan las variables de entorno que puedes usar dentro del contenedor:

Nombre Valor
TPU_NODE_NAME my-first-tpu-node
TPU_CONFIG {"project": "tenant-project-xyz", "zone": "us-central1-b", "tpu_node_name": "my-first-tpu-node"}

Cuenta de servicio personalizada

Se puede usar una cuenta de servicio personalizada para el entrenamiento de TPU. Si deseas obtener información sobre cómo usar una cuenta de servicio personalizada, consulta la página para usar una cuenta de servicio personalizada.

IP privada (intercambio de tráfico entre redes de VPC) para el entrenamiento

Se puede usar una IP privada para el entrenamiento de TPU. Consulta la página sobre cómo usar una IP privada para el entrenamiento personalizado.

Controles del servicio de VPC

Los proyectos habilitados de los Controles del servicio de VPC pueden enviar trabajos de entrenamiento de TPU.

Limitaciones

Las siguientes limitaciones se aplican cuando entrenas con una VM de TPU:

Tipos de TPU

Consulta los tipos de TPU para obtener más información sobre los aceleradores de TPU, como el límite de memoria.