Training FairSeq Transformer on Cloud TPU using PyTorch

This tutorial specifically focuses on the FairSeq version of Transformer, and the WMT 18 translation task, translating English to German.

Objectives

  • Prepare the dataset.
  • Run the training job.
  • Verify the output results.

Costs

This tutorial uses billable components of Google Cloud, including:

  • Compute Engine
  • Cloud TPU

Use the pricing calculator to generate a cost estimate based on your projected usage. New Google Cloud users might be eligible for a free trial.

Before you begin

Before starting this tutorial, check that your Google Cloud project is correctly set up.

  1. Sign in to your Google Account.

    If you don't already have one, sign up for a new account.

  2. In the Cloud Console, on the project selector page, select or create a Cloud project.

    Go to the project selector page

  3. Make sure that billing is enabled for your Google Cloud project. Learn how to confirm billing is enabled for your project.

  4. This walkthrough uses billable components of Google Cloud. Check the Cloud TPU pricing page to estimate your costs. Be sure to clean up resources you create when you've finished with them to avoid unnecessary charges.

Set up a Compute Engine instance

  1. Open a Cloud Shell window.

    Open Cloud Shell

  2. Create a variable for your project's ID.

    export PROJECT_ID=project-id
    
  3. Configure gcloud command-line tool to use the project where you want to create Cloud TPU.

    gcloud config set project ${PROJECT_ID}
    
  4. From the Cloud Shell, launch the Compute Engine resource required for this tutorial.

    gcloud compute --project=${PROJECT_ID} instances create transformer-tutorial \
    --zone=us-central1-a  \
    --machine-type=n1-standard-16  \
    --image-family=torch-xla \
    --image-project=ml-images  \
    --boot-disk-size=200GB \
    --scopes=https://www.googleapis.com/auth/cloud-platform
    
  5. Connect to the new Compute Engine instance.

    gcloud compute ssh transformer-tutorial --zone=us-central1-a
    

Launch a Cloud TPU resource

  1. From the Compute Engine virtual machine, launch a Cloud TPU resource using the following command:

    (vm) $ gcloud compute tpus create transformer-tutorial \
    --zone=us-central1-a \
    --network=default \
    --version=pytorch-1.6 \
    --accelerator-type=v3-8
    
  2. Identify the IP address for the Cloud TPU resource.

    (vm) $ gcloud compute tpus list --zone=us-central1-a
    

    The IP address is located under the NETWORK_ENDPOINTS column. You will need this IP address when you create and configure the PyTorch environment.

Download the data

  1. Create a directory, pytorch-tutorial-data to store the model data.

    (vm) $ mkdir $HOME/pytorch-tutorial-data
    
  2. Navigate to the pytorch-tutorial-data directory.

    (vm) $ cd $HOME/pytorch-tutorial-data
    
  3. Download the model data.

    (vm) $ wget https://dl.fbaipublicfiles.com/fairseq/data/wmt18_en_de_bpej32k.zip
    
  4. Extract the data.

    (vm) $ sudo apt-get install unzip && \
    unzip wmt18_en_de_bpej32k.zip
    

Create and configure the PyTorch environment

  1. Start a conda environment.

    (vm) $ conda activate torch-xla-1.6
    
  2. Configure environmental variables for the Cloud TPU resource.

    (vm) $ export TPU_IP_ADDRESS=ip-address; \
    export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
    

Train the model

To train the model, run the following script:

(vm) $ python /usr/share/torch-xla-1.6/tpu-examples/deps/fairseq/train.py \
  $HOME/pytorch-tutorial-data/wmt18_en_de_bpej32k \
  --save-interval=1 \
  --arch=transformer_vaswani_wmt_en_de_big \
  --max-target-positions=64 \
  --attention-dropout=0.1 \
  --no-progress-bar \
  --criterion=label_smoothed_cross_entropy \
  --source-lang=en \
  --lr-scheduler=inverse_sqrt \
  --min-lr 1e-09 \
  --skip-invalid-size-inputs-valid-test \
  --target-lang=de \
  --label-smoothing=0.1 \
  --update-freq=1 \
  --optimizer adam \
  --adam-betas '(0.9, 0.98)' \
  --warmup-init-lr 1e-07 \
  --lr 0.0005 \
  --warmup-updates 4000 \
  --share-all-embeddings \
  --dropout 0.3 \
  --weight-decay 0.0 \
  --valid-subset=valid \
  --max-epoch=25 \
  --input_shapes 128x64 \
  --num_cores=8 \
  --metrics_debug \
  --log_steps=100

Verify output results

After the training job completes, you can find your model checkpoints in the following directory:

$HOME/checkpoints

Cleaning up

Perform a cleanup to avoid incurring unnecessary charges to your account after using the resources you created:

  1. Disconnect from the Compute Engine instance, if you have not already done so:

    (vm) $ exit
    

    Your prompt should now be user@projectname, showing you are in the Cloud Shell.

  2. In your Cloud Shell, use the gcloud command-line tool to delete the Compute Engine instance.

    $  gcloud compute instances delete transformer-tutorial  --zone=us-central1-a
    
  3. Use gcloud command-line tool to delete the Cloud TPU resource.

    $  gcloud compute tpus delete transformer-tutorial --zone=us-central1-a
    

What's next

Try the PyTorch colabs: