在 TPU Pod 切片上运行 PyTorch 代码
PyTorch/XLA 要求所有 TPU 虚拟机都能够访问模型代码和数据。您可以使用启动脚本下载将模型数据分发到所有 TPU 虚拟机所需的软件。
如果要将 TPU 虚拟机连接到虚拟私有云 (VPC),您必须在项目中添加防火墙规则,以允许端口 8470 - 8479 传入入站流量。如需详细了解如何添加防火墙规则,请参阅使用防火墙规则
设置您的环境
在 Cloud Shell 中,运行以下命令以确保您运行的是当前版本的
gcloud
:$ gcloud components update
如果您需要安装
gcloud
,请使用以下命令:$ sudo apt install -y google-cloud-sdk
创建一些环境变量:
$ export PROJECT_ID=project-id $ export TPU_NAME=tpu-name $ export ZONE=us-central2-b $ export RUNTIME_VERSION=tpu-ubuntu2204-base $ export ACCELERATOR_TYPE=v4-32
创建 TPU 虚拟机
$ gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--project=${PROJECT_ID} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version ${RUNTIME_VERSION}
配置并运行训练脚本
将您的 SSH 证书添加到您的项目中:
ssh-add ~/.ssh/google_compute_engine
在所有 TPU 虚拟机工作器上安装 PyTorch/XLA
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all --command=" pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
在所有 TPU 虚拟机工作器上克隆 XLA
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all --command="git clone -b r2.2 https://github.com/pytorch/xla.git"
在所有工作器上运行训练脚本
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="PJRT_DEVICE=TPU python3 ~/xla/test/test_train_mp_imagenet.py \ --fake_data \ --model=resnet50 \ --num_epochs=1 2>&1 | tee ~/logs.txt"
训练大约需要 5 分钟。完成后,您应该会看到类似于下面这样的消息:
Epoch 1 test end 23:49:15, Accuracy=100.00 10.164.0.11 [0] Max Accuracy: 100.00%
清理
完成 TPU 虚拟机的操作后,请按照以下步骤清理资源。
断开与 Compute Engine 的连接:
(vm)$ exit
通过运行以下命令来验证资源已删除。确保您的 TPU 不再列出。删除操作可能需要几分钟时间才能完成。
$ gcloud compute tpus tpu-vm list \ --zone europe-west4-a