使用 PyTorch 在 Cloud TPU 上训练 Resnet50


本教程介绍如何使用 PyTorch 在 Cloud TPU 设备上训练 ResNet-50 模型。您可以将同一模式应用于使用 PyTorch 和 ImageNet 数据集的其他针对 TPU 进行了优化的图片分类模型。

本教程中的模型基于用于图片识别的深度残差学习,率先引入了残差网络 (ResNet) 架构。本教程使用 50 层变体 ResNet-50,演示如何使用 PyTorch/XLA 训练模型。

目标

  • 准备数据集。
  • 运行训练作业。
  • 验证输出结果。

费用

在本文档中,您将使用 Google Cloud 的以下收费组件:

  • Compute Engine
  • Cloud TPU

您可使用价格计算器根据您的预计使用情况来估算费用。 Google Cloud 新用户可能有资格申请免费试用

准备工作

在开始学习本教程之前,请检查您的 Google Cloud 项目是否已正确设置。

  1. 登录您的 Google Cloud 账号。如果您是 Google Cloud 新手,请创建一个账号来评估我们的产品在实际场景中的表现。新客户还可获享 $300 赠金,用于运行、测试和部署工作负载。
  2. 在 Google Cloud Console 中的项目选择器页面上,选择或创建一个 Google Cloud 项目

    转到“项目选择器”

  3. 确保您的 Google Cloud 项目已启用结算功能

  4. 在 Google Cloud Console 中的项目选择器页面上,选择或创建一个 Google Cloud 项目

    转到“项目选择器”

  5. 确保您的 Google Cloud 项目已启用结算功能

  6. 本演示使用 Google Cloud 的收费组件。请查看 Cloud TPU 价格页面估算您的费用。请务必在使用完您创建的资源以后清理这些资源,以免产生不必要的费用。

创建 TPU 虚拟机

  1. 打开一个 Cloud Shell 窗口。

    打开 Cloud Shell

  2. 创建 TPU 虚拟机

    gcloud compute tpus tpu-vm create your-tpu-name \
    --accelerator-type=v4-8 \
    --version=tpu-ubuntu2204-base \
    --zone=us-central2-b \
    --project=your-project
    
  3. 使用 SSH 连接到您的 TPU 虚拟机:

    gcloud compute tpus tpu-vm ssh  your-tpu-name --zone=us-central2-b
    
  4. 在 TPU 虚拟机上安装 PyTorch/XLA:

    (vm)$ pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html
    
  5. 克隆 PyTorch/XLA GitHub 代码库

    (vm)$ git clone --depth=1 --branch r2.2 https://github.com/pytorch/xla.git
    
  6. 使用虚构数据运行训练脚本

    (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
    

如果您能够使用虚构数据训练模型,则可以尝试用真实数据(例如 ImageNet)进行训练。如需了解如何下载 ImageNet,请参阅下载 ImageNet。在训练脚本命令中,--datadir 标志指定要训练的数据集的位置。以下命令假定 ImageNet 数据集位于 ~/imagenet

   (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py  --datadir=~/imagenet --batch_size=256 --num_epochs=1
   

清理

为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

  1. 断开与 TPU 虚拟机的连接:

    (vm) $ exit
    

    您的提示符现在应为 username@projectname,表明您位于 Cloud Shell 中。

  2. 删除您的 TPU 虚拟机。

    $ gcloud compute tpus tpu-vm delete resnet50-tutorial \
       --zone=us-central2-b
    

后续步骤

试用 PyTorch Colab: