Analyzing training-serving skew with TensorFlow Data Validation

Last reviewed 2021-03-12 UTC

This document is the third in a series that shows you how to monitor machine learning (ML) models that are deployed in AI Platform Prediction for data skew detection. This guide focuses on the concepts and implementation methods for detecting training-serving data skew by using TensorFlow Data Validation (TFDV).

This guide is for data scientists and ML engineers who want to monitor how serving data changes over time and who want to identify data skews and anomalies. It assumes that you have some experience with Google Cloud, with BigQuery, and with Jupyter notebooks.

The series consists of the following guides:

The code for the process described in this document is incorporated into Jupyter notebooks. The notebooks are in a GitHub Repository.

Understanding data drift and concept drift

Monitoring the predictive performance of an ML model in production has emerged as a crucial area of MLOps. Two common causes of decay in a model's predictive performance over time are the following:

  • A skew grows between training data and serving data. This issue is known as data drift.
  • The interpretation of the relationship between the input predictors and the target feature evolves. This issue is known as concept drift.

In data drift, the production data that a model receives for scoring has diverged from the dataset that was used to train, tune, and evaluate the model. The discrepancies between training data and serving data can usually be classified as schema skews or distribution skews:

  • Schema skew occurs when training data and serving data don't conform to the same schema. Schema skew is often caused by faults or changes in the upstream process that generates the serving data. Schema deviations between training and serving data can include the following:

    • Inconsistent features—for example, a new feature is added to the serving data.
    • Inconsistent feature types—for example, a numerical feature that was an integer number in training data is a real number in serving data.
    • Inconsistent feature domains—for example, a value in a categorical feature disappears, or there's a change in the range of numerical features.
  • Distribution skew occurs when the distribution of feature values for training data is significantly different from serving data. This skew can be the result of choosing the wrong training dataset to represent real-world data. This skew can also happen naturally as new trends and patterns emerge in the data due to the changes in the dynamics of the environment. Examples include changes in the prices of real estate or a change in the popularity of fashion items.

Concept drift means that your interpretation of the data has changed. Often, concept drift implies that the mapping of input features to labels that are used during training is no longer valid, or that a novel class or a range of label values has appeared. Concept drift is often the result of a change in the process you're attempting to model. It can also be an evolution of your understanding of this process.

This document focuses on the data drift scenario—specifically on detecting schema skews and feature distribution skews between the training data and the serving data.

Architecture overview

The following diagram shows the environment that's described in this document.

Architecture for the flow that's created in this tutorial series.

In this architecture, AI Platform Prediction request-response logging logs a sample of online requests into a BigQuery table. You can then parse this data, compute descriptive statistics, and visualize data skew and data drift by using Vertex AI Workbench user-managed notebooks and TFDV.

Capturing serving data

An earlier document in this series, Logging serving requests by using AI Platform Prediction, shows how to use AI Platform Prediction to deploy a Keras model for online prediction. It also shows you how to enable request-response logging, which logs a sample of online prediction requests and responses to a BigQuery table.

The logged request instances and response predictions are stored in a raw form. The preceding document in this series shows how you can parse these raw logs and create a structured view to analyze input features and output predictions. That document also describes how to use Looker Studio to visualize how some feature values change over time. The visualization also helps you spot some outliers in the data that might cause prediction skews.

Detecting serving data skews

It can be difficult to manually identify the potential skews and anomalies in your data simply by looking at the visualization. It's especially difficult if you have a large number of input features. Therefore, you need a scalable and automated approach in production to proactively highlight potential issues in the serving data.

This document describes how to use TFDV to detect, analyze, and visualize data skews. TFDV helps analyze the serving data logs against the expected schema and against the statistics generated from the training data in order to identify anomalies and to detect training-serving skews. As shown in the earlier architecture diagram, you use user-managed notebooks for using TFDV tools interactively.

Using TFDV for detecting training-serving skew

The following diagram illustrates the workflow for how to use TFDV to detect and analyze skews in request-response serving logs in BigQuery.

