使用 PyTorch 在 Cloud TPU 上训练 Resnet50
使用集合让一切井井有条
根据您的偏好保存内容并对其进行分类。
本教程介绍如何使用 PyTorch 在 Cloud TPU 设备上训练 ResNet-50 模型。您可以将同一模式应用于使用 PyTorch 和 ImageNet 数据集的其他针对 TPU 进行了优化的图片分类模型。
本教程中的模型基于用于图片识别的深度残差学习,率先引入了残差网络 (ResNet) 架构。本教程使用 50 层变体 ResNet-50,演示如何使用 PyTorch/XLA 训练模型。
费用
在本文档中,您将使用 Google Cloud 的以下收费组件:
您可使用价格计算器根据您的预计使用情况来估算费用。
Google Cloud 新用户可能有资格申请免费试用。
准备工作
在开始学习本教程之前,请检查您的 Google Cloud 项目是否已正确设置。
-
登录您的 Google Cloud 账号。如果您是 Google Cloud 新手,请创建一个账号来评估我们的产品在实际场景中的表现。新客户还可获享 $300 赠金,用于运行、测试和部署工作负载。
-
In the Google Cloud console, on the project selector page,
select or create a Google Cloud project.
Go to project selector
-
确保您的 Google Cloud 项目已启用结算功能。
-
In the Google Cloud console, on the project selector page,
select or create a Google Cloud project.
Go to project selector
-
确保您的 Google Cloud 项目已启用结算功能。
本演示使用 Google Cloud 的收费组件。请查看 Cloud TPU 价格页面估算您的费用。执行上述操作时,请务必清理您创建的资源,
以免产生不必要的费用
创建 TPU 虚拟机
打开一个 Cloud Shell 窗口。
打开 Cloud Shell
创建 TPU 虚拟机
gcloud compute tpus tpu-vm create your-tpu-name \
--accelerator-type=v4-8 \
--version=tpu-ubuntu2204-base \
--zone=us-central2-b \
--project=your-project
使用 SSH 连接到您的 TPU 虚拟机:
gcloud compute tpus tpu-vm ssh your-tpu-name --zone=us-central2-b
在 TPU 虚拟机上安装 PyTorch/XLA:
(vm)$ pip install torch~=2.4.0 torch_xla[tpu]~=2.4.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html
克隆 PyTorch/XLA GitHub 代码库
(vm)$ git clone --depth=1 --branch r2.4 https://github.com/pytorch/xla.git
使用虚构数据运行训练脚本
(vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
如果您能够使用虚构数据训练模型,则可以尝试
例如 ImageNet,如需了解如何下载 ImageNet,请参阅下载 ImageNet。在训练脚本命令中,--datadir
标志用于指定要用于训练的数据集的位置。以下命令假定 ImageNet 数据集位于 ~/imagenet
。
(vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --datadir=~/imagenet --batch_size=256 --num_epochs=1
清理
为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。
断开与 TPU 虚拟机的连接:
(vm) $ exit
您的提示符现在应为 username@projectname
,表明您位于 Cloud Shell 中。
删除您的 TPU 虚拟机。
$ gcloud compute tpus tpu-vm delete resnet50-tutorial \
--zone=us-central2-b
如未另行说明,那么本页面中的内容已根据知识共享署名 4.0 许可获得了许可,并且代码示例已根据 Apache 2.0 许可获得了许可。有关详情,请参阅 Google 开发者网站政策。Java 是 Oracle 和/或其关联公司的注册商标。
最后更新时间 (UTC):2024-10-04。
[{
"type": "thumb-down",
"id": "hardToUnderstand",
"label":"Hard to understand"
},{
"type": "thumb-down",
"id": "incorrectInformationOrSampleCode",
"label":"Incorrect information or sample code"
},{
"type": "thumb-down",
"id": "missingTheInformationSamplesINeed",
"label":"Missing the information/samples I need"
},{
"type": "thumb-down",
"id": "translationIssue",
"label":"翻译问题"
},{
"type": "thumb-down",
"id": "otherDown",
"label":"其他"
}]
[{
"type": "thumb-up",
"id": "easyToUnderstand",
"label":"易于理解"
},{
"type": "thumb-up",
"id": "solvedMyProblem",
"label":"解决了我的问题"
},{
"type": "thumb-up",
"id": "otherUp",
"label":"其他"
}]
{"lastModified": "\u6700\u540e\u66f4\u65b0\u65f6\u95f4 (UTC)\uff1a2024-10-04\u3002"}
[[["易于理解","easyToUnderstand","thumb-up"],["解决了我的问题","solvedMyProblem","thumb-up"],["其他","otherUp","thumb-up"]],[["Hard to understand","hardToUnderstand","thumb-down"],["Incorrect information or sample code","incorrectInformationOrSampleCode","thumb-down"],["Missing the information/samples I need","missingTheInformationSamplesINeed","thumb-down"],["翻译问题","translationIssue","thumb-down"],["其他","otherDown","thumb-down"]],["最后更新时间 (UTC):2024-10-04。"]]