Use this document to learn about the architecture of a machine learning (ML) solution that learns and serves item embeddings. Embeddings can help you understand what items your customers consider to be similar, which enables you to offer real-time "similar item" suggestions in your application.
This page is a detailed description of the solution architecture. For instructions to implement the solution, see the solution readme in the bqml-scann GitHub repo.
This solution shows you how to identify similar songs in a dataset, and then use this information to make song recommendations.
This document is for data scientists and ML engineers who want to build an ML system for item matching and recommendation use cases. It assumes that you have experience with the following technologies:
The solution uses the public
BigQuery dataset, which contains more than 12 million playlist
records. It uses the playlist data to learn embeddings
for songs, based on their co-occurrence on playlists. It then uses the
learned embeddings to identify and recommend relevant songs based on a given
After cleaning the data to remove tracks that lack a title or ID, the dataset includes over 100,000 songs that exist in at least 15 playlists each, and over 500,000 playlists that contain between 2 and 100 songs each.
The ML system uses four primary components to perform real-time similarity matching and retrieval tasks:
- A component to learn item embeddings that capture song semantic similarities. The solution uses a BigQuery ML matrix factorization model for this.
- An item embedding lookup component, which returns the embedding vector that corresponds to a given song ID. The solution uses a Keras model to provide this. The model’s prediction function accepts one or more item IDs and returns the corresponding embedding vectors.
- An embedding matching component, which returns the X most similar embedding vectors to a given embedding vector, and then maps these returned vectors to corresponding song IDs. You can determine the number of similar vectors returned by setting a variable. The solution uses the ScaNN framework framework to create an approximate nearest neighbors (ANN) index to identify similar items in a large set of embeddings. The ANN index is deployed as a model so that it can be queried in real-time. We chose this approach because it is more feasible than scanning all the embeddings to find the nearest neighbours of an input item embedding vector, which becomes unwieldy if you have millions of items.
- An item information lookup component, that supplies the song title and artist information for the returned song IDs. The solution uses a Datastore database to provide quick lookup of song information.
The ML system works as follows:
- Imports the public dataset data into a BigQuery dataset you create.
- Exports song title and artist information to Datastore by using Dataflow. This makes the song information available for real-time lookup when the system makes a song recommendation.
- Using BigQuery, computes the Pointwise mutual information (PMI) between songs, based on song co-occurrence on playlists.
- Creates a BigQuery ML matrix factorization model that uses the PMI information to learn embeddings for the songs.
- Using Dataflow, formats the embeddings as CSV files and exports them from BigQuery to Cloud Storage so they are available for lookup.
- Creates a TensorFlow Keras model to look up the song embeddings and deploys it to AI Platform Prediction.
- Using the ScaNN framework with AI Platform Training, creates an approximate nearest neighbors (ANN) index for song embeddings. This index is created as a model and trained in AI Platform Training.
- Wraps the ANN index model in the ScaNN matching service application.
- Using Cloud Build, packages the ScaNN matching service as a custom container and deploys it to AI Platform Prediction.
Optionally, you an also create a TensorFlow Extended pipeline to automate the system. For more information, see Pipeline implementation.
The following diagram shows an overview of the solution architecture:
Once deployed, the solution works as follows:
- A customer action on the client triggers a request for songs that are similar to a given song, and sends the ID for that song to the ScanNN matching system.
- The ScanNN matching system sends the song ID to the embedding lookup model, and gets back the embedding for that song.
- The ScanNN matching system sends the song embedding to the ANN model, and gets back the embeddings for the song's approximate nearest neighbors.
- The ScanNN matching system maps the returned embeddings to the appropriate song IDs.
- The ANN model returns the song IDs for similar songs to the client.
- The client uses those IDs to retrieve the song title and artist information from Datastore for display.
You can optionally automate the solution by using TensorFlow Extended (TFX), a comprehensive framework for automating ML systems. A TFX pipeline organizes a sequence of components that implement the ML system in a way that is scalable and provides good performance. Components are built using TFX standard components, custom Python function components, and fully custom components. You can also use these components outside of the pipeline if you want to.
You deploy and execute the TFX pipeline on AI Platform Pipelines. Each pipeline component is executed on its corresponding Google Cloud service, these being BigQuery, Dataflow, and AI Platform Training. The pipeline metadata is stored in ML Metadata (MLMD) hosted on Cloud SQL, while the artifacts produced by the pipeline components are stored in Cloud Storage.
The workflow of the TFX pipeline is as follows:
- Computes PMI on item co-occurrence data by using a custom Python function component.
- Trains a BigQuery ML matrix factorization model on the PMI data to learn item embeddings by using a custom Python function component.
- Extracts the embeddings from the model to a BigQuery table by using a custom Python function component.
- Exports the embeddings in TFRecord format by using the standard BigQueryExampleGen component.
- Imports the schema for the embeddings by using the standard ImporterNode component.
- Validates the embeddings against the imported schema by using the standard StatisticsGen and ExampleValidator components.
- Creates an embedding lookup SavedModel by using the standard Trainer component.
- Pushes the embedding lookup model to a model registry directory by using the standard Pusher component.
- Builds the ScaNN index by using the standard Trainer component.
- Evaluates and validates the ScaNN index latency and recall by implementing a TFX custom component.
- Pushes the ScaNN index to a model registry directory by using the standard Pusher component.
The following diagram shows the TFX pipeline components and their input/output artifacts that are used to implement the ML system workflow:
Understanding item embeddings
Many compelling recommendation use cases are enabled by matching semantically similar items, for example, matching similar products, movies, songs, etc. To enable this, you can use machine learning to create a representation space where similar items are placed close to each other.
An embedding is a way to represent discrete items as vectors of floating point numbers, such that the embeddings of two items are similar if the items share the same context. More precisely, when items commonly occur together in a given context, they are considered semantically similar. For example, words that occur in the same textual context are often similar, movies watched by the same person are assumed to be similar, and products appearing together in shopping baskets tend to be similar.
Based on this understanding, a good way to learn item embeddings is to examine how frequently two items co-occur in a dataset. When you use co-occurrence to represent item similarity, you can use unsupervised machine learning to learn item embeddings, which means you can avoid having to create a labeled dataset. Once you have item embeddings, you can use them in supervised machine learning tasks, like classification, regression, and forecasting.
Using matrix factorization to learn item embeddings
Item-item collaborative filtering is a common technique for building recommender systems. This approach has you analyze how users interact with items (like, view, share, rate, purchase, etc.) in a client application, in order to understand which items have similar item-interaction behaviour. It then uses this information, rather than item description and user profile information, to make item recommendations to users. That latter technique is called content-based filtering and is also frequently used in recommender systems.
Item-item collaborative filtering by using a matrix factorization model is the approach taken in this solution. Similar items will have similar embedding vectors, and will map close to each other in embedding space.
To understand collaborative filtering, take movie preferences as an example. The following figure shows how a matrix factorization model uses collaborative filtering to learn the embedding matrix U for users and the embedding matrix V for movies, given the user-item feedback matrix A. The product UVT is a good approximation of the feedback matrix A, as shown in the following illustration:
In item-item collaborative filtering, you abstract this a step so that the matrix factorization model learns only the item embedding matrix V, such that the product VVT is a good approximation of item co-occurrence, as shown in the following illustration:
In this solution, the matrix factorization model uses pointwise mutual information (PMI) to learn the embeddings. PMI is a statistical measure of association between two items. It has been used in a number of word embedding learning algorithms, such as Swivel and GloVe. The PMI between items x and y in dataset D is computed as follows:
pmi(x,y) = log(|x,y|) - log(|x|) - log(|y|) + log(|D|)
Note that the PMI between items x and y increases as their co-occurrence increases, and decreases as the independent occurrence of each item increases.
- Implement the solution described on this page by using the instructions in the bqml-scann readme.
- Learn about other smart analytics solutions.