Training Deep Learning-based recommender models of 100 trillion parameters over Google Cloud
Prof. Ce Zhang
Training recommender models of 100 trillion parameters
A recommender system is an important component of Internet services today: billion dollar revenue businesses are directly driven by recommendation services at big tech companies. The current landscape of production recommender systems is dominated by deep learning based approaches, where an embedding layer is first adopted to map extremely large-scale ID type features to fixed-length embedding vectors; then the embeddings are leveraged by complicated neural network architectures to generate recommendations. The continuing advancement of recommender models is often driven by increasing model sizes--several models have been previously released with billion parameters up to even trillion very recently. Every jump in the model capacity has brought in significant improvement on quality. The era of 100 trillion parameters is just around the corner.
The scale of training tasks for recommender models has created unique challenges. There is a staggering heterogeneity of the training computation--the model's embedding layer could include more than 99.99% of the total model size, which is extremely memory-intensive. Meanwhile, the complicated, dense rest neural network is increasingly computation-intensive with more than 100 TFLOPs in each training iteration. Thus, it is important to have some sophisticated mechanism to manage a cluster with heterogeneous resources for such training tasks.
Recently, Kwai Seattle AI Lab and DS3 Lab from ETH Zurich have collaborated to propose a novel system named “Persia” to tackle this problem through careful co-design of both the training algorithm and the training system. At the algorithm level, Persia adopts a hybrid training algorithm to handle the embedding layer and dense neural network modules differently. The embedding layer is trained asynchronously to improve the throughput of training samples, while the rest neural network is trained synchronously to preserve statistical efficiency. At the system level, a wide range of system optimizations for memory management and communication reduction have been implemented to unleash the full potential of the hybrid algorithm.
Deploying a large-scale training on Google Cloud
The massive scale required by Persia posed multiple challenges, from network bandwidth required across components to the amount of RAM memory required to store the embeddings. Additionally, there is a sizable number of virtual machines needed to be deployed, automated, and orchestrated to minimize the pipeline and optimize costs. Specifically, the workload runs on the following heterogeneous resources:
- 3,000 cores of compute-intensive Virtual Machines
- 8 A2 Virtual Machines adding a total of 64 A100 Nvidia GPUs
- 30 High Memory Virtual Machines, each with 12 TB of RAM, totalling 360 TB
- Orchestration with Kubernetes
All resources had to be launched concurrently in the same zone to minimize network latency. Google Cloud was able to provide the required capacity with very little notice.
Given the bursty nature of the training, Google Kubernetes Engine (GKE) was utilized to orchestrate the deployment of the 138 VMs and software containers. Having the workload containerized also allows for porting and repeatability of the training.
The team chose to keep all embeddings in memory during the training. This requires the availability of highly specialized “Ultramem” VMs, though for a relatively short period of time. This was critical to scale the training up to 100 trillions parameters while keeping cost and duration of processing under control.
Results and Conclusions
With the support of the Google Cloud infrastructure, the team demonstrated Persia’s scalability up to 100 trillion parameters. The hybrid distributed training algorithm introduced elaborate system relaxations for efficient utilization of heterogeneous clusters, while converging as fast as vanilla SGD. Google Cloud was essential to overcome the limitations of on-premise hardware and proved an optimal computing environment for distributed Machine Learning training on a massive scale.
Persia has been released as an open source project on github with setup instructions for Google Cloud —everyone from both academia and industry would find it easy to train 100-trillion-parameter scale, deep learning recommender models.