Workflow for detecting training-server skew.

The workflow consists of the following phases:

  1. Generating baseline statistics and a reference schema from the training data. You can then use the baseline statistics and reference schema to validate the serving data.
  2. Detecting serving data skew. This process produces serving data statistics and identifies any anomalies that are detected when validating the serving data against the reference schema.
  3. Analyzing and visualizing validation output. The anomalies produced are then visualized and analyzed, and the statistics are visualized to show distribution skews.

Objectives

  • Generate baseline statistics and a reference schema.
  • Detect server skew.
  • Analyze and visualize validation outputs.

Costs

In this document, you use the following billable components of Google Cloud:

To generate a cost estimate based on your projected usage, use the pricing calculator. New Google Cloud users might be eligible for a free trial.

Before you begin

Before you begin, you must complete part one and part two of this series.

After you complete these parts, you have the following:

  • A Vertex AI Workbench user-managed notebooks instance that uses TensorFlow 2.3.
  • A clone of the GitHub repository that has the Jupyter notebook that you need for this guide.
  • A BigQuery table that contains request-response logs and a view that parses the raw request and response data points.

The Jupyter notebook for this scenario

An end-to-end process for this workflow is coded into a Jupyter notebook on the GitHub repository that's associated with this document. The steps in the notebook are based on the Covertype dataset from UCI Machine Learning Repository, which is the same dataset that was used for sample data in the previous documents in this series.

Configuring notebook settings

In this section, you prepare the Python environment and set variables that you need in order to run the code for the scenario.

  1. If you don't already have the user-managed notebooks instance from part one open, do the following:

    1. In the Google Cloud console, go to the Notebooks page.

      Go to Notebooks

    2. On the User-managed notebooks tab, select the notebook, and then click Open Jupyterlab. The JupyterLab environment opens in your browser.

    3. In the file browser, open the mlops-on-gcp file, then navigate to the skew-detection directory.

  2. Open the 03-covertype-drift-detection_tfdv.ipynb notebook.

  3. In the notebook, under Setup, run the Install packages and dependencies cell to install the required Python packages and configure the environment variables.

  4. Under Configure Google Cloud environment settings, set the following variables:

    • PROJECT_ID: The ID of the Google Cloud project where the BigQuery dataset for the request-response data is logged.
    • BUCKET: The name of the Cloud Storage bucket where produced artifacts are stored.
    • BQ_DATASET_NAME: The name of the BigQuery dataset where the request-response logs are stored.
    • BQ_VIEW_NAME: The name of the BigQuery view that you created in part two of the series.
    • MODEL_NAME: The name of the model that's deployed to AI Platform Prediction.
    • VERSION_NAME: The version name of the model that's deployed to AI Platform Prediction. The version is in the format vN; for example, v1.
  5. Run the remaining cells under Setup to finish configuring the environment:

    1. Authenticate your GCP account
    2. Import libraries
    3. Create a local workspace

Generating baseline statistics and a reference schema

Run the tasks in the first section of the notebook to have the training data generate baseline statistics and reference schema. Download the data as CSV files and then use the tfdv.generate_statistics_from_csv method to compute the baseline statistics and put them in the baseline_stats variable. The code uses the tfdv.infer_schema method to infer the reference schema of the training data and put it into the reference_schema variable.

You can modify the inferred schema, as shown in the following code snippet:

baseline_stats = tfdv.generate_statistics_from_csv(
   data_location=TRAIN_DATA,
   stats_options=tfdv.StatsOptions(schema=reference_schema,
       sample_count=10000))

reference_schema = tfdv.infer_schema(baseline_stats)

# Set Soil_Type to be categorical
tfdv.set_domain(reference_schema, 'Soil_Type', schema_pb2.IntDomain(
   name='Soil_Type', is_categorical=True))

# Set Cover_Type to be categorical
tfdv.set_domain(reference_schema, 'Cover_Type', schema_pb2.IntDomain(
   name='Cover_Type', is_categorical=True))

