Scaling deep learning workloads with PyTorch / XLA and Cloud TPU VM
Jordan Totten
Machine Learning Specialist
Shane Hansen
Customer Engineer
Try Google Cloud
Start building on Google Cloud with $300 in free credits and 20+ always free products.
Free trialIntroduction
Many deep learning advancements can be attributed to increases in (1) data size and (2) computational power. Training deep learning models with larger datasets can be extremely beneficial for model training. Not only do they help stabilize model performance during training, but research shows that for moderate to large-scale models and datasets, model performance converges as a power-law with training data size, meaning we can predict improvements to model accuracy as the dataset grows.
Figure 1: Learning curve and dataset size for word language models (source)
In practice this means as we look to improve model performance with larger datasets, (1) we need access to hardware accelerators, such as GPUs or TPUs, and (2) we need to architect a system that efficiently stores and delivers this data to the accelerators. There are a few reasons why we may choose to stream data from remote storage to our accelerator devices:
- Data size: data can be too large to fit on a single machine, requiring remote storage and efficient network access
- Streamlined workflows: transferring data to disk can be time consuming and resource intensive, we want to make fewer copies of the data
- Collaboration: disaggregating data from accelerator devices means we can more efficiently share accelerator nodes across workloads and teams
Streaming training data from remote storage to accelerators can alleviate these issues, but it introduces a host of new challenges:
- Network overhead: Many datasets consist of millions of individual files, randomly accessing these files can introduce network bottlenecks. We need sequential access patterns
- Throughput: Modern accelerators are fast; the challenge is feeding them fast enough to keep them fully utilized. We need parallel I/O and pipelined access to data
- Randomness vs Sequential: The optimization algorithms in deep learning jobs benefit from randomness, but random file access introduces network bottlenecks. Sequential access alleviates network bottlenecks, but can reduce the randomness needed for training optimization. We need to balance these
How do we architect a system that addresses these challenges at scale?
Figure 2: Scaling to larger datasets, more devices
In this post, we will cover:
- The challenges associated with scaling deep learning jobs to distributed training settings
- Using the new Cloud TPU VM interface
- How to stream training data from Google Cloud Storage (GCS) to PyTorch / XLA models running on Cloud TPU Pod slices
You can find accompanying code for this article in this GitHub repository.
Model and dataset
In this article, we will train a PyTorch / XLA ResNet-50 model on a v3-32 TPU Pod slice where training data is stored in GCS and streamed to the TPU VMs at training time. ResNet-50 is a 50-layer convolutional neural network commonly used for computer vision tasks and machine learning performance benchmarking. To demonstrate an end-to-end example, we will use the CIFAR-10 dataset. The original dataset consists of 60,000 32x32 color images divided into 10 classes, each class containing 6,000 images. We have upsampled this dataset, creating a training and test set of 1,280,000 and 50,000 images, respectively. CIFAR is used because it is publicly accessible and well known; however, in the GitHub repository, we provide guidance for adapting this solution to your workloads, as well as larger datasets such as ImageNet.
Cloud TPU
TPUs, or Tensor Processing Units, are ML ASICs specifically designed for large-scale model training. As they excel at any task where large matrix multiplications dominate, they can accelerate deep learning jobs and reduce the total cost of training. If you're new to TPUs, check this article to understand how they work.
The v3-32 TPU used in this example consists of 32 TPU v3 cores and 256 GiB of total TPU memory. This TPU Pod slice consists of 4 TPU Boards (a Board has 8 TPU cores). Each TPU Board is connected to a high-performance CPU-based host machine for things like loading and preprocessing data to feed to the TPUs.
We will access the TPU through the new Cloud TPU VMs. When we use Cloud TPU VMs, a VM is created for each TPU board in the configuration. Each VM consists of 48 vCPUs and 340 GB of memory, and comes preinstalled with the latest PyTorch / XLA image. Because there is no user VM, we ssh directly into the TPU host to run our model and code. This root access eliminates the need for a network, VPC, or firewall between our code and the TPU VM, which can significantly improve the performance of our input pipeline. For more details on Cloud TPU VMs, see the System Architecture.
PyTorch / XLA
PyTorch / XLA is a Python library that uses the XLA (Accelerated Linear Algebra) deep learning compiler to connect PyTorch and Cloud TPUs. Check out the GitHub repository for tutorials, best practices, Docker Images, and code for popular models (e.g., ResNet-50 and AlexNet).
Data parallel distributed training
Distributed training typically refers to training workloads which use multiple accelerator devices (e.g. GPU or TPU). In our example, we are executing a data parallel distributed training job with stochastic gradient descent. In data parallel training, our model fits on a single TPU device and we replicate the model across each device in our distributed configuration. When we add more devices, our goal is to reduce overall training time by distributing non-overlapping partitions of the training batch to each device for parallel processing. Because our model is replicated across devices, the models on each device need to communicate to synchronize their weights after each training step. In distributed data parallel jobs, this device communication is typically done either asynchronously or synchronously.
Cloud TPUs execute synchronous device communication over the dedicated high-speed network connecting the chips. In our model code, we use PyTorch / XLA’s optimizer_step(optimizer)
to calculate the gradients and initiate this synchronous update.
Figure 4: Synchronous all-reduce on Cloud TPU interconnect
After the local gradients are computed, the xm.optimizer_step() function synchronizes the local gradients between cores by applying an AllReduce(SUM) operation, and then calls the PyTorch optimizer_step(optimizer), which updates the local weights with the synchronized gradients. On the TPU, the XLA compiler generates AllReduce operations over the dedicated network connecting the chips. Ultimately, the globally averaged gradients are written to each model replica’s parameter weights, ensuring the replicas start from the same state in every training iteration. We can see the call to this function in the training loop:
Input pipeline performance
As previously mentioned, the challenge with TPUs is feeding them the training data fast enough to keep them busy. This problem exists when we store training data on a local disk and becomes even more clear when we stream data from remote storage. Let’s first review a typical machine learning training loop.
Figure 5: Common machine learning training loop and hardware configuration
In this illustration, we see the following steps:
- Training data is either stored on local disk or remote storage
- The CPU (1) requests and reads the data, augments it with various transformations, batches it, and feeds it to the model
- Once the model has the transformed, batched training data, (2) the accelerator takes over
- The accelerator (2a) computes the forward pass, (2b) loss, and (2c) backwards pass
- After computing the gradients, (3) the parameter weights are updated (the learning!)
- And we repeat the cycle over again
While this pattern can be adapted in several ways (e.g., some transformations could be computed on the accelerator), the prevailing theme is that an ideal architecture seeks to maximize utilization of the most expensive component, the accelerator. And because of this, we see most performance bottlenecks occurring in the input pipeline driven by the CPU. To help with this, we are going to use the WebDataset library. WebDataset is a PyTorch dataset implementation designed to improve streaming data access for deep learning workloads, especially in remote storage settings. Let’s see how it helps.
WebDataset format
WebDatasets are just POSIX tar archive files, and they can be created with the well-known tar command. They don't require any data conversion; the data format is the same in the tar file as it is on disk. For example, our training images are still in PPM, PNG, or JPEG format when they are stored and transferred to the input pipeline. The tar format provides performance improvements for both small and large datasets, as well as data stored on either local disk or remote storage, such as GCS. Let’s outline three key pipeline performance enhancements we can achieve with WebDataset.
(1) Sequential I/O
GCS is capable of sustaining high throughput, but there is some network overhead when initiating a connection. If we are accessing millions of individual image files, this is not ideal. Alternatively, we can achieve sequential I/O by requesting a tar file containing our individual image files. Once we request the tar file, we get sequential reads of the individual files within that tar file, which allows for faster object I/O over the network. This reduces the number of network connections to establish with GCS, and thus reduces potential network bottlenecks.
Figure 6: Comparing random and pipelined access to data files
(2) Pipelined data access
With file-based I/O we randomly access image files, which is good for training optimization, but for each image file there is a client request and storage server response. Our sequential storage achieves higher throughput because with a single client request for a tar file, the data samples in that file flow sequentially to the client. This pattern gives us pipelined access to our individual image files, resulting in higher throughput.
(3) Sharding
Storing TBs of data in a single sequential file could be difficult to work with and it prevents us from achieving parallel I/O. Sharding the dataset can help us in several ways:
- Aggregate network I/O by opening shards in parallel
- Accelerate data preprocessing by processing shards in parallel
- Randomly access shards, but read sequentially within each shard
- Distribute shards efficiently across worker nodes and devices
- Guarantee equal number of training samples on each device
Because we can control the number of shards and the number of samples in those shards, we can distribute equal-sized shards and guarantee each device receives the same number of samples in each training epoch. Sharding the tar files helps us balance the tradeoff between random files access and sequential reads. Random access to the shards and in-memory shuffling satisfy enough randomness for the training optimization. The sequential reads from each shard reduce network overhead.
Distributing shards across devices and workers
As we are essentially creating a PyTorch IterableDataset, we can use the PyTorch DataLoader to load data on the devices for each training epoch. Traditional PyTorch Datasets distribute data at the sample-level, but we are going to distribute at the shard-level. We will create two functions to handle this distribution logic and pass them to the `splitter=` and `nodesplitter=` arguments when we create our dataset object. All these functions need to do is take a list of shards and return a subset of those shards. (To see how the following snippets fit into the model script, check out test_train_mp_wds_cifar.py in the accompanying GitHub repository.)
We will split shards across workers with:
We will split shards across devices with:
With these two functions we will create a data loader for both train and validation data. Here is the train loader:
Here is an explanation of some of the variables used in these snippets:
xm.xrt_world_size()
is the total number of devices, or TPU coresFLAGS.num_workers
is the number of subprocesses spawned per TPU core for loading and preprocessing data- The
epoch_size
specifies the number of training samples each device should expect for each epoch shardshuffle=True
means we will shuffle the shards, while.shuffle(10000)
shuffles samples inline.batched(batch_size, partial=True)
explicitly batches data in the Dataset bybatch_size
and‘partial=True’
handles partial batches, typically found in the last shard- Our loader is a standard PyTorch DataLoader. Because our WebDataset Dataset accounts for batching, shuffling, and partial batches, we do not use these arguments in PyTorch’s DataLoader
Performance comparison
The table in Figure 7 compares the performance between 3 different training configurations for a PyTorch / XLA ResNet-50 model training on the ImageNet dataset. Configuration A provides baseline metrics and represents a model reading from local storage, randomly accessing individual image files. Configuration B uses a similar setup as A, except the training data is sharded into 640 POSIX tar files and the WebDataset library is used to sample and distribute shards to the model replicas on Cloud TPU devices. Configuration C uses the same sampling and distribution logic as B, but sources training data from remote storage in GCS. The metrics represent an average of each configuration over five 90-epoch training jobs.
Figure 7: Training performance comparison
Comparing configurations A and B, these results show that simply using a sharded, sequentially readable data format improves pipeline and model throughput (average examples per second) by 11.2%. They also show that we can take advantage of remote storage without negatively impacting model training performance. Comparing configurations A and C, we were able to maintain pipeline and model throughput, training time, and model accuracy.
To highlight the impacts of sequential and parallel I/O, we held many configuration settings constant. There are still several areas to investigate and improve. In a later post we will show how to use the Cloud TPU profiler tool to further optimize PyTorch / XLA training jobs.
End-to-end example
Let’s walk through a full example.
To follow this example, you can use this notebook to create a sharded CIFAR dataset.
Before you begin
In the Cloud Shell, run the following commands to configure gcloud to use your GCP project, install components needed for the TPU VM preview, and enable the TPU API. For additional TPU 1VM setup details, see these instructions.Connecting to a Cloud TPU VM
The default network comes preconfigured to allow ssh access to all VMs. If you don’t use the default network, or the default network settings were edited, you may need to explicitly enable SSH access by adding a firewall rule:
Currently in the TPU VM preview, we recommend disabling OS login to allow native scp (required for PyTorch / XLA Pods).
Creating a TPU 1VM slice
We will create our TPU Pod slice in europe-west4-a because this region supports both TPU VMs and v3-32 TPU Pod slices.
TPU_NAME
: name of the TPU nodeZONE
: location of the TPU nodeACCELERATOR_TYPE
: find the list of supported accelerator types hereRUNTIME_VERSION
: for PyTorch / XLA, use v2-alpha for single TPUs and TPU pods. This is a stable version for our public preview release.
PyTorch / XLA requires all TPU VMs to be able to access the model code and data. Using gcloud, we will include a metadata startup-script
which installs the necessary packages and code on each TPU VM.
This command will create a v3-32 TPU Pod slice and 4 VMs, one dedicated to each TPU board.
To ssh into a TPU VM, we will use the gcloud ssh
command below. By default, this command will connect to the first TPU VM worker (denoted with w-0). To ssh into any other VM associated with the TPU Pod, append `--worker ${WORKER_NUMBER}
` in the command, where the WORKER_NUMBER
is 0-based. See here for more details on managing TPU VMs.
Once in the VM, run the following command to generate the ssh-keys to ssh between VM workers on a pod:
PyTorch training
Check to make sure the metadata startup script has cloned all the repositories. After running the following command, we should see the torchxla_tpu
directory.
To train the model, let’s first set up some environment variables:
BUCKET
: name of GCS bucket storing our sharded dataset. We will also store training logs and model checkpoints here (see guidelines on GCS object names and folders){split}_SHARDS
: train/val shards, using brace notation to enumerate the shardsWDS_{split}_DIR
: uses pipe to run agsutil
command for downloading the train/val shardsLOGDIR:
location in GCS bucket for storing training logs
Optionally, we can pass environment variables for storing model checkpoints and loading from a previous checkpoint file:
When we choose to save model checkpoints, a checkpoint file will be saved at the end of each epoch if the validation accuracy improves. Each time a checkpoint is created, the PyTorch / XLA xm.save()
utility API will save the file locally, overwriting any previous file if it exists. Then, using the Cloud Storage Python SDK, we will upload the file to the specified $LOGDIR
, overwriting any previous file if it exists. Our example saves a dictionary of relevant information like this:
Here is the function that uses the Cloud Storage SDK to upload each model checkpoint to GCS:
If we want to resume training from a previous checkpoint, we use the LOAD_CHKPT_FILE
variable to specify the GCS object to download and the LOAD_CHKPT_DIR
variable to specify the local directory to place this file. Once the model is initialized, we deserialize the dictionary with torch.load()
, load the model’s parameter dictionary with load_state_dict()
, and move the model to the devices with .to(device)
.
Here is the function that uses the Cloud Storage SDK to download the checkpoint and save it to a local directory:
We can use other information from our dictionary to configure the training job, such as updating the best validation accuracy and epoch:
If we don’t want to save or load these files, we can omit them from the command line arguments. Details on saving and loading PyTorch / XLA checkpoint files can be found here.
Now we are ready to train.
--restart-tpuvm-pod-server
restarts the XRT_SERVER (XLA Runtime) and is useful when running consecutive TPU jobs (especially if that server was left in a bad state). Since the XRT_SERVER is persistent for the pod setup, environment variables won’t be picked up until the server is restarted.test_train_mp_wds_cifar.py
closely follows the PyTorch / XLA distributed, multiprocessing script, but is adapted to include support for WebDataset and CIFARTPUs have hardware support for Brain Floating Point Format, which can be used by setting
XLA_USEBF16=1
During training, output for each step looks like this:
10.164.0.25
refers to the IP address for this VM worker[0]
refers to VM worker 0. Recall, there are 4 VM workers in our exampleTraining Device=xla:0/2
refers to the TPU core 2. In our example there are 32 TPU cores, so you should see up toxla:0/31
(since they are 0-based)Rate=1079.01
refers to the exponential moving average of examples per second for this TPU coreGlobalRate=1420.67
refers to the average number of examples per second for this core so far during this epoch
At the end of each epoch’s train loop, you will see output like this:
Replica Train Samples
tells us how many training samples this replica processedReduced GlobalRate
is the average GlobalRate across all replicas for this epoch
Once training is complete, you will see the following output:
The logs for each VM worker are produced asynchronously, so it can be difficult to read them sequentially. To view the logs sequentially for any TPU VM worker, we can execute the following command, where the IP_ADDRESS
is the address to the left of our [0].
We can convert these to a .txt
file and store them in a GCS bucket like this:
Cleaning up
We can clean up our TPU VM resources in one simple command.
First, disconnect from the TPU VM, if you have not already done so:
In the Cloud Shell, use the following command to delete the TPU VM resources:
If you wish to delete the GCS bucket and its contents, run the following command in the Cloud Shell terminal:
What’s next?
In this article we explored the challenges of using remote storage in distributed deep learning training jobs. We discussed the advantages of using sharded, sequentially readable data formats to solve the challenges with remote storage access and how the WebDataset library makes this easier with PyTorch. We then walked through an example demonstrating how to stream training data from GCS to TPU VMs and train a PyTorch / XLA model on Cloud TPU Pod slices.
References
In the next installment of this series, we will revisit this example and work with Cloud TPU Tools to further optimize our training job. We will demonstrate how variables such as shard size, shard count, batch size, and number of workers impact the input pipeline, resource utilization, examples per second, accuracy, loss, and overall model convergence.
Have a question or want to chat? Find the authors here - Jordan and Shane.
Special thanks to Karl Weinmeister, Rajesh Thallam, and Vaibhav Singh for their contributions to this post, as well as Daniel Sohn, Zach Cain, and the rest of the PyTorch / XLA team for their efforts to enhance the PyTorch experience on Cloud TPUs.