使用 PyTorch 在 Cloud TPU 虚拟机上运行计算

本文档简要介绍了如何搭配使用 PyTorch 和 Cloud TPU。

准备工作

在运行本文档中的命令之前,您必须创建 Google Cloud 账号、安装 Google Cloud CLI 并配置 gcloud 命令。如需了解详情,请参阅设置 Cloud TPU 环境

使用 gcloud 创建 Cloud TPU

  1. 定义一些环境变量,以便更轻松地使用这些命令。

    export PROJECT_ID=your-project
    export ACCELERATOR_TYPE=v5p-8
    export ZONE=us-east5-a
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export TPU_NAME=your-tpu-name

    环境变量说明

    PROJECT_ID
    您的 Google Cloud 项目 ID。
    ACCELERATOR_TYPE
    加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
    ZONE
    拟在其中创建 Cloud TPU 的可用区
    RUNTIME_VERSION
    Cloud TPU 运行时版本
    TPU_NAME
    用户为 Cloud TPU 分配的名称。
  2. 运行以下命令,创建 TPU 虚拟机:

    $ gcloud compute tpus tpu-vm create $TPU_NAME \
    --project=$PROJECT_ID \
    --zone=$ZONE \
    --accelerator-type=$ACCELERATOR_TYPE \
    --version=$RUNTIME_VERSION

连接到 Cloud TPU 虚拟机

使用以下命令通过 SSH 连接到 TPU 虚拟机:

$ gcloud compute tpus tpu-vm ssh $TPU_NAME \
   --project=$PROJECT_ID \
   --zone=$ZONE

在 TPU 虚拟机上安装 PyTorch/XLA

$ (vm) sudo apt-get update
$ (vm) sudo apt-get install libopenblas-dev -y
$ (vm) pip install numpy
$ (vm) pip install torch torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html

验证 PyTorch 是否可以访问 TPU

使用以下命令验证 PyTorch 是否可以访问您的 TPU。

$ (vm) PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))"

该命令的输出应如下所示:

['xla:0', 'xla:1', 'xla:2', 'xla:3']

执行基本计算

  1. 在当前目录中创建一个名为 tpu-test.py 的文件,并将以下脚本复制粘贴到其中。

    import torch
    import torch_xla.core.xla_model as xm
    
    dev = xm.xla_device()
    t1 = torch.randn(3,3,device=dev)
    t2 = torch.randn(3,3,device=dev)
    print(t1 + t2)
    
  2. 运行脚本:

    (vm)$ PJRT_DEVICE=TPU python3 tpu-test.py

    脚本的输出显示计算结果:

    tensor([[-0.2121,  1.5589, -0.6951],
            [-0.7886, -0.2022,  0.9242],
            [ 0.8555, -1.8698,  1.4333]], device='xla:1')
    

清理

为避免因本页中使用的资源导致您的 Google Cloud 账号产生费用,请按照以下步骤操作。

  1. 断开与 Compute Engine 实例的连接(如果您尚未这样做):

    (vm)$ exit

    您的提示符现在应为 username@projectname,表明您位于 Cloud Shell 中。

  2. 删除您的 Cloud TPU。

    $ gcloud compute tpus tpu-vm delete $TPU_NAME \
      --project=$PROJECT_ID \
      --zone=$ZONE

此命令的输出应确认 TPU 已被删除。

后续步骤

详细了解 Cloud TPU 虚拟机: