在 Docker 容器中运行 Cloud TPU 应用

Docker 容器可将您的代码和所有所需依赖项组合到一个可分发软件包中,从而简化应用配置。您可以在 TPU 虚拟机中运行 Docker 容器,以简化 Cloud TPU 应用的配置和共享。本文档介绍了如何为 Cloud TPU 支持的每个机器学习框架设置 Docker 容器。

在 Docker 容器中训练 TensorFlow 模型

TPU 设备

  1. 在当前目录中创建一个名为 Dockerfile 的文件,并将以下文本粘贴到其中

    FROM python:3.8
    RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.12.0/tensorflow-2.12.0-cp38-cp38-linux_x86_64.whl
    RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.6.0/libtpu.so -o /lib/libtpu.so
    RUN git clone https://github.com/tensorflow/models.git
    WORKDIR ./models
    RUN pip install -r official/requirements.txt
    ENV PYTHONPATH=/models
  2. 创建 Cloud Storage 存储桶

    gcloud storage buckets create gs://your-bucket-name --location=europe-west4
  3. 创建 TPU 虚拟机

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-vm-tf-2.18.0-pjrt
  4. 将 Dockerfile 复制到 TPU 虚拟机

    gcloud compute tpus tpu-vm scp ./Dockerfile your-tpu-name:
  5. 通过 SSH 连接到 TPU 虚拟机

    gcloud compute tpus tpu-vm ssh your-tpu-name \
    --zone=europe-west4-a
  6. 构建 Docker 映像

    sudo docker build -t your-image-name .
  7. 启动 Docker 容器

    sudo docker run -ti --rm --net=host --name your-container-name --privileged your-image-name bash
  8. 设置环境变量

    export STORAGE_BUCKET=gs://your-bucket-name
    export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet
    export MODEL_DIR=${STORAGE_BUCKET}/resnet-2x
  9. 训练 ResNet

    python3 official/vision/train.py \
    --tpu=local \
    --experiment=resnet_imagenet \
    --mode=train_and_eval \
    --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \
    --model_dir=${MODEL_DIR} \
    --params_override="task.train_data.input_path=${DATA_DIR}/train*, task.validation_data.input_path=${DATA_DIR}/validation*,trainer.train_steps=100"

训练脚本完成后,请务必清理资源。

  1. 输入 exit 以退出 Docker 容器
  2. 输入 exit 以退出 TPU 虚拟机
  3. 删除 TPU 虚拟机
     $ gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a

TPU Pod

  1. 在当前目录中创建一个名为 Dockerfile 的文件,并将以下文本粘贴到其中

    FROM python:3.8
    RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.12.0/tensorflow-2.12.0-cp38-cp38-linux_x86_64.whl
    RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.6.0/libtpu.so -o /lib/libtpu.so
    RUN git clone https://github.com/tensorflow/models.git
    WORKDIR ./models
    RUN pip install -r official/requirements.txt
    ENV PYTHONPATH=/models
  2. 创建 TPU 虚拟机

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v3-32 \
    --version=tpu-vm-tf-2.18.0-pod-pjrt
  3. 将 Dockerfile 复制到 TPU 虚拟机

    gcloud compute tpus tpu-vm scp ./Dockerfile your-tpu-name:
  4. 通过 SSH 连接到 TPU 虚拟机

    gcloud compute tpus tpu-vm ssh your-tpu-name \
    --zone=europe-west4-a
  5. 构建 Docker 映像

    sudo docker build -t your-image-name .
  6. 启动 Docker 容器

    sudo docker run -ti --rm --net=host --name your-container-name --privileged your-image-name bash
  7. 训练 ResNet

    python3 official/vision/train.py \
    --tpu=local \
    --experiment=resnet_imagenet \
    --mode=train_and_eval \
    --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \
    --model_dir=${MODEL_DIR} \
    --params_override="task.train_data.input_path=${DATA_DIR}/train*, task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"

训练脚本完成后,请务必清理资源。

  1. 输入 exit 以退出 Docker 容器
  2. 输入 exit 以退出 TPU 虚拟机
  3. 删除 TPU 虚拟机
      $ gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a

在 Docker 容器中训练 PyTorch 模型

TPU 设备

  1. 创建 Cloud TPU 虚拟机

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  2. 通过 SSH 连接到 TPU 虚拟机

    gcloud compute tpus tpu-vm ssh your-tpu-name \
    --zone=europe-west4-a
  3. 使用每夜 PyTorch/XLA 映像在 TPU 虚拟机中启动容器。

    sudo docker run -ti --rm --name your-container-name --privileged gcr.io/tpu-pytorch/xla:r2.0_3.8_tpuvm bash
  4. 配置 TPU 运行时

    PyTorch/XLA 运行时有两个选项:PJRT 和 XRT。除非您有理由使用 XRT,否则我们建议您使用 PJRT。如需详细了解不同的运行时配置,请参阅您有理由使用 XRT。如需详细了解不同的运行时配置,请参阅 PJRT 运行时文档

    PJRT

    export PJRT_DEVICE=TPU

    XRT

    export XRT_TPU_CONFIG="localservice;0;localhost:51011"
  5. 克隆 PyTorch XLA 代码库

    git clone --recursive https://github.com/pytorch/xla.git
  6. 训练 ResNet50

    python3 xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1

训练脚本完成后,请务必清理资源。

  1. 输入 exit 以退出 Docker 容器
  2. 输入 exit 以退出 TPU 虚拟机
  3. 删除 TPU 虚拟机
     $ gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a

TPU Pod

在 TPU Pod 上运行 PyTorch 代码时,您必须同时在所有 TPU 工作器上运行代码。为此,一种方法是将 gcloud compute tpus tpu-vm ssh 命令与 --worker=all--command 标志结合使用。以下过程介绍了如何创建 Docker 映像,以便更轻松地设置每个 TPU 工作器。

  1. 创建 TPU 虚拟机

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=us-central2-b \
    --accelerator-type=v4-32 \
    --version=tpu-ubuntu2204-base
  2. 将当前用户添加到 docker 组

    gcloud compute tpus tpu-vm ssh your-tpu-name \
    --zone=us-central2-b \
    --worker=all \
    --command="sudo usermod -a -G docker $USER"
  3. 在所有 TPU 工作器上的容器中运行训练脚本。

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=us-central2-b \
    --command="docker run --rm --privileged --net=host  -e PJRT_DEVICE=TPU gcr.io/tpu-pytorch/xla:r2.0_3.8_tpuvm python /pytorch/xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1"

    Docker 命令标志:

    • --rm 在容器进程终止后移除容器。
    • --privileged 将 TPU 设备公开给容器。
    • --net=host 会将容器的所有端口绑定到 TPU 虚拟机,以允许 Pod 中的主机之间进行通信。
    • -e 设置环境变量。

训练脚本完成后,请务必清理资源。

使用以下命令删除 TPU 虚拟机:

$ gcloud compute tpus tpu-vm delete your-tpu-name \
  --zone=us-central2-b

在 Docker 容器中训练 JAX 模型

TPU 设备

  1. 创建 TPU 虚拟机

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  2. 通过 SSH 连接到 TPU 虚拟机

    gcloud compute tpus tpu-vm ssh your-tpu-name  --zone=europe-west4-a
  3. 在 TPU 虚拟机中启动 Docker 守护程序

    sudo systemctl start docker
  4. 启动 Docker 容器

    sudo docker run -ti --rm --name your-container-name --privileged --network=host python:3.8 bash
  5. 安装 JAX

    pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  6. 安装 FLAX

    pip install --upgrade clu
    git clone https://github.com/google/flax.git
    pip install --user -e flax
  7. 运行 FLAX MNIST 训练脚本

    cd flax/examples/mnist
    python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5

训练脚本完成后,请务必清理资源。

  1. 输入 exit 以退出 Docker 容器
  2. 输入 exit 以退出 TPU 虚拟机
  3. 删除 TPU 虚拟机

    $ gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a

TPU Pod

在 TPU Pod 上运行 JAX 代码时,您必须同时在所有 TPU 工作器上运行 JAX 代码。为此,一种方法是使用带有 --worker=all--command 标志的 gcloud compute tpus tpu-vm ssh 命令。以下过程介绍了如何创建 Docker 映像,以便更轻松地设置每个 TPU 工作器。

  1. 在当前目录中创建一个名为 Dockerfile 的文件,并将以下文本粘贴到其中

    FROM python:3.8
    RUN pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    RUN pip install --upgrade clu
    RUN git clone https://github.com/google/flax.git
    RUN pip install --user -e flax
    WORKDIR ./flax/examples/mnist
  2. 构建 Docker 映像

    docker build -t your-image-name .
  3. 请先为 Docker 映像添加标记,然后再将其推送到 Artifact Registry。如需详细了解如何使用 Artifact Registry,请参阅使用容器映像

    docker tag your-image-name europe-west-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
  4. 将 Docker 映像推送到 Artifact Registry

    docker push europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
  5. 创建 TPU 虚拟机

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type==v2-8 \
    --version=tpu-ubuntu2204-base
  6. 在所有 TPU 工作器上从 Artifact Registry 拉取 Docker 映像。

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="sudo usermod -a -G docker ${USER}"
    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="gcloud auth configure-docker europe-west4-docker.pkg.dev --quiet"
    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker pull europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag"
  7. 在所有 TPU 工作器上运行容器。

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    zone=europe-west4-a \
    --command="docker run -ti -d --privileged --net=host --name your-container-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image:your-tag bash"
  8. 在所有 TPU 工作器上运行训练脚本:

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker exec --privileged your-container-name python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5"

训练脚本完成后,请务必清理资源。

  1. 在所有工作器上关闭容器:

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker kill your-container-name"
  2. 使用以下命令删除 TPU 虚拟机:

    $ gcloud compute tpus tpu-vm delete your-tpu-name \
    --zone=europe-west4-a

使用 JAX 稳定堆栈在 Docker 容器中训练 JAX 模型

您可以使用 JAX 稳定堆栈基本映像构建 MaxTextMaxDiffusion Docker 映像。

JAX 稳定堆栈通过将 JAX 与 orbaxflaxoptaxlibtpu.so 等核心软件包捆绑在一起,为 MaxText 和 MaxDiffusion 提供了一致的环境。这些库已过测试,以确保兼容性,并为构建和运行 MaxText 和 MaxDiffusion 提供了稳定的基础。 这样可以消除因软件包版本不兼容而导致的潜在冲突。

JAX 稳定版堆栈包含一个已完全发布且经过认证的 libtpu.so,这是驱动 TPU 程序编译、执行和 ICI 网络配置的核心库。libtpu 版本取代了 JAX 之前使用的每夜 build,并通过 HLO/StableHLO IR 中的 PJRT 级资格测试确保 XLA 计算在 TPU 上的功能一致。

如需使用 JAX 稳定堆栈构建 MaxText 和 MaxDiffusion Docker 映像,请在运行 docker_build_dependency_image.sh 脚本时,将 MODE 变量设置为 stable_stack,并将 BASEIMAGE 变量设置为要使用的基准映像。

以下示例将 us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1 指定为基础映像:

   bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1
   

如需查看可用的 JAX 稳定堆栈基础映像的列表,请参阅 Artifact Registry 中的 JAX 稳定堆栈映像

后续步骤