在 TPU Pod 切片上运行 PyTorch 代码

PyTorch/XLA 要求所有 TPU 虚拟机都能够访问模型代码和数据。您可以使用启动脚本下载将模型数据分发到所有 TPU 虚拟机所需的软件。

如果您要将 TPU 虚拟机连接到虚拟私有云 (VPC),则必须在项目中添加防火墙规则,以允许端口 8470-8479 的入站流量。如需详细了解如何添加防火墙规则,请参阅使用防火墙规则

设置环境

  1. 在 Cloud Shell 中,运行以下命令以确保您运行的是当前版本的 gcloud

    $ gcloud components update

    如果您需要安装 gcloud,请使用以下命令:

    $ sudo apt install -y google-cloud-sdk
  2. 创建一些环境变量:

    $ export PROJECT_ID=project-id
    $ export TPU_NAME=tpu-name
    $ export ZONE=us-central2-b
    $ export RUNTIME_VERSION=tpu-ubuntu2204-base
    $ export ACCELERATOR_TYPE=v4-32

创建 TPU 虚拟机

$ gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--project=${PROJECT_ID} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version ${RUNTIME_VERSION}

配置和运行训练脚本

  1. 将 SSH 证书添加到您的项目中:

    ssh-add ~/.ssh/google_compute_engine
  2. 在所有 TPU VM 工作器上安装 PyTorch/XLA

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="
      pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
  3. 在所有 TPU VM 工作器上克隆 XLA

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="git clone -b r2.5 https://github.com/pytorch/xla.git"
  4. 在所有工作器上运行训练脚本

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all \
      --command="PJRT_DEVICE=TPU python3 ~/xla/test/test_train_mp_imagenet.py  \
      --fake_data \
      --model=resnet50  \
      --num_epochs=1 2>&1 | tee ~/logs.txt"
      

    训练大约需要 5 分钟。完成后,您应该会看到类似于下面这样的消息:

    Epoch 1 test end 23:49:15, Accuracy=100.00
    10.164.0.11 [0] Max Accuracy: 100.00%
    

清理

完成 TPU 虚拟机的操作后,请按照以下步骤清理资源。

  1. 断开与 Compute Engine 实例的连接(如果您尚未这样做):

    (vm)$ exit

    您的提示符现在应为 username@projectname,表明您位于 Cloud Shell 中。

  2. 删除您的 Cloud TPU 和 Compute Engine 资源。

    $ gcloud compute tpus tpu-vm delete  \
      --zone=${ZONE}
  3. 通过运行 gcloud compute tpus execution-groups list 验证资源是否已删除。删除操作可能需要几分钟时间才能完成。以下命令的输出不应包含本教程中创建的任何资源:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE}