Training Resnet50 on Cloud TPU with PyTorch

Stay organized with collections Save and categorize content based on your preferences.

This tutorial shows you how to train the ResNet-50 model on a Cloud TPU device with PyTorch. You can apply the same pattern to other TPU-optimised image classification models that use PyTorch and the ImageNet dataset.

The model in this tutorial is based on Deep Residual Learning for Image Recognition, which first introduces the residual network (ResNet) architecture. The tutorial uses the 50-layer variant, ResNet-50, and demonstrates training the model using PyTorch/XLA.


  • Prepare the dataset.
  • Run the training job.
  • Verify the output results.


This tutorial uses the following billable components of Google Cloud:

  • Compute Engine
  • Cloud TPU

To generate a cost estimate based on your projected usage, use the pricing calculator. New Google Cloud users might be eligible for a free trial.

Before you begin

Before starting this tutorial, check that your Google Cloud project is correctly set up.

  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  3. Make sure that billing is enabled for your Cloud project. Learn how to check if billing is enabled on a project.

  4. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  5. Make sure that billing is enabled for your Cloud project. Learn how to check if billing is enabled on a project.

  6. This walkthrough uses billable components of Google Cloud. Check the Cloud TPU pricing page to estimate your costs. Be sure to clean up resources you create when you've finished with them to avoid unnecessary charges.

Set up a Compute Engine instance

  1. Open a Cloud Shell window.

    Open Cloud Shell

  2. Create a variable for your project's ID.

    export PROJECT_ID=project-id
  3. Configure the Google Cloud CLI to use the project where you want to create Cloud TPU.

    gcloud config set project ${PROJECT_ID}

    The first time you run this command in a new Cloud Shell VM, an Authorize Cloud Shell page is displayed. Click Authorize at the bottom of the page to allow gcloud to make Google Cloud API calls with your credentials.

  4. Launch the Compute Engine resource required for this tutorial.

    gcloud compute instances create resnet50-tutorial \
    --zone=us-central1-a \
    --machine-type=n1-highmem-96  \
    --image-family=torch-xla \
    --image-project=ml-images  \
    --boot-disk-size=300GB \

    If you plan on training Resnet50 on real data, choose the machine type with the highest number of CPUs that you can. Resnet50 is typically highly input-bound so the training can be quite slow unless there are many workers to feed in data and sufficient RAM to maintain a large number of worker threads. For best results, select n1-highmem-96 machine type.

    If you plan to download ImageNet, specify a disk size of at least 300GB. If you plan to only use fake data, you can specify the default disk size of 20GB. This tutorial suggests using both data sets.

Launch a Cloud TPU resource

  1. From the Compute Engine virtual machine, launch a Cloud TPU resource using the following command:

    (vm) $ gcloud compute tpus create resnet50-tutorial \
    --zone=us-central1-a \
    --network=default \
    --version=pytorch-1.13 \
  2. Identify the IP address for the Cloud TPU resource.

    (vm) $ gcloud compute tpus list --zone=us-central1-a

Create and configure the PyTorch environment

  1. Connect to the new Compute Engine instance.

    gcloud compute ssh resnet50-tutorial --zone=us-central1-a

From this point on, a prefix of (vm)$ means you should run the command on the Compute Engine VM instance.

  1. Start a conda environment.

    (vm) $ conda activate torch-xla-1.13
  2. Configure environmental variables for the Cloud TPU resource.

    (vm) $ export TPU_IP_ADDRESS=ip-address
    (vm) $ export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"

Training with the fake_data set

We recommend using the fake data set for your initial run instead of the real ImageNet set since fake_data is automatically installed in your VM, and requires less time and fewer resources to process.

(vm) $ python /usr/share/torch-xla-1.13/pytorch/xla/test/ --fake_data --model=resnet50 --num_epochs=2 --batch_size=128 --log_steps=20

Training with the real data set

If everything looks OK using the --fake_data flag, you can try training on real data, such as ImageNet.

In general, uses torchvision.datasets.ImageFolder so you can use any dataset that is structured properly. See the ImageFolder documentation.

Some suggested command line modifications for using real data, assuming you stored the dataset at ~/imagenet:

(vm) $ python /usr/share/torch-xla-1.13/pytorch/xla/test/ --datadir=~/imagenet --model=resnet50 --num_epochs=90 --num_workers=8 --batch_size=128 --log_steps=200

Requesting additional CPU quota

Plan and request additional resources a few days in advance to ensure that there is enough time to fulfill your request.

  1. Go to the Quotas page.

    Go to the Quotas page

  2. From the Service menu, select Cloud TPU API.
  3. Select the region or zones where you want to use the CPUs.
  4. From the Metric menu, select None and then enter CPUs in the search box.
  5. Select CPUs.
  6. In the list, select Compute Engine API - CPUs, then click Edit Quotas at the top of the page.
  7. Enter the amount of quota you are requesting and a description (required), then click Done. A request is sent to your service provider for approval.

Clean up

To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.

  1. Disconnect from the Compute Engine instance, if you have not already done so:

    (vm)$ exit

    Your prompt should now be user@projectname, showing you are in the Cloud Shell.

  2. In your Cloud Shell, run ctpu delete with the --zone and --name flags you used when you set up the Compute Engine VM and Cloud TPU. This deletes both your VM and your Cloud TPU.

    $ ctpu delete --project=${PROJECT_ID} \
      --name=resnet50-tutorial \
  3. Run ctpu status to make sure you have no instances allocated to avoid unnecessary charges for TPU usage. The deletion might take several minutes. A response like the one below indicates there are no more allocated instances:

    $ ctpu status --project=${PROJECT_ID} \
    2018/04/28 16:16:23 WARNING: Setting zone to "us-central1-a"
    No instances currently exist.
            Compute Engine VM:     --
            Cloud TPU:             --

What's next

Try the PyTorch colabs: