使用 PyTorch 在 Cloud TPU 上训练 FairSeq Transformer

本教程重点介绍 FairSeq 版本的 Transformer,以及将英语翻译为德语的 WMT 18 翻译任务。

目标

  • 准备数据集。
  • 运行训练作业。
  • 验证输出结果。

费用

本教程使用 Google Cloud 的以下收费组件:

  • Compute Engine
  • Cloud TPU

请使用价格计算器根据您的预计使用情况来估算费用。 Google Cloud 新用户可能有资格申请免费试用

准备工作

在开始学习本教程之前,请检查您的 Google Cloud 项目是否已正确设置。

  1. 登录您的 Google 帐号。

    如果您还没有 Google 帐号,请注册一个新帐号

  2. 在 Cloud Console 的项目选择器页面上,选择或创建 Cloud 项目。

    转到项目选择器页面

  3. 确保您的 Google Cloud 项目已启用结算功能。 了解如何确认您的项目已启用结算功能

  4. 本演示使用 Google Cloud 的收费组件。请查看 Cloud TPU 价格页面估算您的费用。请务必在使用完您创建的资源以后清理这些资源,以免产生不必要的费用。

设置 Compute Engine 实例

  1. 打开一个 Cloud Shell 窗口。

    打开 Cloud Shell

  2. 为项目 ID 创建一个变量。

    export PROJECT_ID=project-id
    
  3. 配置 gcloud 命令行工具,以使用要在其中创建 Cloud TPU 的项目。

    gcloud config set project ${PROJECT_ID}
    
  4. 从 Cloud Shell 启动本教程所需的 Compute Engine 资源。

    gcloud compute --project=${PROJECT_ID} instances create transformer-tutorial \
    --zone=us-central1-a  \
    --machine-type=n1-standard-16  \
    --image-family=torch-xla \
    --image-project=ml-images  \
    --boot-disk-size=200GB \
    --scopes=https://www.googleapis.com/auth/cloud-platform
    
  5. 连接到新的 Compute Engine 实例。

    gcloud compute ssh transformer-tutorial --zone=us-central1-a
    

启动 Cloud TPU 资源

  1. 在 Compute Engine 虚拟机中,使用以下命令启动 Cloud TPU 资源:

    (vm) $ gcloud compute tpus create transformer-tutorial \
    --zone=us-central1-a \
    --network=default \
    --version=pytorch-1.6 \
    --accelerator-type=v3-8
    
  2. 确定 Cloud TPU 资源的 IP 地址。

    (vm) $ gcloud compute tpus list --zone=us-central1-a
    

    该 IP 地址位于 NETWORK_ENDPOINTS 列下方。在创建和配置 PyTorch 环境时,您将需要该 IP 地址。

下载数据

  1. 创建一个用于存储模型数据的目录 pytorch-tutorial-data

    (vm) $ mkdir $HOME/pytorch-tutorial-data
    
  2. 导航到 pytorch-tutorial-data 目录。

    (vm) $ cd $HOME/pytorch-tutorial-data
    
  3. 下载模型数据。

    (vm) $ wget https://dl.fbaipublicfiles.com/fairseq/data/wmt18_en_de_bpej32k.zip
    
  4. 提取数据。

    (vm) $ sudo apt-get install unzip && \
    unzip wmt18_en_de_bpej32k.zip
    

创建并配置 PyTorch 环境

  1. 启动 conda 环境。

    (vm) $ conda activate torch-xla-1.6
    
  2. 为 Cloud TPU 资源配置环境变量。

    (vm) $ export TPU_IP_ADDRESS=ip-address; \
    export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
    

训练模型

要训练模型,请运行以下脚本:

(vm) $ python /usr/share/torch-xla-1.6/tpu-examples/deps/fairseq/train.py \
  $HOME/pytorch-tutorial-data/wmt18_en_de_bpej32k \
  --save-interval=1 \
  --arch=transformer_vaswani_wmt_en_de_big \
  --max-target-positions=64 \
  --attention-dropout=0.1 \
  --no-progress-bar \
  --criterion=label_smoothed_cross_entropy \
  --source-lang=en \
  --lr-scheduler=inverse_sqrt \
  --min-lr 1e-09 \
  --skip-invalid-size-inputs-valid-test \
  --target-lang=de \
  --label-smoothing=0.1 \
  --update-freq=1 \
  --optimizer adam \
  --adam-betas '(0.9, 0.98)' \
  --warmup-init-lr 1e-07 \
  --lr 0.0005 \
  --warmup-updates 4000 \
  --share-all-embeddings \
  --dropout 0.3 \
  --weight-decay 0.0 \
  --valid-subset=valid \
  --max-epoch=25 \
  --input_shapes 128x64 \
  --num_cores=8 \
  --metrics_debug \
  --log_steps=100

验证输出结果

训练作业完成后,可以在以下目录中查找模型检查点:

$HOME/checkpoints

清理

使用您创建的资源后,请进行清理,以免您的帐号产生不必要的费用:

  1. 断开与 Compute Engine 实例的连接(如果您尚未这样做):

    (vm) $ exit
    

    您的提示符现在应为 user@projectname,表明您位于 Cloud Shell 中。

  2. 在 Cloud Shell 中,使用 gcloud 命令行工具删除 Compute Engine 实例。

    $  gcloud compute instances delete transformer-tutorial  --zone=us-central1-a
    
  3. 使用 gcloud 命令行工具删除 Cloud TPU 资源。

    $  gcloud compute tpus delete transformer-tutorial --zone=us-central1-a
    

后续步骤

试用 PyTorch Colab: