使用超参数调节来提高模型性能


本教程将介绍如何在 BigQuery ML 中使用超参数调节来调整机器学习模型并提升其性能。

您可以通过指定 CREATE MODEL 语句的 NUM_TRIALS 选项以及其他特定于模型的选项来执行超参数调节。设置这些选项后,BigQuery ML 会训练多个版本(即实验)的模型,每个版本的参数略有不同,并返回效果最佳的实验。

本教程使用公共 tlc_yellow_trips_2018 示例表,其中包含 2018 年纽约市出租车行程的相关信息。

目标

本教程将指导您完成以下任务:

  • 使用 CREATE MODEL 语句创建基准线性回归模型。
  • 使用 ML.EVALUATE 函数评估基准模型。
  • CREATE MODEL 语句与超参数调节选项结合使用,训练 20 次线性回归模型试验。
  • 使用 ML.TRIAL_INFO 函数查看实验。
  • 使用 ML.EVALUATE 函数评估试验。
  • 使用 ML.PREDICT 函数从试验中的最佳模型获取出租车行程的预测结果。

费用

本教程使用 Google Cloud的以下计费组件:

  • BigQuery
  • BigQuery ML

如需详细了解 BigQuery 费用,请参阅 BigQuery 价格页面。

准备工作

  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  3. Make sure that billing is enabled for your Google Cloud project.

  4. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  5. Make sure that billing is enabled for your Google Cloud project.

  6. 新项目会自动启用 BigQuery。如需在现有项目中激活 BigQuery,请前往

    Enable the BigQuery API.

    Enable the API

所需权限

  • 如需创建数据集,您需要拥有 bigquery.datasets.create IAM 权限。
  • 如需创建连接资源,您需要以下权限:

    • bigquery.connections.create
    • bigquery.connections.get
  • 如需创建模型,您需要以下权限:

    • bigquery.jobs.create
    • bigquery.models.create
    • bigquery.models.getData
    • bigquery.models.updateData
    • bigquery.connections.delegate
  • 如需运行推理,您需要以下权限:

    • bigquery.models.getData
    • bigquery.jobs.create

如需详细了解 BigQuery 中的 IAM 角色和权限,请参阅 IAM 简介

创建数据集

创建 BigQuery 数据集以存储您的机器学习模型:

  1. 在 Google Cloud 控制台中,前往 BigQuery 页面。

    转到 BigQuery 页面

  2. 探索器窗格中,点击您的项目名称。

  3. 点击 查看操作 > 创建数据集

    创建数据集。

  4. 创建数据集页面上,执行以下操作:

    • 数据集 ID 部分,输入 bqml_tutorial

    • 位置类型部分,选择多区域,然后选择 US (multiple regions in United States)(美国[美国的多个区域])。

      公共数据集存储在 US 多区域中。为简单起见,请将数据集存储在同一位置。

    • 保持其余默认设置不变,然后点击创建数据集

      创建数据集页面。

创建训练数据表

根据 tlc_yellow_trips_2018 表数据的一部分创建训练数据表。

请按照以下步骤创建表格:

  1. 在 Google Cloud 控制台中,前往 BigQuery 页面。

    转到 BigQuery

  2. 在查询编辑器中,粘贴以下查询,然后点击运行

    CREATE OR REPLACE TABLE `bqml_tutorial.taxi_tip_input`
    AS
    SELECT * EXCEPT (tip_amount), tip_amount AS label
    FROM
      `bigquery-public-data.new_york_taxi_trips.tlc_yellow_trips_2018`
    WHERE
      tip_amount IS NOT NULL
    LIMIT 100000;

创建基准线性回归模型

创建一个不进行超参数调节的线性回归模型,并基于 taxi_tip_input 表数据对其进行训练。

