Pretraining Wav2Vec2 on Cloud TPU with PyTorch

This tutorial shows you how to pretrain FairSeq's Wav2Vec2 model on a Cloud TPU device with PyTorch. You can apply the same pattern to other TPU-optimised image classification models that use PyTorch and the ImageNet dataset.

The model in this tutorial is based on the wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations paper.


  • Create and configure the PyTorch environment.
  • Download Open Source LibriSpeech data.
  • Run the training job.


This tutorial uses billable components of Google Cloud, including:

  • Compute Engine
  • Cloud TPU

Use the pricing calculator to generate a cost estimate based on your projected usage. New Google Cloud users might be eligible for a free trial.

Before you begin

Before starting this tutorial, check that your Google Cloud project is correctly set up.

  1. In the Google Cloud Console, on the project selector page, select or create a Google Cloud project. Note: If you don't plan to keep the resources that you create in this procedure, create a project instead of selecting an existing project. After you finish these steps, you can delete the project, removing all resources associated with the project.
  2. Go to the project selector page and make sure that billing is enabled for your Cloud project. Learn how to confirm that billing is enabled for your project.
  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. In the Google Cloud Console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  3. Make sure that billing is enabled for your Cloud project. Learn how to confirm that billing is enabled for your project.

  4. This walkthrough uses billable components of Google Cloud. Check the Cloud TPU pricing page to estimate your costs. Be sure to clean up resources you create when you've finished with them to avoid unnecessary charges.

Set up a Compute Engine instance

  1. Open a Cloud Shell window.

    Open Cloud Shell

  2. Create a variable for your project's ID.

    export PROJECT_ID=project-id
  3. Configure the gcloud command-line tool to use the project where you want to create Cloud TPU.

    gcloud config set project ${PROJECT_ID}

    The first time you run this command in a new Cloud Shell VM, an Authorize Cloud Shell page is displayed. Click Authorize at the bottom of the page to allow gcloud to make GCP API calls with your credentials.

  4. From the Cloud Shell, launch the Compute Engine resource required for this tutorial.

    gcloud compute instances create wav2vec2-tutorial \
      --zone=us-central1-a \
      --machine-type=n1-standard-64 \
      --image-family=torch-xla \
      --image-project=ml-images  \
      --boot-disk-size=200GB \
  5. Connect to the new Compute Engine instance.

    gcloud compute ssh wav2vec2-tutorial --zone=us-central1-a

Launch a Cloud TPU resource

  1. In the Compute Engine virtual machine, set the PyTorch version.

    (vm) $ export PYTORCH_VERSION=1.9
  2. Launch a Cloud TPU resource using the following command:

    (vm) $ gcloud compute tpus create w2v2-tutorial \
    --zone=us-central1-a \
    --network=default \
    --version=pytorch-1.9 \
  3. Identify the IP address for the Cloud TPU resource.

    (vm) $ gcloud compute tpus list --zone=us-central1-a

Create and configure the PyTorch environment

  1. Start a conda environment.

    (vm) $ conda activate torch-xla-1.9
  2. Configure environmental variables for the Cloud TPU resource.

    (vm) $ export TPU_IP_ADDRESS=ip-address
    (vm) $ export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"

Download and prepare data

Visit the OpenSLR website to see the alternative datasets you may use in this task. In this tutorial, we will use dev-clean.tar.gz as it has the shortest preprocessing time.

  1. Wav2Vec2 requires some dependencies, install them now.

    (vm) $ pip install omegaconf hydra-core soundfile
    (vm) $ sudo apt-get install libsndfile-dev
  2. Download the dataset.

    (vm) $ curl --output dev-clean.tar.gz
  3. Extract the compressed files. Your files are stored in the LibriSpeech folder.

    (vm) $ tar xf dev-clean.tar.gz
  4. Download and install the latest fairseq model.

    (vm) $ git clone --recursive
    (vm) $ cd fairseq
    (vm) $ pip install --editable .
    (vm) $ cd -
  5. Prepare the dataset. This script creates a folder called manifest/ with the pointer to the raw data (under LibriSpeech/).

    (vm) $ python fairseq/examples/wav2vec/ LibriSpeech/ --dest manifest/

Run the training job

  1. Run the model on LibriSpeech data, this script takes about 2 hours to run.

    (vm) $ OMP_NUM_THREADS=1 python fairseq/ \
     manifest/ \
     --num-batch-buckets 3 \
     --tpu \
     --max-sentences 4 \
     --max-sentences-valid 4 \
     --required-batch-size-multiple 4 \
     --distributed-world-size 8 \
     --distributed-port 12597 \
     --update-freq 1 \
     --enable-padding \
     --log-interval 5 \
     --num-workers 6 \
     --task audio_pretraining \
     --criterion wav2vec \
     --arch wav2vec2 \
     --log-keys  "['prob_perplexity','code_perplexity','temp']" \
     --quantize-targets \
     --extractor-mode default \
     --conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' \
     --final-dim 256 \
     --latent-vars 320 \
     --latent-groups 2 \
     --latent-temp '(2,0.5,0.999995)' \
     --infonce \
     --optimizer adam \
     --adam-betas '(0.9,0.98)' \
     --adam-eps 1e-06 \
     --lr-scheduler polynomial_decay \
     --total-num-update 400000 \
     --lr 0.0005 \
     --warmup-updates 32000 \
     --mask-length 10 \
     --mask-prob 0.65 \
     --mask-selection static \
     --mask-other 0 \
     --mask-channel-prob 0.1 \
     --encoder-layerdrop 0 \
     --dropout-input 0.0 \
     --dropout-features 0.0 \
     --feature-grad-mult 0.1 \
     --loss-weights '[0.1, 10]' \
     --conv-pos 128 \
     --conv-pos-groups 16 \
     --num-negatives 100 \
     --cross-sample-negatives 0 \
     --max-sample-size 250000 \
     --min-sample-size 32000 \
     --dropout 0.0 \
     --attention-dropout 0.0 \
     --weight-decay 0.01 \
     --max-tokens 1400000 \
     --max-epoch 10 \
     --save-interval 2 \
     --skip-invalid-size-inputs-valid-test \
     --ddp-backend no_c10d \
     --log-format simple

Clean up

To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.

  1. Disconnect from the Compute Engine instance, if you have not already done so:

    (vm)$ exit

    Your prompt should now be user@projectname, showing you are in the Cloud Shell.

  2. In your Cloud Shell, use the gcloud command-line tool to delete the Compute Engine VM instance and TPU:

    $ gcloud compute tpus execution-groups delete w2v2-tutorial --zone=us-central1-a

What's next

Scaling to Cloud TPU Pods

Please refer to the Training PyTorch models on Cloud TPU Pods tutorial in order to scale the pretraining task in this tutorial to the powerful Cloud TPU Pods.

Try the PyTorch colabs: