Training mit TPU-Beschleunigern

Vertex AI unterstützt das Training mit verschiedenen Frameworks und Bibliotheken mit einer TPU-VM. Beim Konfigurieren von Rechenressourcen können Sie TPU v2-, TPU v3- oder TPU v5e-VMs angeben. TPU v5e unterstützt JAX 0.4.6 und höher, TensorFlow 2.15 und höher sowie PyTorch 2.1 und höher. Weitere Informationen zum Konfigurieren von TPU-VMs für benutzerdefiniertes Training finden Sie unter Rechenressourcen für benutzerdefiniertes Training konfigurieren.

TensorFlow-Training

Vordefinierter Container

Verwenden Sie einen vordefinierten Trainingscontainer, der TPUs unterstützt, und erstellen Sie eine Python-Trainingsanwendung.

Benutzerdefinierter Container

Verwenden Sie einen benutzerdefinierten Container, in dem Sie Versionen von tensorflow und libtpu installiert haben, die speziell für TPU-VMs erstellt wurden. Diese Bibliotheken werden vom Cloud TPU-Dienst verwaltet und in der Dokumentation Unterstützte TPU-Konfigurationen aufgeführt.

Wählen Sie die gewünschte tensorflow-Version und die entsprechende libtpu-Bibliothek aus. Installieren Sie diese dann beim Erstellen des Containers im Docker-Container-Image.

Wenn Sie beispielsweise TensorFlow 2.12 verwenden möchten, fügen Sie Ihrem Dockerfile die folgende Anleitung hinzu:

  # Download and install `tensorflow`.
  RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.15.0/tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

  # Download and install `libtpu`.
  # You must save `libtpu.so` in the '/lib' directory of the container image.
  RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.9.0/libtpu.so -o /lib/libtpu.so

  # TensorFlow training on TPU v5e requires the PJRT runtime. To enable the PJRT
  # runtime, configure the following environment variables in your Dockerfile.
  # For details, see https://cloud.google.com/tpu/docs/runtimes#tf-pjrt-support.
  # ENV NEXT_PLUGGABLE_DEVICE_USE_C_API=true
  # ENV TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so

TPU-Pod

Das Training von tensorflow auf einem TPU Pod erfordert eine zusätzliche Einrichtung im Trainingscontainer. Vertex AI verwaltet ein Basis-Docker-Image, das die Ersteinrichtung übernimmt.

Image-URIs Python-Version und TPU-Version
  • us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
  • europe-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
  • asia-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
3,8
  • us-docker.pkg.dev/vertex-ai/training/tf-tpu.2-15-pod-base-cp310:latest
  • europe-docker.pkg.dev/vertex-ai/training/tf-tpu.2-15-pod-base-cp310:latest
  • asia-docker.pkg.dev/vertex-ai/training/tf-tpu.2-15-pod-base-cp310:latest
3,10

So erstellen Sie Ihren benutzerdefinierten Container:

  1. Wählen Sie das Basis-Image für die Python-Version Ihrer Wahl aus. TPU TensorFlow Wheels für TensorFlow 2.12 und niedriger unterstützen Python 3.8. TensorFlow 2.13 und höher unterstützen Python 3.10 oder höher. Informationen zu den jeweiligen TensorFlow-Räumen finden Sie unter Cloud TPU-Konfigurationen.
  2. Erweitern Sie das Image mit Ihrem Trainercode und dem Startbefehl.
# Specifies base image and tag
FROM us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
WORKDIR /root

# Download and install `tensorflow`.
RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.12.0/tensorflow-2.12.0-cp38-cp38-linux_x86_64.whl

# Download and install `libtpu`.
# You must save `libtpu.so` in the '/lib' directory of the container image.
RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.6.0/libtpu.so -o /lib/libtpu.so

# Copies the trainer code to the docker image.
COPY your-path-to/model.py /root/model.py
COPY your-path-to/trainer.py /root/trainer.py