请按照以下步骤创建模型:

  1. 在 Google Cloud 控制台中,前往 BigQuery 页面。

    转到 BigQuery

  2. 在查询编辑器中,粘贴以下查询,然后点击运行

    CREATE OR REPLACE MODEL `bqml_tutorial.baseline_taxi_tip_model`
      OPTIONS (
        MODEL_TYPE = 'LINEAR_REG'
      )
    AS
    SELECT
      *
    FROM
      `bqml_tutorial.taxi_tip_input`;

    查询大约需要 2 分钟才能完成。

评估基准模型

使用 ML.EVALUATE 函数评估模型的性能。ML.EVALUATE 函数根据模型训练期间计算得出的评估指标评估模型返回的内容评分预测结果。

请按照以下步骤评估模型:

  1. 在 Google Cloud 控制台中,前往 BigQuery 页面。

    转到 BigQuery

  2. 在查询编辑器中,粘贴以下查询,然后点击运行

    SELECT *
    FROM
      ML.EVALUATE(MODEL `bqml_tutorial.baseline_taxi_tip_model`);

    结果类似于以下内容:

    +---------------------+--------------------+------------------------+-----------------------+---------------------+---------------------+
    | mean_absolute_error | mean_squared_error | mean_squared_log_error | median_absolute_error |      r2_score       | explained_variance  |
    +---------------------+--------------------+------------------------+-----------------------+---------------------+---------------------+
    |  2.5853895559690323 | 23760.416358496139 |   0.017392406523370374 | 0.0044248227819481123 | -1934.5450533482465 | -1934.3513857946277 |
    +---------------------+--------------------+------------------------+-----------------------+---------------------+---------------------+
    

基准模型的 r2_score 值为负,表示模型与数据的拟合度较差;R2 得分越接近 1,模型拟合度就越高。

创建具有超参数调节的线性回归模型

创建具有超参数调节的线性回归模型,并使用 taxi_tip_input 表数据对其进行训练。

您可以在 CREATE MODEL 语句中使用以下超参数调节选项:

  • NUM_TRIALS 选项,用于将试验次数设置为 20。
  • MAX_PARALLEL_TRIALS 选项,用于在每个训练作业中运行两次试验,总共十个作业和二十次试验。这样可以缩短训练时间。不过,两个并发试验无法从彼此的训练结果中受益。
  • L1_REG 选项,用于在不同的试验中尝试不同的 L1 正则化值。L1 正则化会从模型中移除不相关的特征,有助于防止出现过拟合

模型支持的其他超参数调节选项使用其默认值,如下所示:

  • L1_REG0
  • HPARAM_TUNING_ALGORITHM'VIZIER_DEFAULT'
  • HPARAM_TUNING_OBJECTIVES['R2_SCORE']

请按照以下步骤创建模型:

  1. 在 Google Cloud 控制台中,前往 BigQuery 页面。

    转到 BigQuery

  2. 在查询编辑器中,粘贴以下查询,然后点击运行

    CREATE OR REPLACE MODEL `bqml_tutorial.hp_taxi_tip_model`
      OPTIONS (
        MODEL_TYPE = 'LINEAR_REG',
        NUM_TRIALS = 20,
        MAX_PARALLEL_TRIALS = 2,
        L1_REG = HPARAM_RANGE(0, 5))
    AS
    SELECT
      *
    FROM
      `bqml_tutorial.taxi_tip_input`;

    查询大约需要 20 分钟才能完成。

获取有关训练试用版的信息

使用 ML.TRIAL_INFO 函数获取有关所有试验的信息,包括其超参数值、目标和状态。此函数还会根据这些信息返回有关哪个实验效果最佳的信息。