# Set max and min values for Elevation
tfdv.set_domain(reference_schema,
   'Elevation',
   tfdv.utils.schema_util.schema_pb2.IntDomain(min=1000, max=5000))

# Allow no missing values
tfdv.get_feature(reference_schema,
   'Slope').presence.min_fraction = 1.0

# Set distribution skew detector for Wilderness_Area
tfdv.get_feature(reference_schema,
   'Wilderness_Area').skew_comparator.infinity_norm.threshold = 0.05

You display the generated reference schema by using the tfdv.display_schema method. When you do, you see a listing similar to the following:

Listing of the generated schema.

Detecting serving data skew

In the next section of the notebook, you run tasks to detect serving data skews. The process consists of the following steps:

  1. The request-response serving data is read from the vw_covertype_classifier_logs_v1 view in BigQuery. The view presents parsed feature values and predictions that are logged in the covertype_classifier_logs table. The table contains raw instance requests and prediction responses from the covertype_classifier model that was deployed earlier to AI Platform Prediction.
  2. The logged serving data is saved to a CSV file for use by TFDV.
  3. TFDV computes serving statistics from the CSV data file for different slices of the serving data by using the tfdv.experimental_get_feature_value_slicer method, as shown in the following code snippet:

    slice_fn = tfdv.experimental_get_feature_value_slicer(features={'time': None})
    
    serving_stats_list = tfdv.generate_statistics_from_csv(
    data_location=serving_data_file,
    stats_options=tfdv.StatsOptions(
        slice_functions=[slice_fn],
        schema=reference_schema))
    
  4. TFDV validates each serving_stats slice against the reference schema (in the reference_schema variable) by using the tfdv.validate_statistics method. This process generates anomalies, as shown in the following code snippet:

    anomalies_list = []
    
    for slice_key in slice_keys[1:]:
      serving_stats = tfdv.get_slice_stats(serving_stats_list, slice_key)
      anomalies = tfdv.validate_statistics(
           serving_stats, schema=reference_schema,
           previous_statistics=baseline_stats)
      anomalies_list.append(anomalies)
    

TFDV checks for anomalies by comparing a schema and the statistics protocol buffer. For information about the anomaly types that TFDV can detect, see TensorFlow Data Validation Anomalies Reference. The schema and statistics fields are used to detect each anomaly type and the conditions under which each anomaly type is detected.

Analyzing and visualizing validation outputs

In the third section of the notebook, you visualize anomalies. The code uses the tfdv.visualize_statistics method to visualize a serving_stats slice against the training data baseline stats in order to highlight distribution skews.

The following screenshot shows an example of a visualization for distributions for the Elevation and Aspect numerical features.

Graphs showing anomalies for elevation and aspect.

The following screenshots show examples of visualizations for the Wildness_Area and Cover_Type categorical features. Notice that in the categorical features visualization, Cover_Type is the target feature, and that the distribution shows prediction skew.

Wildness area visualization.

Cover type visualization.

To inspect the generated anomalies in each serving_stats slice, you can call the tfdv.display_anomalies method. The following listing shows an example of the detected anomalies.

Listing of anomalies found by the display_anomalies method.

In addition to creating visualizations by using the TFDV visualization API, you can use a Python plotting library to visualize the statistics in your notebook.

The following plots shows how the mean values of the numerical feature in the serving statistics drift across time slices, and how the drift compares to the mean values in the baseline statistics:

Graph of mean values for elevation.

Graph of mean values for aspect.

Graph of mean values for slope.

The following plot shows how the value distribution of the categorical feature in the serving statistics drifts across time slices:

Bar graphs showing value distribution of features across time.

Clean up

If you plan to continue with the rest of this series, keep the resources that you've already created. Otherwise, delete the project that contains the resources, or keep the project and delete the individual resources.

Delete the project

  1. In the Google Cloud console, go to the Manage resources page.

    Go to Manage resources

  2. In the project list, select the project that you want to delete, and then click Delete.
  3. In the dialog, type the project ID, and then click Shut down to delete the project.

What's next