Developers & Practitioners

Scaling deep learning workloads with PyTorch / XLA and Cloud TPU VM


Many deep learning advancements can be attributed to increases in (1) data size and (2) computational power. Training deep learning models with larger datasets can be extremely beneficial for model training. Not only do they help stabilize model performance during training, but research shows that for moderate to large-scale models and datasets, model performance converges as a power-law with training data size, meaning we can predict improvements to model accuracy as the dataset grows.

Figure 1

Figure 1: Learning curve and dataset size for word language models (source)

In practice this means as we look to improve model performance with larger datasets, (1) we need access to hardware accelerators, such as GPUs or TPUs, and (2) we need to architect a system that efficiently stores and delivers this data to the accelerators. There are a few reasons why we may choose to stream data from remote storage to our accelerator devices:

  • Data size: data can be too large to fit on a single machine, requiring remote storage and efficient network access
  • Streamlined workflows: transferring data to disk can be time consuming and resource intensive, we want to make fewer copies of the data 
  • Collaboration: disaggregating data from accelerator devices means we can more efficiently share accelerator nodes across workloads and teams

Streaming training data from remote storage to accelerators can alleviate these issues, but it introduces a host of new challenges:

  • Network overhead: Many datasets consist of millions of individual files, randomly accessing these files can introduce network bottlenecks. We need sequential access patterns
  • Throughput: Modern accelerators are fast; the challenge is feeding them fast enough to keep them fully utilized. We need parallel I/O and pipelined access to data
  • Randomness vs Sequential: The optimization algorithms in deep learning jobs benefit from randomness, but random file access introduces network bottlenecks. Sequential access alleviates network bottlenecks, but can reduce the randomness needed for training optimization. We need to balance these 

How do we architect a system that addresses these challenges at scale?

figure 2

Figure 2: Scaling to larger datasets, more devices 

In this post, we will cover:

  • The challenges associated with scaling deep learning jobs to distributed training settings
  • Using the new Cloud TPU VM interface
  • How to stream training data from Google Cloud Storage (GCS) to PyTorch / XLA models running on Cloud TPU Pod slices

You can find accompanying code for this article in this GitHub repository

Model and dataset

In this article, we will train a PyTorch / XLA ResNet-50 model on a v3-32 TPU Pod slice where training data is stored in GCS and streamed to the TPU VMs at training time. ResNet-50 is a 50-layer convolutional neural network commonly used for computer vision tasks and machine learning performance benchmarking. To demonstrate an end-to-end example, we will use the CIFAR-10 dataset. The original dataset consists of 60,000 32x32 color images divided into 10 classes, each class containing 6,000 images. We have upsampled this dataset, creating a training and test set of 1,280,000 and 50,000 images, respectively. CIFAR is used because it is publicly accessible and well known; however, in the GitHub repository, we provide guidance for adapting this solution to your workloads, as well as larger datasets such as ImageNet.

Cloud TPU

TPUs, or Tensor Processing Units, are ML ASICs specifically designed for large-scale model training. As they excel at any task where large matrix multiplications dominate, they can accelerate deep learning jobs and reduce the total cost of training. If you're new to TPUs, check this article to understand how they work. 

The v3-32 TPU used in this example consists of 32 TPU v3 cores and 256 GiB of total TPU memory. This TPU Pod slice consists of 4 TPU Boards (a Board has 8 TPU cores). Each TPU Board is connected to a high-performance CPU-based host machine for things like loading and preprocessing data to feed to the TPUs.

Figure 3
Figure 3: Cloud TPU VM architecture (source)

We will access the TPU through the new Cloud TPU VMs. When we use Cloud TPU VMs, a VM is created for each TPU board in the configuration. Each VM consists of 48 vCPUs and 340 GB of memory, and comes preinstalled with the latest PyTorch / XLA image. Because there is no user VM, we ssh directly into the TPU host to run our model and code. This root access eliminates the need for a network, VPC, or firewall between our code and the TPU VM, which can significantly improve the performance of our input pipeline. For more details on Cloud TPU VMs, see the System Architecture.

PyTorch / XLA

PyTorch / XLA is a Python library that uses the XLA (Accelerated Linear Algebra) deep learning compiler to connect PyTorch and Cloud TPUs. Check out the GitHub repository for tutorials, best practices, Docker Images, and code for popular models (e.g., ResNet-50 and AlexNet).

Data parallel distributed training

Distributed training typically refers to training workloads which use multiple accelerator devices (e.g. GPU or TPU). In our example, we are executing a data parallel distributed training job with stochastic gradient descent. In data parallel training, our model fits on a single TPU device and we replicate the model across each device in our distributed configuration. When we add more devices, our goal is to reduce overall training time by distributing non-overlapping partitions of the training batch to each device for parallel processing. Because our model is replicated across devices, the models on each device need to communicate to synchronize their weights after each training step. In distributed data parallel jobs, this device communication is typically done either asynchronously or synchronously.

Cloud TPUs execute synchronous device communication over the dedicated high-speed network connecting the chips. In our model code, we use PyTorch / XLA’s optimizer_step(optimizer) to calculate the gradients and initiate this synchronous update.

Figure 4

Figure 4: Synchronous all-reduce on Cloud TPU interconnect

After the local gradients are computed, the xm.optimizer_step() function synchronizes the local gradients between cores by applying an AllReduce(SUM) operation, and then calls the PyTorch optimizer_step(optimizer), which updates the local weights with the synchronized gradients. On the TPU, the XLA compiler generates AllReduce operations over the dedicated network connecting the chips. Ultimately, the globally averaged gradients are written to each model replica’s parameter weights, ensuring the replicas start from the same state in every training iteration. We can see the call to this function in the training loop: 

  for step, (data, target) in enumerate(loader):
    optimizer.zero_grad() # zero the parameter gradients
    output = model(data)
    loss = loss_fn(output, target)
    xm.optimizer_step(optimizer) # initiate weight synchronization

Input pipeline performance

As previously mentioned, the challenge with TPUs is feeding them the training data fast enough to keep them busy. This problem exists when we store training data on a local disk and becomes even more clear when we stream data from remote storage. Let’s first review a typical machine learning training loop.

figure 5

Figure 5: Common machine learning training loop and hardware configuration

In this illustration, we see the following steps:

  • Training data is either stored on local disk or remote storage 
  • The CPU (1) requests and reads the data, augments it with various transformations, batches it, and feeds it to the model 
  • Once the model has the transformed, batched training data, (2) the accelerator takes over 
  • The accelerator (2a) computes the forward pass, (2b) loss, and (2c) backwards pass 
  • After computing the gradients, (3) the parameter weights are updated (the learning!) 
  • And we repeat the cycle over again 

While this pattern can be adapted in several ways (e.g., some transformations could be computed on the accelerator), the prevailing theme is that an ideal architecture seeks to maximize utilization of the most expensive component, the accelerator. And because of this, we see most performance bottlenecks occurring in the input pipeline driven by the CPU. To help with this, we are going to use the WebDataset library. WebDataset is a PyTorch dataset implementation designed to improve streaming data access for deep learning workloads, especially in remote storage settings. Let’s see how it helps.

WebDataset format

WebDatasets are just POSIX tar archive files, and they can be created with the well-known tar command. They don't require any data conversion; the data format is the same in the tar file as it is on disk. For example, our training images are still in PPM, PNG, or JPEG format when they are stored and transferred to the input pipeline. The tar format provides performance improvements for both small and large datasets, as well as data stored on either local disk or remote storage, such as GCS. Let’s outline three key pipeline performance enhancements we can achieve with WebDataset.

(1) Sequential I/O

GCS is capable of sustaining high throughput, but there is some network overhead when initiating a connection. If we are accessing millions of individual image files, this is not ideal. Alternatively, we can achieve sequential I/O by requesting a tar file containing our individual image files. Once we request the tar file, we get sequential reads of the individual files within that tar file, which allows for faster object I/O over the network. This reduces the number of network connections to establish with GCS, and thus reduces potential network bottlenecks. 

Figure 6

Figure 6: Comparing random and pipelined access to data files

(2) Pipelined data access 

With file-based I/O we randomly access image files, which is good for training optimization, but for each image file there is a client request and storage server response. Our sequential storage achieves higher throughput because with a single client request for a tar file, the data samples in that file flow sequentially to the client. This pattern gives us pipelined access to our individual image files, resulting in higher throughput. 

(3) Sharding

Storing TBs of data in a single sequential file could be difficult to work with and it prevents us from achieving parallel I/O. Sharding the dataset can help us in several ways:

  1. Aggregate network I/O by opening shards in parallel 
  2. Accelerate data preprocessing by processing shards in parallel
  3. Randomly access shards, but read sequentially within each shard 
  4. Distribute shards efficiently across worker nodes and devices
  5. Guarantee equal number of training samples on each device

Because we can control the number of shards and the number of samples in those shards, we can distribute equal-sized shards and guarantee each device receives the same number of samples in each training epoch. Sharding the tar files helps us balance the tradeoff between random files access and sequential reads. Random access to the shards and in-memory shuffling satisfy enough randomness for the training optimization. The sequential reads from each shard reduce network overhead. 

Distributing shards across devices and workers

As we are essentially creating a PyTorch IterableDataset, we can use the PyTorch DataLoader to load data on the devices for each training epoch. Traditional PyTorch Datasets distribute data at the sample-level, but we are going to distribute at the shard-level. We will create two functions to handle this distribution logic and pass them to the `splitter=` and `nodesplitter=` arguments when we create our dataset object. All these functions need to do is take a list of shards and return a subset of those shards. (To see how the following snippets fit into the model script, check out in the accompanying GitHub repository.)

We will split shards across workers with:

  def my_worker_splitter(urls):
   """Split urls per worker
   Selects a subset of urls based on Torch get_worker_info.
   Used as a shard selection function in Dataset.
   replaces wds.split_by_worker"""

   urls = [url for url in urls]

   assert isinstance(urls, list)

   worker_info =
   if worker_info is not None:
       wid =
       num_workers = worker_info.num_workers

       return urls[wid::num_workers]
       return urls

We will split shards across devices with:

  def my_node_splitter(urls):
   """Split urls_ correctly per accelerator node
   :param urls:
   :return: slice of urls_

   urls_this = urls[rank::num_replicas]
   return urls_this

With these two functions we will create a data loader for both train and validation data. Here is the train loader:

  def make_train_loader(cifar_img_dim, shuffle=10000,
    num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers
    epoch_size = trainsize // num_dataset_instances

    image_transform = transforms.Compose(
            transforms.RandomCrop(cifar_img_dim, padding=4),
    dataset = (
                       shardshuffle=True, length=epoch_size)
        .to_tuple("ppm;jpg;jpeg;png", "cls")
        .map_tuple(image_transform, identity)
        .batched(batch_size, partial=True)

    loader =, batch_size=None, 
                                         shuffle=False, drop_last=False, 
    return loader

Here is an explanation of some of the variables used in these snippets:

  • xm.xrt_world_size() is the total number of devices, or TPU cores
  • FLAGS.num_workers is the number of subprocesses spawned per TPU core for loading and preprocessing data
  • The epoch_size specifies the number of training samples each device should expect for each epoch
  • shardshuffle=True means we will shuffle the shards, while .shuffle(10000) shuffles samples inline
  • .batched(batch_size, partial=True) explicitly batches data in the Dataset by batch_size and ‘partial=True’ handles partial batches, typically found in the last shard
  • Our loader is a standard PyTorch DataLoader. Because our WebDataset Dataset accounts for batching, shuffling, and partial batches, we do not use these arguments in PyTorch’s DataLoader

Performance comparison

The table in Figure 7 compares the performance between 3 different training configurations for a PyTorch / XLA ResNet-50 model training on the ImageNet dataset. Configuration A provides baseline metrics and represents a model reading from local storage, randomly accessing individual image files. Configuration B uses a similar setup as A, except the training data is sharded into 640 POSIX tar files and the WebDataset library is used to sample and distribute shards to the model replicas on Cloud TPU devices. Configuration C uses the same sampling and distribution logic as B, but sources training data from remote storage in GCS. The metrics represent an average of each configuration over five 90-epoch training jobs.

figure 7

Figure 7: Training performance comparison

Comparing configurations A and B, these results show that simply using a sharded, sequentially readable data format improves pipeline and model throughput (average examples per second) by 11.2%. They also show that we can take advantage of remote storage without negatively impacting model training performance. Comparing configurations A and C, we were able to maintain pipeline and model throughput, training time, and model accuracy.

To highlight the impacts of sequential and parallel I/O, we held many configuration settings constant. There are still several areas to investigate and improve. In a later post we will show how to use the Cloud TPU profiler tool to further optimize PyTorch / XLA training jobs.

End-to-end example

Let’s walk through a full example.

To follow this example, you can use this notebook to create a sharded CIFAR dataset.

Before you begin

In the Cloud Shell, run the following commands to configure gcloud to use your GCP project, install components needed for the TPU VM preview, and enable the TPU API. For additional TPU 1VM setup details, see these instructions.
  gcloud config set account YOUR_EMAIL_ACCOUNT
gcloud config set project YOUR_PROJECT_ID
gcloud components install alpha 
gcloud services enable

Connecting to a Cloud TPU VM

The default network comes preconfigured to allow ssh access to all VMs. If you don’t use the default network, or the default network settings were edited, you may need to explicitly enable SSH access by adding a firewall rule:

  gcloud compute firewall-rules create --network=network allow-ssh --allow=tcp:22

Currently in the TPU VM preview, we recommend disabling OS login to allow native scp (required for PyTorch / XLA Pods).

  gcloud compute project-info add-metadata \
    --metadata enable-oslogin=FALSE --project ${PROJECT_ID}

Creating a TPU 1VM slice

We will create our TPU Pod slice in europe-west4-a because this region supports both TPU VMs and v3-32 TPU Pod slices.

  export REGION=europe-west4
export ZONE=europe-west4-a
export TPU_NAME=my-1vm-tpu
export RUNTIME_VERSION=v2-alpha
  • TPU_NAME: name of the TPU node

  • ZONE: location of the TPU node

  • ACCELERATOR_TYPE: find the list of supported accelerator types here

  • RUNTIME_VERSION: for PyTorch / XLA, use v2-alpha for single TPUs and TPU pods. This is a stable version for our public preview release.

PyTorch / XLA requires all TPU VMs to be able to access the model code and data. Using gcloud, we will include a metadata startup-script which installs the necessary packages and code on each TPU VM. 

  gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --zone ${ZONE} \
   --accelerator-type ${ACCELERATOR_TYPE} --version ${RUNTIME_VERSION} \
   --metadata startup-script='#! /bin/bash
pip install webdataset==0.1.54
pip install google-cloud-storage
pip install tensorboardX
cd /usr/share/
git clone --recursive
cd pytorch/
git clone --recursive
git clone --recursive

This command will create a v3-32 TPU Pod slice and 4 VMs, one dedicated to each TPU board. 

To ssh into a TPU VM, we will use the gcloud ssh command below. By default, this command will connect to the first TPU VM worker (denoted with w-0). To ssh into any other VM associated with the TPU Pod, append `--worker ${WORKER_NUMBER}` in the command, where the WORKER_NUMBER is 0-based. See here for more details on managing TPU VMs.   

  gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project ${PROJECT_ID}

Once in the VM, run the following command to generate the ssh-keys to ssh between VM workers on a pod:

  gcloud compute config-ssh

PyTorch training

Check to make sure the metadata startup script has cloned all the repositories. After running the following command, we should see the torchxla_tpu directory.

  cd /usr/share/pytorch

To train the model, let’s first set up some environment variables:

  export BUCKET=          # TODO ex: tpu-demo-xxxx
export TRAIN_SHARDS=    # TODO ex: 'train/cifar-train-{000000..000639}.tar'
export VAL_SHARDS=      # TODO ex: 'val/cifar-val-{000000..000049}.tar'
export WDS_TRAIN_DIR="pipe:gsutil cat gs://${BUCKET}/${TRAIN_SHARDS}"
export WDS_VAL_DIR="pipe:gsutil cat gs://${BUCKET}/${VAL_SHARDS}"
export LOGDIR="${LOGDIR:-gs://${BUCKET}/log-$(date '+%Y%m%d%H%M%S')}"

  • BUCKET: name of GCS bucket storing our sharded dataset. We will also store training logs and model checkpoints here (see guidelines on GCS object names and folders)
  • {split}_SHARDS: train/val shards, using brace notation to enumerate the shards
  • WDS_{split}_DIR: uses pipe to run a gsutil command for downloading the train/val shards
  • LOGDIR: location in GCS bucket for storing training logs

  export TPU_NAME=my-1vm-tpu         # Name of TPU 
export NUM_EPOCHS=10               # Total number of epochs
export BATCH_SIZE=128              # Samples per train batch
export TEST_BATCH_SIZE=64          # Samples per test batch
export NUM_WORKERS=8               # Workers per TPU VM to prep/load data
export TRAIN_SIZE=1280000          # Total number of training samples
export TEST_SIZE=50000             # Total number of test samples

Optionally, we can pass environment variables for storing model checkpoints and loading from a previous checkpoint file:

  export SAVE_MODEL='/tmp/' # local file to upload to GCS
export LOAD_CHKPT_FILE=                 # object in GCS bucket 
export LOAD_CHKPT_DIR=                  # local directory/filename

When we choose to save model checkpoints, a checkpoint file will be saved at the end of each epoch if the validation accuracy improves. Each time a checkpoint is created, the PyTorch / XLA utility API will save the file locally, overwriting any previous file if it exists. Then, using the Cloud Storage Python SDK, we will upload the file to the specified $LOGDIR, overwriting any previous file if it exists. Our example saves a dictionary of relevant information like this:

  if FLAGS.save_model != "":
 if accuracy > best_valid_acc:
           "epoch": epoch,
           "nepochs": FLAGS.num_epochs,
           "model_state_dict": model.state_dict(),
           "best_valid_acc": best_valid_acc,
  if xm.is_master_ordinal():
      _upload_blob_gcs(FLAGS.logdir, FLAGS.save_model, '')

Here is the function that uses the Cloud Storage SDK to upload each model checkpoint to GCS:

  def _upload_blob_gcs(gcs_uri, source_file_name, destination_blob_name):
   """Uploads a file to GCS bucket"""
   client = storage.Client()
   blob = Blob.from_string(os.path.join(gcs_uri, destination_blob_name))
   blob.bucket._client = client

If we want to resume training from a previous checkpoint, we use the LOAD_CHKPT_FILE variable to specify the GCS object to download and the LOAD_CHKPT_DIR variable to specify the local directory to place this file. Once the model is initialized, we deserialize the dictionary with torch.load(), load the model’s parameter dictionary with load_state_dict(), and move the model to the devices with .to(device)

  if FLAGS.load_chkpt_file != "":
                  FLAGS.load_chkpt_dir) # download object from GCS
   checkpoint = torch.load(FLAGS.load_chkpt_dir) # deserialize
   model.load_state_dict(checkpoint['model_state_dict']) # load params
   model = # move model to devices

Here is the function that uses the Cloud Storage SDK to download the checkpoint and save it to a local directory:

  def _read_blob_gcs(BUCKET, CHKPT_FILE, DESTINATION):
   """Downloads a file from GCS to local directory"""
   client = storage.Client()
   bucket = client.get_bucket(BUCKET)
   blob = bucket.get_blob(CHKPT_FILE)

We can use other information from our dictionary to configure the training job, such as updating the best validation accuracy and epoch:

  if FLAGS.load_chkpt_file != "":
   best_valid_acc = checkpoint['best_valid_acc']
   start_epoch = checkpoint['epoch']
   best_valid_acc = 0.0
   start_epoch = 1

If we don’t want to save or load these files, we can omit them from the command line arguments. Details on saving and loading PyTorch / XLA checkpoint files can be found here

Now we are ready to train.

  python3 -m torch_xla.distributed.xla_dist --tpu=$TPU_NAME \
   --restart-tpuvm-pod-server --env XLA_USE_BF16=1 \
   -- python3 /usr/share/pytorch/torchxla_tpu/ \
   --num_epochs=$NUM_EPOCHS \
   --batch_size=$BATCH_SIZE \
   --num_workers=$NUM_WORKERS \
   --log_steps=10 \
   --test_set_batch_size=$TEST_BATCH_SIZE \
   --wds_traindir="$WDS_TRAIN_DIR" --wds_testdir="$WDS_VAL_DIR" \
   --save_model=$SAVE_MODEL --model_bucket=$BUCKET \
   --trainsize=$TRAIN_SIZE --testsize=$TEST_SIZE \
   --logdir=$LOGDIR 2>&1 | tee -a /tmp/out-wds-1.log
  • --restart-tpuvm-pod-server restarts the XRT_SERVER (XLA Runtime) and is useful when running consecutive TPU jobs (especially if that server was left in a bad state). Since the XRT_SERVER is persistent for the pod setup, environment variables won’t be picked up until the server is restarted.

  • closely follows the PyTorch / XLA distributed, multiprocessing script, but is adapted to include support for WebDataset and CIFAR

  • TPUs have hardware support for Brain Floating Point Format, which can be used by setting XLA_USEBF16=1

During training, output for each step looks like this: [0] | Training Device=xla:0/2 Epoch=8 Step=310 Loss=0.26758 Rate=1079.01 GlobalRate=1420.67 Time=18:02:10
  • refers to the IP address for this VM worker
  • [0] refers to VM worker 0. Recall, there are 4 VM workers in our example
  • Training Device=xla:0/2 refers to the TPU core 2. In our example there are 32 TPU cores, so you should see up to xla:0/31 (since they are 0-based)
  • Rate=1079.01 refers to the exponential moving average of examples per second for this TPU core
  • GlobalRate=1420.67 refers to the average number of examples per second for this core so far during this epoch

At the end of each epoch’s train loop, you will see output like this:

  [0] Epoch 8 train end 18:02:10, Epoch Time=0:00:28, Replica Train Samples=39664, Reduced GlobalRate=45676.50
  • Replica Train Samples tells us how many training samples this replica processed
  • Reduced GlobalRate is the average GlobalRate across all replicas for this epoch

Once training is complete, you will see the following output:

  [0] Total Train Time: 0:03:59
[0] Max Accuracy: 79.67%
[0] Avg. Global Rate: 48718.11 examples per second

The logs for each VM worker are produced asynchronously, so it can be difficult to read them sequentially. To view the logs sequentially for any TPU VM worker, we can execute the following command, where the IP_ADDRESS is the address to the left of our [0].

  grep "IP_ADDRESS" /tmp/out-wds-1.log

We can convert these to a .txt file and store them in a GCS bucket like this:

  grep "IP_ADDRESS" /tmp/out-wds-1.log > /tmp/out-wds-1.log.txt

gsutil cp /tmp/out-wds-1.log.txt gs://${BUCKET}/YOUR_FILE_NAME.txt

Cleaning up

We can clean up our TPU VM resources in one simple command.

First, disconnect from the TPU VM, if you have not already done so:


In the Cloud Shell, use the following command to delete the TPU VM resources:

  gcloud alpha compute tpus tpu-vm delete ${TPU_NAME} --zone ${ZONE} --project ${PROJECT_ID}

If you wish to delete the GCS bucket and its contents, run the following command in the Cloud Shell terminal:

  gsutil rm -r gs://${BUCKET}

What’s next?

In this article we explored the challenges of using remote storage in distributed deep learning training jobs. We discussed the advantages of using sharded, sequentially readable data formats to solve the challenges with remote storage access and how the WebDataset library makes this easier with PyTorch. We then walked through an example demonstrating how to stream training data from GCS to TPU VMs and train a PyTorch / XLA model on Cloud TPU Pod slices. 


In the next installment of this series, we will revisit this example and work with Cloud TPU Tools to further optimize our training job. We will demonstrate how variables such as shard size, shard count, batch size, and number of workers impact the input pipeline, resource utilization, examples per second, accuracy, loss, and overall model convergence.   

Have a question or want to chat? Find the authors here - Jordan and Shane

Special thanks to Karl Weinmeister, Rajesh Thallam, and Vaibhav Singh for their contributions to this post, as well as Daniel Sohn, Zach Cain, and the rest of the PyTorch / XLA team for their efforts to enhance the PyTorch experience on Cloud TPUs.