Vertex AI は、TPU VM を使用したさまざまなフレームワークとライブラリを使用したトレーニングをサポートしています。コンピューティング リソースを構成するときに、TPU v2、TPU v3、または TPU v5e VM を指定できます。TPU v5e は、JAX 0.4.6 以降、TensorFlow 2.15 以降、PyTorch 2.1 以降をサポートしています。カスタム トレーニング用の TPU VM の構成の詳細については、カスタム トレーニング用のコンピューティング リソースを構成するをご覧ください。
TensorFlow トレーニング
ビルド済みコンテナ
TPU をサポートするビルド済みのトレーニング コンテナを使用し、Python トレーニング アプリケーションを作成します。
カスタム コンテナ
TPU VM 専用にビルドされた tensorflow
バージョンと libtpu
バージョンがインストールされているカスタム コンテナを使用します。これらのライブラリは Cloud TPU サービスによって管理されており、サポートされている TPU 構成のドキュメントに記載されています。
目的の tensorflow
バージョンと、それに対応する libtpu
ライブラリを選択します。コンテナをビルドするときに、これらを Docker コンテナ イメージにインストールします。
たとえば、TensorFlow 2.12 を使用する場合は Dockerfile に次の指示を含めます。
# Download and install `tensorflow`.
RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.15.0/tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
# Download and install `libtpu`.
# You must save `libtpu.so` in the '/lib' directory of the container image.
RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.9.0/libtpu.so -o /lib/libtpu.so
# TensorFlow training on TPU v5e requires the PJRT runtime. To enable the PJRT
# runtime, configure the following environment variables in your Dockerfile.
# For details, see https://cloud.google.com/tpu/docs/runtimes#tf-pjrt-support.
# ENV NEXT_PLUGGABLE_DEVICE_USE_C_API=true
# ENV TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so
TPU Pod
TPU Pod
で tensorflow
トレーニングを行う場合は、トレーニング コンテナに追加の設定が必要です。Vertex AI は、初期設定を処理するベース Docker イメージを保持します。
イメージの URI | Python バージョンと TPU バージョン |
---|---|
|
3.8 |
|
3.10 |
カスタム コンテナをビルドする手順は次のとおりです。
- 使用する Python バージョンのベースイメージを選択します。TensorFlow 2.12 以前の TPU TensorFlow ホイールは Python 3.8 をサポートしています。TensorFlow 2.13 以降は、Python 3.10 以降をサポートしています。特定の TensorFlow ホイールについては、Cloud TPU の構成をご覧ください。
- トレーナー コードと起動コマンドを使用してイメージを拡張します。
# Specifies base image and tag
FROM us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
WORKDIR /root
# Download and install `tensorflow`.
RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.12.0/tensorflow-2.12.0-cp38-cp38-linux_x86_64.whl
# Download and install `libtpu`.
# You must save `libtpu.so` in the '/lib' directory of the container image.
RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.6.0/libtpu.so -o /lib/libtpu.so
# Copies the trainer code to the docker image.
COPY your-path-to/model.py /root/model.py
COPY your-path-to/trainer.py /root/trainer.py
# The base image is setup so that it runs the CMD that you provide.
# You can provide CMD inside the Dockerfile like as follows.
# Alternatively, you can pass it as an `args` value in ContainerSpec:
# (https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec#containerspec)
CMD ["python3", "trainer.py"]
PyTorch トレーニング
TPU でトレーニングする場合は、PyTorch 用のビルド済みコンテナまたはカスタム コンテナを使用できます。
ビルド済みコンテナ
TPU をサポートするビルド済みのトレーニング コンテナを使用し、Python トレーニング アプリケーションを作成します。
カスタム コンテナ
PyTorch
ライブラリをインストールしたカスタム コンテナを使用します。
たとえば、Dockerfile は次のようになります。
FROM python:3.10
# v5e specific requirement - enable PJRT runtime
ENV PJRT_DEVICE=TPU
# install pytorch and torch_xla
RUN pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0
-f https://storage.googleapis.com/libtpu-releases/index.html
# Add your artifacts here
COPY trainer.py .
# Run the trainer code
CMD ["python3", "trainer.py"]
TPU Pod
トレーニングは TPU Pod のすべてのホストで実行されます(TPU Pod スライスで PyTorch コードを実行するをご覧ください)。
Vertex AI は、すべてのホストからのレスポンスを待ってから、ジョブの完了を決定します。
JAX トレーニング
ビルド済みコンテナ
JAX 用のビルド済みコンテナはありません。
カスタム コンテナ
JAX
ライブラリをインストールしたカスタム コンテナを使用します。
たとえば、Dockerfile は次のようになります。
# Install JAX.
RUN pip install 'jax[tpu]>=0.4.6' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Add your artifacts here
COPY trainer.py trainer.py
# Set an entrypoint.
ENTRYPOINT ["python3", "trainer.py"]
TPU Pod
トレーニングは TPU Pod のすべてのホストで実行されます(TPU Pod スライスでの JAX コードの実行をご覧ください)。
Vertex AI は、TPU Pod の最初のホストを監視してジョブの完了を判断します。次のコード スニペットを使用すると、すべてのホストが同時に終了するようにできます。
# Your training logic
...
if jax.process_count() > 1:
# Make sure all hosts stay up until the end of main.
x = jnp.ones([jax.local_device_count()])
x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x))
assert x[0] == jax.device_count()
環境変数
次の表に、コンテナ内で使用可能な環境変数の詳細を示します。
名前 | 値 |
---|---|
TPU_NODE_NAME | my-first-tpu-node |
TPU_CONFIG | {"project": "tenant-project-xyz", "zone": "us-central1-b", "tpu_node_name": "my-first-tpu-node"} |
カスタム サービス アカウント
TPU トレーニングにはカスタム サービス アカウントを使用できます。カスタム サービス アカウントの使用方法については、カスタム サービス アカウントの使用方法のページをご覧ください。
トレーニング用のプライベート IP(VPC ネットワーク ピアリング)
TPU トレーニングにプライベート IP を使用できます。カスタム トレーニングにプライベート IP を使用する方法のページをご覧ください。
VPC Service Controls
VPC Service Controls が有効になっているプロジェクトでは、TPU トレーニング ジョブを送信できます。
制限事項
TPU VM を使用してトレーニングする場合は、次の制限が適用されます。
TPU タイプ
メモリの上限など、TPU アクセラレータの詳細については、TPU タイプをご覧ください。