如需获取试用信息,请按以下步骤操作:

  1. 在 Google Cloud 控制台中,前往 BigQuery 页面。

    转到 BigQuery

  2. 在查询编辑器中,粘贴以下查询,然后点击运行

    SELECT *
    FROM
      ML.TRIAL_INFO(MODEL `bqml_tutorial.hp_taxi_tip_model`)
    ORDER BY is_optimal DESC;

    结果类似于以下内容:

    +----------+-------------------------------------+-----------------------------------+--------------------+--------------------+-----------+---------------+------------+
    | trial_id |           hyperparameters           | hparam_tuning_evaluation_metrics  |   training_loss    |     eval_loss      |  status   | error_message | is_optimal |
    +----------+-------------------------------------+-----------------------------------+--------------------+--------------------+-----------+---------------+------------+
    |        7 |      {"l1_reg":"4.999999999999985"} |  {"r2_score":"0.653653627638174"} | 4.4677841296238165 |  4.478469742512195 | SUCCEEDED | NULL          |       true |
    |        2 |  {"l1_reg":"2.402163664510254E-11"} | {"r2_score":"0.6532493667964732"} |  4.457692508421795 |  4.483697081650438 | SUCCEEDED | NULL          |      false |
    |        3 |  {"l1_reg":"1.2929452948742316E-7"} |  {"r2_score":"0.653249366811995"} |   4.45769250849513 |  4.483697081449748 | SUCCEEDED | NULL          |      false |
    |        4 |  {"l1_reg":"2.5787102060628228E-5"} | {"r2_score":"0.6532493698925899"} |  4.457692523040582 |  4.483697041615808 | SUCCEEDED | NULL          |      false |
    |      ... |                             ...     |                           ...     |              ...   |             ...    |       ... |          ...  |        ... |
    +----------+-------------------------------------+-----------------------------------+--------------------+--------------------+-----------+---------------+------------+
    

    is_optimal 列值表示试验 7 是调优返回的最优模型。

评估调优后的模型试用

使用 ML.EVALUATE 函数评估实验的效果。ML.EVALUATE 函数根据所有试验在训练期间计算得出的评估指标,评估模型返回的内容评分预测结果。

请按照以下步骤评估模型试用:

  1. 在 Google Cloud 控制台中,前往 BigQuery 页面。

    转到 BigQuery

  2. 在查询编辑器中,粘贴以下查询,然后点击运行

    SELECT *
    FROM
      ML.EVALUATE(MODEL `bqml_tutorial.hp_taxi_tip_model`)
    ORDER BY r2_score DESC;

    结果类似于以下内容:

    +----------+---------------------+--------------------+------------------------+-----------------------+--------------------+--------------------+
    | trial_id | mean_absolute_error | mean_squared_error | mean_squared_log_error | median_absolute_error |      r2_score      | explained_variance |
    +----------+---------------------+--------------------+------------------------+-----------------------+--------------------+--------------------+
    |        7 |   1.151814398002232 |  4.109811493266523 |     0.4918733252641176 |    0.5736103414025084 | 0.6652110305659145 | 0.6652144696114834 |
    |       19 |  1.1518143358927102 |  4.109811921460791 |     0.4918672150119582 |    0.5736106106914161 | 0.6652109956848206 | 0.6652144346901685 |
    |        8 |   1.152747850702547 |  4.123625876152422 |     0.4897808307399327 |    0.5731702310239184 | 0.6640856984144734 |  0.664088410199906 |
    |        5 |   1.152895108945439 |  4.125775524878872 |    0.48939088205957937 |    0.5723300569616766 | 0.6639105860807425 | 0.6639132416838652 |
    |      ... |                ...  |                ... |                    ... |                   ... |                ... |                ... |
    +----------+---------------------+--------------------+------------------------+-----------------------+--------------------+--------------------+
    

    最优模型(即第 7 次试验)的 r2_score 值为 0.66521103056591446,比基准模型有了显著改进。

您可以通过在 ML.EVALUATE 函数中指定 TRIAL_ID 参数来评估特定试验。

如需详细了解 ML.TRIAL_INFO 目标与 ML.EVALUATE 评估指标之间的区别,请参阅模型部署函数

使用调整后的模型预测出租车小费

使用调优返回的最佳模型来预测不同出租车行程的小费。除非您通过指定 TRIAL_ID 参数选择其他试验,否则 ML.PREDICT 函数会自动使用最佳模型。预测结果会返回在 predicted_label 列中。