# The base image is setup so that it runs the CMD that you provide.
# You can provide CMD inside the Dockerfile like as follows.
# Alternatively, you can pass it as an `args` value in ContainerSpec:
# (https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec#containerspec)
CMD ["python3", "trainer.py"]

PyTorch-Training

Sie können für das Training mit TPUs vordefinierte oder benutzerdefinierte Container für PyTorch verwenden.

Vordefinierter Container

Verwenden Sie einen vordefinierten Trainingscontainer, der TPUs unterstützt, und erstellen Sie eine Python-Trainingsanwendung.

Benutzerdefinierter Container

Verwenden Sie einen benutzerdefinierten Container, in dem Sie die PyTorch-Bibliothek installiert haben.

Ihr Dockerfile könnte beispielsweise so aussehen:

FROM python:3.10

# v5e specific requirement - enable PJRT runtime
ENV PJRT_DEVICE=TPU

# install pytorch and torch_xla
RUN pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0
 -f https://storage.googleapis.com/libtpu-releases/index.html

# Add your artifacts here
COPY trainer.py .

# Run the trainer code
CMD ["python3", "trainer.py"]

TPU-Pod

Das Training wird auf allen Hosts des TPU-Pods ausgeführt (siehe PyTorch-Code auf TPU-Pod-Slices ausführen).

Vertex AI wartet auf eine Antwort von allen Hosts, um den Abschluss des Jobs zu entscheiden.

JAX-Training

Vordefinierter Container

Es gibt keine vordefinierten Container für JAX.

Benutzerdefinierter Container

Verwenden Sie einen benutzerdefinierten Container, in dem Sie die JAX-Bibliothek installiert haben.

Ihr Dockerfile könnte beispielsweise so aussehen:

# Install JAX.
RUN pip install 'jax[tpu]>=0.4.6' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Add your artifacts here
COPY trainer.py trainer.py

# Set an entrypoint.
ENTRYPOINT ["python3", "trainer.py"]

TPU-Pod

Das Training wird auf allen Hosts des TPU-Pods ausgeführt (siehe JAX-Code auf TPU-Pod-Slices ausführen).

Vertex AI überwacht den ersten Host des TPU-Pods, um über den Abschluss des Jobs zu entscheiden. Mit dem folgenden Code-Snippet können Sie dafür sorgen, dass alle Hosts gleichzeitig beendet werden:

# Your training logic
...

if jax.process_count() > 1:
  # Make sure all hosts stay up until the end of main.
  x = jnp.ones([jax.local_device_count()])
  x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x))
  assert x[0] == jax.device_count()

Umgebungsvariablen

In der folgenden Tabelle werden die Umgebungsvariablen beschrieben, die Sie im Container verwenden können:

Name Wert
TPU_NODE_NAME my-first-tpu-node
TPU_CONFIG {"project": "tenant-project-xyz", "zone": "us-central1-b", "tpu_node_name": "my-first-tpu-node"}

Benutzerdefiniertes Dienstkonto

Ein benutzerdefiniertes Dienstkonto kann für das TPU-Training verwendet werden. Informationen zur Verwendung eines benutzerdefinierten Dienstkontos finden Sie auf der Seite zur Verwendung eines benutzerdefinierten Dienstkontos.

Private IP-Adresse (VPC-Netzwerk-Peering) für das Training

Eine private IP-Adresse kann für das TPU-Training verwendet werden. Informationen dazu finden Sie auf der Seite Private IP-Adresse für benutzerdefiniertes Training verwenden.

VPC Service Controls

VPC Service Controls-fähige Projekte können TPU-Trainingsjobs senden.

Beschränkungen

Beim Trainieren mit einer TPU-VM gelten die folgenden Einschränkungen:

TPU-Typen

Weitere Informationen zu TPU-Beschleunigern wie das Arbeitsspeicherlimit finden Sie unter TPU-Typen.