Pre-training FairSeq RoBERTa on Cloud TPU using PyTorch

Stay organized with collections Save and categorize content based on your preferences.

This tutorial shows you how to pre-train FairSeq's RoBERTa on a Cloud TPU. Specifically, it follows FairSeq's tutorial, pretraining the model on the public wikitext-103 dataset.

Objectives

  • Create and configure the PyTorch environment
  • Prepare the dataset
  • Run the training job
  • Verify that you can view the output results

Costs

This tutorial uses the following billable components of Google Cloud:

  • Compute Engine
  • Cloud TPU

To generate a cost estimate based on your projected usage, use the pricing calculator. New Google Cloud users might be eligible for a free trial.

Before you begin

Before starting this tutorial, check that your Google Cloud project is correctly set up.

  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  3. Make sure that billing is enabled for your Cloud project. Learn how to check if billing is enabled on a project.

  4. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  5. Make sure that billing is enabled for your Cloud project. Learn how to check if billing is enabled on a project.

  6. This walkthrough uses billable components of Google Cloud. Check the Cloud TPU pricing page to estimate your costs. Be sure to clean up resources you create when you've finished with them to avoid unnecessary charges.

Set up a Compute Engine instance

  1. Open a Cloud Shell window.

    Open v

  2. Create a variable for your project's ID.

    export PROJECT_ID=project-id
    
  3. Configure Google Cloud CLI to use the project where you want to create Cloud TPU.

    gcloud config set project ${PROJECT_ID}
    

    The first time you run this command in a new Cloud Shell VM, an Authorize Cloud Shell page is displayed. Click Authorize at the bottom of the page to allow gcloud to make Google Cloud API calls with your credentials.

  4. From the Cloud Shell, launch the Compute Engine resource required for this tutorial.

    gcloud compute instances create roberta-tutorial \
    --zone=us-central1-a \
    --machine-type=n1-standard-16  \
    --image-family=torch-xla \
    --image-project=ml-images  \
    --boot-disk-size=200GB \
    --scopes=https://www.googleapis.com/auth/cloud-platform
    
  5. Connect to the new Compute Engine instance.

    gcloud compute ssh roberta-tutorial --zone=us-central1-a
    

Launch a Cloud TPU resource

  1. From the Compute Engine virtual machine, launch a Cloud TPU resource using the following command:

    (vm) $ gcloud compute tpus create roberta-tutorial \
    --zone=us-central1-a \
    --network=default \
    --version=pytorch-2.0  \
    --accelerator-type=v3-8
    
  2. Identify the IP address for the Cloud TPU resource.

    (vm) $ gcloud compute tpus describe --zone=us-central1-a roberta-tutorial
    

Create and configure the PyTorch environment

  1. Start a conda environment.

    (vm) $ conda activate torch-xla-2.0
    
  2. Configure environmental variables for the Cloud TPU resource.

    (vm) $ export TPU_IP_ADDRESS=ip-address
    
    (vm) $ export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
    

Set up the data

  1. Install FairSeq by running:

    (vm) $ pip install --editable /usr/share/torch-xla-2.0/tpu-examples/deps/fairseq
    
  2. Create a directory, pytorch-tutorial-data to store the model data.

    (vm) $ mkdir $HOME/pytorch-tutorial-data
    (vm) $ cd $HOME/pytorch-tutorial-data
    
  3. Follow the instructions on FairSeq RoBERTa's README "Preprocess the data" section. It takes approximately 10 minutes to prepare the dataset.

Train the model

To train the model, first set up some environment variables:

(vm) $ export TOTAL_UPDATES=125000    # Total number of training steps
(vm) $ export WARMUP_UPDATES=10000    # Warmup the learning rate over this many updates
(vm) $ export PEAK_LR=0.0005          # Peak learning rate, adjust as needed
(vm) $ export TOKENS_PER_SAMPLE=512   # Max sequence length
(vm) $ export UPDATE_FREQ=16          # Increase the batch size 16x
(vm) $ export DATA_DIR=${HOME}/pytorch-tutorial-data/data-bin/wikitext-103

Then, run the following script:

(vm) $ python3 \
      /usr/share/torch-xla-pytorch-2.0/tpu-examples/deps/fairseq/train.py $DATA_DIR \
      --task=masked_lm --criterion=masked_lm \
      --arch=roberta_base --sample-break-mode=complete \
      --tokens-per-sample=512 \
      --optimizer=adam \
      --adam-betas='(0.9,0.98)' \
      --adam-eps=1e-6 \
      --clip-norm=0.0 \
      --lr-scheduler=polynomial_decay \
      --lr=0.0005 \
      --warmup-updates=10000 \
      --dropout=0.1 \
      --attention-dropout=0.1 \
      --weight-decay=0.01 \
      --update-freq=16 \
      --train-subset=train \
      --valid-subset=valid \
      --num_cores=8 \
      --metrics_debug \
      --save-dir=checkpoints \
      --log_steps=30 \
      --log-format=simple \
      --skip-invalid-size-inputs-valid-test \
      --suppress_loss_report \
      --input_shapes 16x512 18x480 21x384 \
      --max-epoch=1

The training script runs for approximately 15 minutes and when it finishes, it generates a message similar to:

saved checkpoint /home/user/checkpoints/checkpoint1.pt
(epoch 1 @ 119 updates) (writing took 25.19265842437744 seconds)
| done training in 923.8 seconds

Verify output results

After the training job completes, you can find your model checkpoints in the following directory:

$HOME/checkpoints

Clean up

Perform a cleanup to avoid incurring unnecessary charges to your account after using the resources you created:

  1. Disconnect from the Compute Engine instance, if you have not already done so:

    (vm) $ exit
    

    Your prompt should now be user@projectname, showing you are in the Cloud Shell.

  2. In your Cloud Shell, use the Google Cloud CLI to delete the Compute Engine instance.

    $ gcloud compute instances delete roberta-tutorial --zone=us-central1-a
    
  3. Use Google Cloud CLI to delete the Cloud TPU resource.

    $ gcloud compute tpus delete roberta-tutorial --zone=us-central1-a
    

What's next

Try the PyTorch colabs: