在 TPU Pod 切片上运行 PyTorch 代码

PyTorch/XLA 要求所有 TPU 虚拟机都能够访问模型代码和数据。您可以使用启动脚本下载将模型数据分发到所有 TPU 虚拟机所需的软件。

如果要将 TPU 虚拟机连接到虚拟私有云 (VPC),您必须在项目中添加防火墙规则,以允许端口 8470 - 8479 传入入站流量。如需详细了解如何添加防火墙规则,请参阅使用防火墙规则

设置您的环境

  1. 在 Cloud Shell 中,运行以下命令以确保您运行的是当前版本的 gcloud

    $ gcloud components update
    

    如果您需要安装 gcloud,请使用以下命令:

    $ sudo apt install -y google-cloud-sdk
  2. 创建一些环境变量:

    $ export PROJECT_ID=project-id
    $ export TPU_NAME=tpu-name
    $ export ZONE=us-central2-b
    $ export RUNTIME_VERSION=tpu-ubuntu2204-base
    $ export ACCELERATOR_TYPE=v4-32
    

创建 TPU 虚拟机

$ gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--project=${PROJECT_ID} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version ${RUNTIME_VERSION}

配置并运行训练脚本

  1. 将您的 SSH 证书添加到您的项目中:

    ssh-add ~/.ssh/google_compute_engine
    
  2. 在所有 TPU 虚拟机工作器上安装 PyTorch/XLA

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="
      pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
    
  3. 在所有 TPU 虚拟机工作器上克隆 XLA

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="git clone -b r2.2 https://github.com/pytorch/xla.git"
    
  4. 在所有工作器上运行训练脚本

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all \
      --command="PJRT_DEVICE=TPU python3 ~/xla/test/test_train_mp_imagenet.py  \
      --fake_data \
      --model=resnet50  \
      --num_epochs=1 2>&1 | tee ~/logs.txt"
      

    训练大约需要 5 分钟。完成后,您应该会看到类似于下面这样的消息:

    Epoch 1 test end 23:49:15, Accuracy=100.00
    10.164.0.11 [0] Max Accuracy: 100.00%
    

清理

完成 TPU 虚拟机的操作后,请按照以下步骤清理资源。

  1. 断开与 Compute Engine 的连接:

    (vm)$ exit
    
  2. 通过运行以下命令来验证资源已删除。确保您的 TPU 不再列出。删除操作可能需要几分钟时间才能完成。

    $ gcloud compute tpus tpu-vm list \
      --zone europe-west4-a