请按照以下步骤获取预测结果:

  1. 在 Google Cloud 控制台中,前往 BigQuery 页面。

    转到 BigQuery

  2. 在查询编辑器中,粘贴以下查询,然后点击运行

    SELECT *
    FROM
      ML.PREDICT(
        MODEL `bqml_tutorial.hp_taxi_tip_model`,
        (
          SELECT
            *
          FROM
            `bqml_tutorial.taxi_tip_input`
          LIMIT 5
        ));

    结果类似于以下内容:

    +----------+--------------------+-----------+---------------------+---------------------+-----------------+---------------+-----------+--------------------+--------------+-------------+-------+---------+--------------+---------------+--------------+--------------------+---------------------+----------------+-----------------+-------+
    | trial_id |  predicted_label   | vendor_id |   pickup_datetime   |  dropoff_datetime   | passenger_count | trip_distance | rate_code | store_and_fwd_flag | payment_type | fare_amount | extra | mta_tax | tolls_amount | imp_surcharge | total_amount | pickup_location_id | dropoff_location_id | data_file_year | data_file_month | label |
    +----------+--------------------+-----------+---------------------+---------------------+-----------------+---------------+-----------+--------------------+--------------+-------------+-------+---------+--------------+---------------+--------------+--------------------+---------------------+----------------+-----------------+-------+
    |        7 |  1.343367839584448 | 2         | 2018-01-15 18:55:15 | 2018-01-15 18:56:18 |               1 |             0 | 1         | N                  | 1            |           0 |     0 |       0 |            0 |             0 |            0 | 193                | 193                 |           2018 |               1 |     0 |
    |        7 | -1.176072791783461 | 1         | 2018-01-08 10:26:24 | 2018-01-08 10:26:37 |               1 |             0 | 5         | N                  | 3            |        0.01 |     0 |       0 |            0 |           0.3 |         0.31 | 158                | 158                 |           2018 |               1 |     0 |
    |        7 |  3.839580104168765 | 1         | 2018-01-22 10:58:02 | 2018-01-22 12:01:11 |               1 |          16.1 | 1         | N                  | 1            |        54.5 |     0 |     0.5 |            0 |           0.3 |         55.3 | 140                | 91                  |           2018 |               1 |     0 |
    |        7 |  4.677393985230036 | 1         | 2018-01-16 10:14:35 | 2018-01-16 11:07:28 |               1 |            18 | 1         | N                  | 2            |        54.5 |     0 |     0.5 |            0 |           0.3 |         55.3 | 138                | 67                  |           2018 |               1 |     0 |
    |        7 |  7.938988937253062 | 2         | 2018-01-16 07:05:15 | 2018-01-16 08:06:31 |               1 |          17.8 | 1         | N                  | 1            |        54.5 |     0 |     0.5 |            0 |           0.3 |        66.36 | 132                | 255                 |           2018 |               1 | 11.06 |
    +----------+--------------------+-----------+---------------------+---------------------+-----------------+---------------+-----------+--------------------+--------------+-------------+-------+---------+--------------+---------------+--------------+--------------------+---------------------+----------------+-----------------+-------+
    

清理

为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

  • 删除您在教程中创建的项目。
  • 或者,保留项目但删除数据集。

删除数据集

删除项目也将删除项目中的所有数据集和所有表。如果您希望重复使用该项目,则可以删除在本教程中创建的数据集:

  1. 如有必要,请在Google Cloud 控制台中打开 BigQuery 页面。

    转到 BigQuery 页面

  2. 在导航面板中,点击您创建的 bqml_tutorial 数据集。

  3. 点击窗口右侧的删除数据集。此操作会删除相关数据集、表和所有数据。

  4. 删除数据集对话框中,通过输入数据集的名称 (bqml_tutorial) 来确认该删除命令,然后点击删除

删除项目

要删除项目,请执行以下操作:

  1. In the Google Cloud console, go to the Manage resources page.

    Go to Manage resources

  2. In the project list, select the project that you want to delete, and then click Delete.
  3. In the dialog, type the project ID, and then click Shut down to delete the project.

后续步骤