Analyzing and validating data at scale for machine learning with TensorFlow Data Validation

This document discusses how to use the TensorFlow Data Validation (TFDV) library for data exploration and descriptive analytics during experimentation. Data scientists and machine learning (ML) engineers can use TFDV in a production ML system to validate data that's used in a continuous training (CT) pipeline, and to detect skews and outliers in data received for prediction serving.

The document covers concepts, challenges, and resolutions pertaining to analysis and validation for ML at production scale using TFDV. It's intended for data scientists who have a background with ML training and evaluation.

For hands-on practice, you can work with the analyzing and validating data with TFDV tutorials.

The ML experimentation process

After you define an ML use case and its objective, you go through an iterative process before you put your ML continuous training pipeline and model serving into production. As a data scientist or ML engineer, you are probably familiar with this process, which consists of the following steps:

  1. Data exploration and analysis
  2. Data preparation and transformation
  3. Model training and tuning
  4. Model evaluation and validation

Each step depends on a preceding step, but they can be iterative. For example, after evaluating a model, you might try training a different model configuration, or you might revisit the data preparation step to design new features.

This document focuses on the first step: data exploration. For more details on ML productionization, see MLOps: Continuous delivery and automation pipelines in ML.

Exploratory data analysis

As a data scientist, you typically start your ML process with exploratory data analysis (EDA), where the goal is to understand the properties of the data that you'll use to train your model. You do this by computing statistical properties from your data (or from a sample of the data), and by displaying these summaries using visualizations.

The data analysis and exploration process leads to key insights and outcomes:

  • An understanding of the issues (if any) in the data that require attention during the data preparation step, such as missing values, outliers, and type conversions.
  • Insight into the feature engineering operation (for example, feature selection, extraction, or construction) that improves the quality of the data for your ML task.
  • An idea of what kind of ML model would be a good fit for your data, taking into account feature sparsity, linear dependency, and so on.
  • A schema with metadata that describes the properties of your data, including data types, allowed ranges, and additional characteristics. The schema is the key output of the data analysis process because subsequent ML steps depend on it.

For more information about data and feature engineering for ML on Google Cloud, see the Data preprocessing for ML reference guide.

Purpose of the schema

The schema describes the expected data types, format, and distribution of its values. It's used for these purposes:

  • It enables metadata-driven preprocessing and model creation, rather than requiring you to hard-code feature handling for each feature. For example, given the data types of the features that are defined in the schema, all the numerical features are normalized, and embedding layers are created for all the categorical features in the model.
  • It's used to validate new data that's received for training your model so that you can catch data anomalies and format mismatches. In addition, it's used to detect skews and outliers in data that's received for prediction overtime.

Challenges for maintaining model performance

In some use cases, the ML process described earlier is a one-off, manual process. That is, data scientists move from one step to the next one (forward or backward) manually, and they put a model into production only when it shows acceptable performance.

In this one-off process, a new version of the model is deployed only rarely. In practice, however, models often break down when they're deployed in the real world because they fail to adapt to changes in the dynamics of environments, or to the data that describes that environment. For more information, see Why Machine Learning Models Crash And Burn In Production.

To maintain performance in production, data scientists need to retrain ML models repeatedly at intervals, with the frequency depending on the use case. Retraining the model captures the change in the dynamics of environments and the changes in the format of the data itself.

This introduces two challenges:

  • How to automatically detect anomalies in the new data through a programmatic data validation process.
  • How to perform this data validation process at big-data scale.

Automatically detecting data anomalies

Model performance can be negatively affected by anomalies and divergence between data splits for training, evaluation, and serving. In continuous training (CT) pipelines, new training data, as well as the evaluation data split, is expected to comply with the schema that's expected by the model. Anomalies in data can be categorized into the following:

  • Data format anomalies (schema skew): These anomalies occur when different data splits have different schemas. For example, the training data schema differs from the schema that's expected by the model.
  • Data values/distribution anomalies (data skew): These anomalies occur when the statistical properties of the data splits are different. For example, the serving data statistical properties differ from properties of the data that the model trained with.

To illustrate, imagine you are training a model that predicts US house prices in dollars based on parcel size (datatype float), number of rooms (datatype int), and house type (datatype string, values detached or duplex). The following two sections give examples of both data format and distribution anomalies.

Anomalies in data format

Assume that you've trained the model on two types of homes, detached or duplex. However, at retraining or serving time, the model encounters a new categorical value, the quadplex. This is an example of schema skew. The data at production does not adhere to the definitions of correct data as derived from the training data.

Schema anomalies include the following:

  • Data type mismatch. The schema codifies a data field to be of a certain type, such as float. An anomaly is generated when a different data type is provided at serve time, like the number of rooms as a float number instead of as an int number, or an encoded house type as 1 or 2 instead of "detached" or "duplex".
  • Missing or unexpected features. This type of anomaly can occur if the serving data is consistently missing a feature such as house type, but your model is trained to use that feature for estimating the price of the house.
  • Data domain violation. This anomaly can occur when a field's value falls outside the allowed range. Data range anomalies can also occur with categorical and numerical features. For example, housing prices should be larger than 0, and the schema expects house types to be either "detached" or "duplex".

Anomalies in data value distributions

Anomalies based on data value distributions include the following:

  • Feature skew. The data values of a given feature are different between data splits. This can happen if feature values come from different data sources. It can also occur if inconsistent transformations are applied during training and serving—for example, you train your house price estimation model on dollars, but a Japanese web app performs prediction serving in yen.

  • Distribution skew. The distribution of feature values differs between data splits. This can be due to a faulty sampling mechanism, or if you chose the wrong training dataset to represent real-world data. This also can occur if your model has gone stale—that is, the model has not been trained for a while using fresh data, but the data patterns and relationships have evolved.

For example, suppose that housing in a particular locale is booming, which causes a steep increase in housing prices across the market. Your current model's predictive performance will degrade because it wasn't trained using current price ranges.

In a case like this, when you compare the serving data to the training data, you detect a distribution skew. The data format is correct, but the production data of actual housing prices shows a different price distribution compared to the training data. Prices in serving data are centered around a higher mean. To fix this skew, you must retrain your model using higher housing prices in order to maintain predictive performance.

Similarly, when you compare weekly periods of serving data, you also detect a drift anomaly. Concept drift describes the case where you compare the distribution of values over consecutive time periods. Changes in distribution over time can also affect performance. For example, suppose that the distribution of house types on the market changes—the share of detached houses offered in the market increases as people struggle to pay their mortgage and put their expensive houses on the market.

Both data skews and schema skews can negatively impact the predictive performance of your model, and they can be difficult to detect manually. Therefore, a goal is to automatically detect these types of skews.

Exploration and validation at big-data scale

The global volume of data is growing at a very fast rate, as IT systems use data to capture more and more of what happens in the world. ML has seen rapid adoption because it can help turn big data into valuable insights. Deep neural networks (DNN) especially thrive on large amounts of data because DNN performance continues improving as the size of the data grows, where other models plateau.

In fact, working with large data sets is a challenge not just for model training, but for each step of the ML process, starting with data exploration and analysis. Many methods for exploring and validating data rely on loading data into memory. But this approach doesn't scale to terabyte-sized or petabyte-sized datasets.

One solution is to get a small sample of the data to analyze and then perform your work on that subset. However, this can lead the model to ignore a lot of information and useful patterns that are found only in the whole dataset. Another approach is to compute the statistics for the whole dataset, and then use these statistics as a summary that describes the data. The work of computing statistics like these on a large dataset needs to scale out over multiple machines for distributed processing, instead of just scaling up one machine. Therefore, the goal is to be able to explore and validate your data for ML at scale.

Applying TensorFlow Data Validation

TFDV is part of TensorFlow Extend (TFX). TFX is a production-scale machine learning platform that provides a component framework and shared libraries for implementing production ML pipelines.

A TFX pipeline defines a data flow through several components, with the goal of implementing a specific ML task, such as building and deploying a regression model. Pipeline components are built on TFX libraries. The key TFX libraries are the following:

TFDV helps you understand, validate, and monitor ML data at scale. The library is used at Google every day to analyze and validate petabytes of data, and has a proven track record in helping TFX users maintain the health of their ML pipelines. You can use TFDV in these ways:

  • In the exploratory data analysis phase, you can use TFDV to visualize, analyze, and understand the data, as well as to produce the data schema that's expected by the pipeline. This schema acts like a contract between the data and the ML pipeline. If the contract is violated, either the data or the pipeline needs to be fixed.
  • In the pipeline production phase, the schema is used to validate the new data that's used for training, evaluation, and serving so that you can identify any anomalies in the data schema or distribution.

The following figure illustrates TFDV usage in the exploratory data analysis phase and in the pipeline production phase. In the diagram, different shapes (for example, the shapes for Stats, Schema, and Anomalies) represent different artifacts created by the pipeline.

Flow of processing when using TFDV, showing artifacts created at each stage.

TFDV in the exploratory phase

In the exploratory phase, you typically perform the following steps using TFDV:

  1. Compute the summary statistics of your dataset. The statistics can be generated from the full set, but they are usually generated from a large sample. Therefore, this type of computation needs to be done at scale.
  2. Explore the computed statistics visually to understand information about the data, such as data types, domains, distributions, ranges, tendency, spread, missing values, and so on. This helps identify any data preprocessing and feature engineering that's required before the ML model can be built.
  3. Infer an initial schema from the automatically generated statistics.
  4. Alter, update, and fix the schema manually. The produced schema acts as a contract between the model and the incoming data. If any incoming data violates the schema, anomalies are detected. Updates to the schema include the following:

    • Changing the expected data types for features.
    • Adding vocabulary to categorical features.
    • Specifying whether a feature is required (that is, whether it allows missing values).
    • Defining a domain of a feature. For example, a feature that's a measurement of an angle could be checked for values in the 0-360 range.

TFDV in the production phase

In the production phase, you typically perform the following tasks using TFDV:

  1. Compute statistics over the new data splits: training, evaluation, testing in re-training workflows, and serving in prediction monitoring workflows.
  2. Validate the statistics against the predefined schema.
  3. Determine whether there are anomalies, which can include schema skew or data skew.

If anomalies are found, you take one of several actions, depending on the type of the anomalies and your use case. The actions can include the following:

  • Fix and update the data schema that's used by the ML pipeline.
  • Update the preprocessing and feature engineering steps so that they handle your data.
  • Update your model architecture to take into account the new or altered features.
  • Retrain your model to handle the changes in the data patterns.

The TFDV core API for computing data statistics is built as an Apache Beam component. Apache Beam is an open-source SDK for building streaming and batch pipelines that scales horizontally and runs on any execution engine. Running on Apache Beam allows TFDV to run anywhere and to run computations at big-data scale.

TFDV core APIs

The TFDV API has built-in functionality to execute statistics computations directly on Dataflow. Dataflow is a managed service that runs Apache Beam data processing pipelines at scale. It provides autoscaling functionality for workers, and it integrates natively with Google Cloud services for machine learning (AI Platform), data warehousing (BigQuery), and data lakes (Cloud Storage).

The following table provides an overview of the TFDV APIs that are used for the functionality discussed previously in the exploration and production phases.

Functionality Description Outputs
Compute statistics Computes descriptive statistics, which provides an overview of the data in terms of the features and feature value distributions.

Data can be loaded directly from an in-memory pandas DataFrame instance, or from a CSV file or TFRecord file that's stored locally or in Cloud Storage.

Statistics proto buffer.
Visualize statistics Visualizes the statistics generated from the data. Facets visualization of the statistics in a notebook environment to interactively explore the data.
Infer schema Infers and proposes an initial schema based on statistics computed over data that's passed to the API. Statistics proto buffer. You can display the auto-generated schema using the API, modify the schema as needed, check it into a version control system, and push it explicitly into the pipeline for further validation.
Validate data Detects anomalies by matching statistics from any new data to the schema.

To detect training serving skew, you pass the current data using the statistics argument, and pass new serving data using the serving_stats argument.

To detect drift, you pass statistics from previous periods using the previous_statistics argument.

Anomalies proto buffer. You can display the anomalies from a Jupyter notebook using the API.

For information about which anomalies are supported, see the GitHub repository.

Display anomalies Visualizes any anomalies that are detected A notebook-based visualization of the detected anomalies.

What's next?