本教程介绍如何使用 PyTorch 在 Cloud TPU 设备上训练 ResNet-50 模型。您可以将同一模式应用于使用 PyTorch 和 ImageNet 数据集的其他针对 TPU 进行了优化的图片分类模型。
本教程中的模型基于用于图片识别的深度残差学习,率先引入了残差网络 (ResNet) 架构。本教程使用 50 层变体 ResNet-50,演示如何使用 PyTorch/XLA 训练模型。
目标
- 准备数据集。
- 运行训练作业。
- 验证输出结果。
费用
在本文档中,您将使用 Google Cloud的以下收费组件:
- Compute Engine
- Cloud TPU
您可使用价格计算器根据您的预计使用情况来估算费用。
准备工作
在开始学习本教程之前,请检查您的 Google Cloud 项目是否已正确设置。
- Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Verify that billing is enabled for your Google Cloud project.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Verify that billing is enabled for your Google Cloud project.
本演示使用 Google Cloud的收费组件。请查看 Cloud TPU 价格页面估算您的费用。请务必在使用完您创建的资源以后清理这些资源,以免产生不必要的费用。
创建 TPU 虚拟机
打开一个 Cloud Shell 窗口。
创建 TPU 虚拟机
gcloud compute tpus tpu-vm create your-tpu-name \ --accelerator-type=v3-8 \ --version=tpu-ubuntu2204-base \ --zone=us-central1-a \ --project=your-project
使用 SSH 连接到 TPU 虚拟机:
gcloud compute tpus tpu-vm ssh your-tpu-name --zone=us-central1-a
在 TPU 虚拟机上安装 PyTorch/XLA:
(vm)$ pip install torch torch_xla[tpu] torchvision -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
-
(vm)$ git clone --depth=1 https://github.com/pytorch/xla.git
使用虚构数据运行训练脚本
(vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
清理
为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。
与 TPU 虚拟机断开连接:
(vm) $ exit
您的提示符现在应为
username@projectname
,表明您位于 Cloud Shell 中。删除您的 TPU 虚拟机。
$ gcloud compute tpus tpu-vm delete your-tpu-name \ --zone=us-central1-a
后续步骤