使用 PyTorch 在 Cloud TPU 上预训练 Wav2Vec2


本教程介绍如何使用 PyTorch 在 Cloud TPU 设备上预训练 FairSeq 的 Wavev2Vec2 模型。您可以将同一模式应用于使用 PyTorch 和 ImageNet 数据集的其他针对 TPU 进行了优化的图片分类模型。

本教程中的模型基于《wav2vec 2.0:一种自我监督的语音表示学习的框架》论文。

目标

  • 创建并配置 PyTorch 环境。
  • 下载开源 LibriSpeech 数据。
  • 运行训练作业。

费用

在本文档中,您将使用 Google Cloud 的以下收费组件:

  • Compute Engine
  • Cloud TPU

您可使用价格计算器根据您的预计使用情况来估算费用。 Google Cloud 新用户可能有资格申请免费试用

准备工作

在开始学习本教程之前,请检查您的 Google Cloud 项目是否已正确设置。

  1. 在 Google Cloud 控制台的项目选择器页面上,选择或创建 Google Cloud 项目。注意:如果您不打算保留在此过程中创建的资源,请创建新的项目,而不要选择现有的项目。完成本教程介绍的步骤后,您可以删除所创建的项目,并移除与该项目关联的所有资源。
  2. 转到项目选择器页面,确保您的 Cloud 项目已启用结算功能。了解如何确认您的项目是否已启用结算功能
  1. 登录您的 Google Cloud 账号。如果您是 Google Cloud 新手,请创建一个账号来评估我们的产品在实际场景中的表现。新客户还可获享 $300 赠金,用于运行、测试和部署工作负载。
  2. 在 Google Cloud Console 中的项目选择器页面上,选择或创建一个 Google Cloud 项目

    转到“项目选择器”

  3. 确保您的 Google Cloud 项目已启用结算功能

  4. 在 Google Cloud Console 中的项目选择器页面上,选择或创建一个 Google Cloud 项目

    转到“项目选择器”

  5. 确保您的 Google Cloud 项目已启用结算功能

  6. 本演示使用 Google Cloud 的收费组件。请查看 Cloud TPU 价格页面估算您的费用。请务必在使用完您创建的资源以后清理这些资源,以免产生不必要的费用。

设置 Compute Engine 实例

  1. 打开一个 Cloud Shell 窗口。

    打开 Cloud Shell

  2. 为项目 ID 创建一个变量。

    export PROJECT_ID=project-id
    
  3. 配置 Google Cloud CLI,以使用要在其中创建 Cloud TPU 的项目。

    gcloud config set project ${PROJECT_ID}
    

    当您第一次在新的 Cloud Shell 虚拟机中运行此命令时,系统会显示 Authorize Cloud Shell 页面。点击页面底部的 Authorize,以允许 gcloud 使用您的凭据进行 API 调用。

  4. 从 Cloud Shell 启动本教程所需的 Compute Engine 资源。

    gcloud compute instances create wav2vec2-tutorial \
      --zone=us-central1-a \
      --machine-type=n1-standard-64 \
      --image-family=torch-xla \
      --image-project=ml-images  \
      --boot-disk-size=200GB \
      --scopes=https://www.googleapis.com/auth/cloud-platform
    
  5. 连接到新的 Compute Engine 实例。

    gcloud compute ssh wav2vec2-tutorial --zone=us-central1-a
    

启动 Cloud TPU 资源

  1. 在 Compute Engine 虚拟机中,设置 PyTorch 版本。

    (vm) $ export PYTORCH_VERSION=2.0
    
  2. 使用以下命令启动 Cloud TPU 资源。

    (vm) $ gcloud compute tpus create wav2vec2-tutorial \
    --zone=us-central1-a \
    --network=default \
    --version=pytorch-2.0 \
    --accelerator-type=v3-8
    
  3. 通过运行以下命令确定 Cloud TPU 资源的 IP 地址:

    (vm) $ gcloud compute tpus describe wav2vec2-tutorial --zone=us-central1-a
    

    此命令将显示有关 TPU 的信息。查找标有 ipAddress 的行。将以下命令中的 ip-address 替换为该 IP 地址:

    (vm) $ export TPU_IP_ADDRESS=ip-address
    (vm) $ export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
    

创建并配置 PyTorch 环境

  1. 启动 conda 环境。

    (vm) $ conda activate torch-xla-2.0
    
  2. 为 Cloud TPU 资源配置环境变量。

下载并准备数据

访问 ResNetR 网站,查看您可以在此任务中使用的备用数据集。在本教程中,我们将使用 dev-clean.tar.gz,因为它的预处理时间最短。

  1. Wav2Vec2 需要一些依赖项,请立即进行安装。

    (vm) $ pip install omegaconf hydra-core soundfile
    (vm) $ sudo apt-get install libsndfile-dev
    
  2. 下载数据集。

    (vm) $ curl https://www.openslr.org/resources/12/dev-clean.tar.gz --output dev-clean.tar.gz
    
  3. 对压缩文件解压缩。您的文件存储在 LibriSpeech 文件夹中。

    (vm) $ tar xf dev-clean.tar.gz
    
  4. 下载并安装最新的 fairseq 模型。

    (vm) $ git clone --recursive https://github.com/pytorch/fairseq.git
    (vm) $ cd fairseq
    (vm) $ pip install --editable .
    (vm) $ cd -
  5. 准备数据集。 此脚本会创建一个名为 manifest/ 的文件夹,其中包含指向原始数据的指针(在 LibriSpeech/ 下)。

    (vm) $ python fairseq/examples/wav2vec/wav2vec_manifest.py LibriSpeech/ --dest manifest/

运行训练作业

  1. 对 LibriSpeech 数据运行该模型,此脚本大约需要 2 个小时才能运行。

    (vm) $ fairseq-hydra-train \
    task.data=${HOME}/manifest \
    optimization.max_update=400000 \
    dataset.batch_size=4 \
    common.log_format=simple \
    --config-dir fairseq/examples/wav2vec/config/pretraining   \
    --config-name wav2vec2_large_librivox_tpu.yaml
    

清理

为避免因本教程中使用的资源导致您的 Google Cloud 帐号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

  1. 断开与 Compute Engine 实例的连接(如果您尚未这样做):

    (vm)$ exit
    

    您的提示符现在应为 user@projectname,表明您位于 Cloud Shell 中。

  2. 在 Cloud Shell 中,使用 Google Cloud CLI 删除 Compute Engine 虚拟机实例和 TPU:

    $ gcloud compute tpus delete wav2vec2-tutorial --zone=us-central1-a
    

后续步骤

扩展到 Cloud TPU Pod

如需将本教程中的预训练任务扩展到强大的 Cloud TPU Pod,请参阅在 Cloud TPU Pod 上训练 PyTorch 模型教程。

试用 PyTorch Colab: