使用 PyTorch 在 Cloud TPU 虚拟机上运行计算

本快速入门介绍如何创建 Cloud TPU、如何安装 PyTorch,以及如何在 Cloud TPU 上运行简单的计算。如需更深入地了解如何在 Cloud TPU 上训练模型,请参阅 Cloud TPU PyTorch 教程之一。

准备工作

在按照本快速入门操作之前,您必须创建一个 Google Cloud Platform 账号,安装 Google Cloud CLI,并配置 gcloud 命令。如需了解详情,请参阅设置账号和 Cloud TPU 项目

使用 gcloud 创建 Cloud TPU

在默认用户项目、网络和可用区中创建 TPU 虚拟机 运行:

$ gcloud compute tpus tpu-vm create tpu-name \
   --zone=us-central1-a \
   --accelerator-type=v3-8 \
   --version=tpu-ubuntu2204-base

命令标志说明

zone
区域 创建 Cloud TPU 的位置。
accelerator-type
加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。 如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
version
Cloud TPU 软件版本

创建 TPU 时,如果您想指定默认网络和子网,可以传递额外的 --network--subnetwork 标志。如果您不想使用默认网络,则必须传递 --network 标志。--subnetwork 标志是可选的,可用于为您使用的任何网络(默认网络或用户指定的网络)指定默认子网。请参阅gcloud 如需详细了解这些标志,请参阅 API 参考文档页面

连接到 Cloud TPU 虚拟机

   $ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central1-a

在 TPU 虚拟机上安装 PyTorch/XLA

   (vm)$ pip install torch~=2.4.0 torch_xla[tpu]~=2.4.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html
   

设置 TPU 运行时配置

确保 PyTorch/XLA 运行时使用 TPU。

   (vm) $ export PJRT_DEVICE=TPU

验证 PyTorch 是否可以访问 TPU

  1. 在当前目录中创建一个名为 tpu-count.py 的文件,并将以下脚本复制粘贴到其中。

    import torch
    import torch_xla.core.xla_model as xm
    print(f'PyTorch can access {len(torch_xla.devices())} TPU cores')
    
  2. 运行脚本:

    (vm)$ python3 tpu-count.py

    脚本的输出显示计算结果:

    PyTorch can access 8 TPU cores
    

执行基本计算

  1. 在当前目录中创建一个名为 tpu-test.py 的文件,并将以下脚本复制粘贴到其中。

    import torch
    import torch_xla.core.xla_model as xm
    
    dev = xm.xla_device()
    t1 = torch.randn(3,3,device=dev)
    t2 = torch.randn(3,3,device=dev)
    print(t1 + t2)
    
  2. 运行脚本:

      (vm)$ python3 tpu-test.py

    脚本的输出显示计算结果:

    tensor([[-0.2121,  1.5589, -0.6951],
            [-0.7886, -0.2022,  0.9242],
            [ 0.8555, -1.8698,  1.4333]], device='xla:1')
    

清理

为避免因本页中使用的资源导致您的 Google Cloud 账号产生费用,请按照以下步骤操作。

  1. 断开与 Compute Engine 实例的连接(如果您尚未这样做):

    (vm)$ exit

    您的提示符现在应为 username@projectname,表明您位于 Cloud Shell 中。

  2. 删除您的 Cloud TPU。

    $ gcloud compute tpus tpu-vm delete tpu-name \
      --zone=us-central1-a

此命令的输出应确认 TPU 已被删除。

后续步骤

详细了解 Cloud TPU 虚拟机: