Jump to Content
AI & Machine Learning

How Cohere is accelerating language model training with Google Cloud TPUs

July 27, 2022
Joanna Yoo

Machine Learning Engineer, Cohere

Vaibhav Singh

Sr. Product Manager

Over the past few years, advances in training large language models (LLMs) have moved natural language processing (NLP) from a bleeding-edge technology that few companies could access, to a powerful component of many common applications. From chatbots to content moderation to categorization, a general rule for NLP is that the larger the model, the greater the accuracy it’s able to achieve in understanding and generating language.

But in the quest to create larger and more powerful language models, scale has become a major challenge. Once a model becomes too large to fit on a single device, it requires distributed training strategies, which in turn require extensive compute resources with vast memory capacity and fast interconnects. You also need specialized algorithms to optimize the hardware and time resources.

Cohere engineers are working on solutions to this scaling challenge that have already yielded results. Cohere provides developers a platform for working with powerful LLMs without the infrastructure or deep ML expertise that such projects typically require. In a new technical paper, Scalable Training of Language Models using JAX pjit and TPUv4, engineers at Cohere demonstrate how their new FAX framework deployed on Google Cloud’s recently announced Cloud TPU v4 Pods addresses the challenges of scaling LLMs to hundreds of billions of parameters. Specifically, the report reveals breakthroughs in training efficiency that Cohere was able to achieve through tensor and data parallelism. 

This framework aims to accelerate the research, development, and production of large language models with two significant improvements: scalability and rapid prototyping. Cohere will be able to improve its models by training larger ones more quickly, delivering better models to its customers faster. The framework also supports rapid prototyping of models that address specific objectives — for example, creating a generative model that powers customer-service chatbot — by experimenting and testing new ideas. The ability to switch back and forth among model types and optimize for different objectives will ultimately allow Cohere to offer models optimized for particular use cases. 

The FAX framework relies heavily on the partitioned just-in-time compilation (pjit) feature of JAX, which abstracts the relationship between device and workload. This allows Cohere engineers to optimize efficiency, and performance by aligning devices and processes in the ideal configuration for the task at hand. Pjit works by compiling an arbitrary function into a single program (an XLA computation), that runs on multiple devices — even those residing on different hosts.

Cohere’s new solution also takes advantage of Google Cloud’s new TPU v4 Pods to perform tensor parallelism. which is more efficient than the earlier pipeline parallelism implementation. As the name suggests, the pipeline parallel approach uses accelerators in a linear fashion to scale a workload, like a single long assembly line. Accelerators must process each micro-batch of data before passing it along to the next one, and then run the backward pass in reverse order. 

Tensor parallelism eliminates the accelerator idle time of pipeline parallelism, also known as the pipeline bubble. Tensor parallelism involves partitioning large tensors (mathematical arrays that define the relationship among multiple objects such as the words in a paragraph) across accelerators to perform computations at the same time on multiple devices. If pipeline parallelism is an ever-lengthening assembly line, tensor parallelism is a series of parallel assembly lines — one making the engine, the other the body, etc. — that simultaneously come together to form a complete car in a fraction of the time.

These computations are then collated, a process made practical thanks to Google Cloud TPU v4 VMs, which more than double the computational power of their v3 predecessors. The superior performance of v4 chips has enabled Cohere to iterate on ideas and validate them 1.7X faster in computation than before.

At Cohere, we build cutting-edge natural language processing (NLP) services, including APIs for language generation, classification, and search. These tools are built on top of a set of language models that Cohere trains from scratch on Cloud TPUs using JAX. The superior performance of v4 chips has enabled Cohere to iterate on ideas and validate them 1.7X faster in computation than before, allowing faster iterations for our researchers and higher quality results for our customers. The exceptionally low carbon footprint of Cloud TPU v4 Pods was another key factor for us.

Aidan Gomez, CEO and co-founder, Cohere

Why Google Cloud for LLM training?

As part of a multiyear technology partnership, Cohere leverages Google Cloud’s advanced AI and ML infrastructure to power its platform. Cohere develops and deploys its products on Cloud TPUs, Google Cloud’s custom-designed machine learning chips that are optimized for large-scale ML. Cohere’s recently announced their new model improvements and scalability by training an LLM using FAX on Google Cloud TPUs, and this model has demonstrated that transitioning from TPU v3 to TPU v4 has so far enabled them to achieve a total speedup of 1.7x . In addition to a significant performance boost, TPUs provide an excellent user experience with the new TPU VM architecture. Importantly, Google Cloud ensures that Cohere's state-of-the-art ML training is achieved with the highest standards of sustainability,  powered by 90% carbon-free energy in the world's largest publicly available ML hub.

By adopting Cloud TPUs, Cohere is making LLM training faster, more economical, and more agile. This helps them provide larger and more accurate LLMs to developers, and put NLP technology in the hands of developers and businesses of all sizes.

To learn more about these LLM training advances, you can read the full paper, Scalable Training of Language Models using JAX pjit and TPUv4. To learn more about Cohere's best practices and AI principles, you can check this article co-authored with Open AI and AI 21 Labs.

Posted in