本教程介绍如何使用 PyTorch 在 Cloud TPU 设备上训练 ResNet-50 模型。您可以将同一模式应用于使用 PyTorch 和 ImageNet 数据集的其他针对 TPU 进行了优化的图片分类模型。
本教程中的模型基于用于图片识别的深度残差学习,率先引入了残差网络 (ResNet) 架构。本教程使用 50 层变体 ResNet-50,演示如何使用 PyTorch/XLA 训练模型。
目标
- 准备数据集。
- 运行训练作业。
- 验证输出结果。
费用
在本文档中,您将使用 Google Cloud 的以下收费组件:
- Compute Engine
- Cloud TPU
您可使用价格计算器根据您的预计使用情况来估算费用。
准备工作
在开始学习本教程之前,请检查您的 Google Cloud 项目是否已正确设置。
- Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Google Cloud project.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Google Cloud project.
本演示使用 Google Cloud 的收费组件。请查看 Cloud TPU 价格页面估算您的费用。请务必在使用完您创建的资源以后清理这些资源,以免产生不必要的费用。
创建 TPU 虚拟机
打开一个 Cloud Shell 窗口。
创建 TPU 虚拟机
gcloud compute tpus tpu-vm create your-tpu-name \ --accelerator-type=v4-8 \ --version=tpu-ubuntu2204-base \ --zone=us-central2-b \ --project=your-project
使用 SSH 连接到 TPU 虚拟机:
gcloud compute tpus tpu-vm ssh your-tpu-name --zone=us-central2-b
在 TPU 虚拟机上安装 PyTorch/XLA:
(vm)$ pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html
-
(vm)$ git clone --depth=1 --branch r2.5 https://github.com/pytorch/xla.git
使用虚构数据运行训练脚本
(vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
如果您能够使用虚构数据训练模型,可以尝试使用真实数据(例如 ImageNet)进行训练。如需了解如何下载 ImageNet,请参阅下载 ImageNet。在训练脚本命令中,--datadir
标志用于指定要用于训练的数据集的位置。以下命令假设 ImageNet 数据集位于 ~/imagenet
中。
(vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --datadir=~/imagenet --batch_size=256 --num_epochs=1
清理
为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。
断开与 TPU 虚拟机的连接:
(vm) $ exit
您的提示符现在应为
username@projectname
,表明您位于 Cloud Shell 中。删除您的 TPU 虚拟机。
$ gcloud compute tpus tpu-vm delete resnet50-tutorial \ --zone=us-central2-b
后续步骤