使用入门:使用 Keras 进行训练和预测

Colab 徽标 在 Colab 中以笔记本的形式运行本教程 GitHub 徽标在 GitHub 上查看笔记本

本教程介绍如何使用 Keras Sequential API 在 AI Platform 上训练神经网络,以及如何通过该模型执行预测。

Keras 是用于构建和训练深度学习模型的高级 API。tf.keras 是 TensorFlow 对此 API 的实现。

本教程的前两部分介绍如何使用预先编写的 Keras 代码在 AI Platform 上训练模型、将经过训练的模型部署到 AI Platform,以及通过部署的模型执行在线预测。

本教程的最后一部分深入介绍了用于此模型的训练代码,并确保这些代码与 AI Platform 兼容。如需详细了解如何在 Keras 中以更常规的方式构建机器学习模型,请阅读 TensorFlow 的 Keras 教程

数据集

本教程使用加利福尼亚大学欧文分校机器学习存储库提供的美国人口普查收入数据集。该数据集包含 1994 年人口普查数据库中人员的相关信息,包括年龄、教育程度、婚姻状况、职业以及年收入是否超过 5 万美元。

目标

目标是使用 Keras 训练深度神经网络 (DNN),以根据某人的其他人口普查信息(特征)预测其年收入是否超过 $50000(目标标签)。

本教程重点介绍如何将此模型与 AI Platform 结合使用,而不是仅仅介绍模型本身的设计。但在构建机器学习系统时,请务必始终考虑到潜在问题和意外后果。请参阅有关公平性的机器学习速成课程练习,详细了解人口普查数据集中的偏差来源及机器学习公平性。

费用

本教程使用 Google Cloud (Google Cloud)的以下收费组件:

  • AI Platform Training
  • AI Platform Prediction
  • Cloud Storage

了解 AI Platform Training 价格AI Platform Prediction 价格Cloud Storage 价格,并使用价格计算器根据您的预计使用情况来估算费用。

准备工作

在 AI Platform 中训练和部署模型之前,您必须先完成以下事项:

  • 设置本地开发环境。
  • 设置 Google Cloud 项目,并启用结算功能和必要的 API。
  • 创建 Cloud Storage 存储分区以存储您的训练软件包和经过训练的模型。

设置本地开发环境

您需要以下资源才能完成本教程:

  • Git
  • Python 3
  • virtualenv
  • Google Cloud SDK

有关设置 Python 开发环境的 Google Cloud 指南详细说明了如何满足这些要求。以下步骤提供了一系列简要的说明:

  1. 安装 Python 3

  2. 安装 virtualenv 并创建一个使用 Python 3 的虚拟环境。

  3. 激活该环境。

  4. 完成以下部分中的步骤以安装 Google Cloud SDK。

设置您的 Google Cloud 项目

  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  3. Make sure that billing is enabled for your Google Cloud project.

  4. Enable the AI Platform Training & Prediction and Compute Engine APIs.

    Enable the APIs

  5. Install the Google Cloud CLI.
  6. To initialize the gcloud CLI, run the following command:

    gcloud init
  7. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  8. Make sure that billing is enabled for your Google Cloud project.

  9. Enable the AI Platform Training & Prediction and Compute Engine APIs.

    Enable the APIs

  10. Install the Google Cloud CLI.
  11. To initialize the gcloud CLI, run the following command:

    gcloud init

验证您的 GCP 账号

要设置身份验证,您需要创建服务账号密钥并为服务账号密钥的文件路径设置环境变量。

  1. 创建服务账号:

    1. 在 Google Cloud 控制台中,前往创建服务账号页面。

      转到“创建服务账号”

    2. 服务账号名称字段中,输入一个名称。
    3. 可选:在服务账号说明字段中,输入说明。
    4. 点击创建
    5. 点击选择角色字段。在所有角色下,选择 AI Platform > AI Platform Admin
    6. 点击添加其他角色
    7. 点击选择角色字段。在所有角色下,选择存储 > Storage Object Admin

    8. 点击完成以创建服务账号。

      不要关闭浏览器窗口。您将在下一步骤中用到它。

  2. 创建用于身份验证的服务账号密钥:

    1. 在 Google Cloud 控制台中,点击您创建的服务账号的电子邮件地址。
    2. 点击密钥
    3. 依次点击添加密钥创建新密钥
    4. 点击创建。JSON 密钥文件将下载到您的计算机上。
    5. 点击关闭
  3. 将环境变量 GOOGLE_APPLICATION_CREDENTIALS 设置为包含服务账号密钥的 JSON 文件的文件路径。此变量仅适用于当前的 shell 会话,因此,如果您打开新的会话,请重新设置该变量。

创建 Cloud Storage 存储分区

如果您使用 Cloud SDK 提交训练作业,请将包含训练代码的 Python 软件包上传到 Cloud Storage 存储分区。AI Platform 运行此软件包中的代码。在本教程中,AI Platform 还会将您的作业生成的训练模型保存在同一个存储分区中。然后,您可以基于此输出创建 AI Platform 模型版本,以执行在线预测。

将 Cloud Storage 存储分区的名称设置为环境变量。它在所有 Cloud Storage 存储分区中必须是唯一的:

BUCKET_NAME="your-bucket-name"

选择一个可以使用 AI Platform TrainingAI Platform Prediction 的区域,然后另外创建一个环境变量。例如:

REGION="us-central1"

在您选择的区域中创建 Cloud Storage 存储分区,稍后使用同一区域进行训练和预测。如果存储分区尚不存在,请运行以下命令创建一个:

gsutil mb -l $REGION gs://$BUCKET_NAME

快速入门:在 AI Platform 中进行训练

本教程的这一部分将逐步介绍如何将训练作业提交到 AI Platform。此作业运行的示例代码使用 Keras 根据美国人口普查数据训练深度神经网络。此作业会在 Cloud Storage 存储桶中以 TensorFlow SavedModel 目录的形式输出经过训练的模型。

获取训练代码和依赖项

首先,下载训练代码并更改工作目录:

# Clone the repository of AI Platform samples
git clone --depth 1 https://github.com/GoogleCloudPlatform/cloudml-samples

# Set the working directory to the sample code directory
cd cloudml-samples/census/tf-keras

注意,训练代码设计为 trainer/ 子目录中的 Python 程序包:

# `ls` shows the working directory's contents. The `p` flag adds trailing
# slashes to subdirectory names. The `R` flag lists subdirectories recursively.
ls -pR
.:
README.md  requirements.txt  trainer/

./trainer:
__init__.py  model.py  task.py  util.py

接下来,安装在本地训练模型所需的 Python 依赖项:

pip install -r requirements.txt

当您在 AI Platform 中运行训练作业时,系统将根据您选择的运行时版本预安装依赖项。

在本地训练您的模型

在 AI Platform 上训练之前,请在本地训练作业以验证文件结构和封装是正确的。

对于复杂的或资源密集型作业,建议您根据数据集的一小部分样本在本地进行训练,以验证代码。然后,您可以在 AI Platform 上运行作业以根据整个数据集进行训练。

此样本根据小型数据集运行较为简单快捷的作业,因此本地训练和 AI Platform 作业根据相同的数据运行同样的代码。

运行以下命令以在本地训练模型:

# This is similar to `python -m trainer.task --job-dir local-training-output`
# but it better replicates the AI Platform environment, especially
# for distributed training (not applicable here).
gcloud ai-platform local train \
  --package-path trainer \
  --module-name trainer.task \
  --job-dir local-training-output

在 shell 中观察训练进度。最后,训练应用会导出经过训练的模型并输出如下消息:

Model exported to:  local-training-output/keras_export/1553709223

使用 AI Platform 训练您的模型

接下来,向 AI Platform 提交训练作业。这将在云中运行训练模块,并将经过训练的模型导出到 Cloud Storage。

首先,为您的训练作业命名,然后选择 Cloud Storage 存储分区中的目录以保存中间文件和输出文件。将它们设置为环境变量。例如:

JOB_NAME="my_first_keras_job"
JOB_DIR="gs://$BUCKET_NAME/keras-job-dir"

运行以下命令来封装 trainer/ 目录,将其上传到指定的 --job-dir,并指示 AI Platform 运行该软件包中的 trainer.task 模块。

使用 --stream-logs 标志,您可以在 shell 中查看训练日志。您还可以在 Google Cloud 控制台中查看日志和其他作业详细信息。

gcloud ai-platform jobs submit training $JOB_NAME \
  --package-path trainer/ \
  --module-name trainer.task \
  --region $REGION \
  --python-version 3.7 \
  --runtime-version 1.15 \
  --job-dir $JOB_DIR \
  --stream-logs

这可能比本地训练花费的时间更长,但您可以在 shell 中以类似的方式观察训练进度。最后,训练作业将经过训练的模型导出到您的 Cloud Storage 存储分区并输出如下消息:

INFO    2019-03-27 17:57:11 +0000   master-replica-0        Model exported to:  gs://your-bucket-name/keras-job-dir/keras_export/1553709421
INFO    2019-03-27 17:57:11 +0000   master-replica-0        Module completed; cleaning up.
INFO    2019-03-27 17:57:11 +0000   master-replica-0        Clean up finished.
INFO    2019-03-27 17:57:11 +0000   master-replica-0        Task completed successfully.

超参数调节

您可以使用随附的 hptuning_config.yaml 配置文件选择性地执行超参数调节。此文件指示 AI Platform 调节批次大小和学习速率,以便通过多次试验进行训练,从而最大限度地提高准确率。

在此示例中,训练代码使用 TensorBoard 回调,该回调在训练期间创建 TensorFlow Summary Event。AI Platform 使用这些事件跟踪要优化的指标。详细了解 AI Platform 训练中的超参数调节

gcloud ai-platform jobs submit training ${JOB_NAME}_hpt \
  --config hptuning_config.yaml \
  --package-path trainer/ \
  --module-name trainer.task \
  --region $REGION \
  --python-version 3.7 \
  --runtime-version 1.15 \
  --job-dir $JOB_DIR \
  --stream-logs

有关在 AI Platform 中进行在线预测的快速入门

本部分介绍如何使用 AI Platform 和上一部分中经过训练的模型来通过某人的其他人口普查信息预测其收入等级。

在 AI Platform 中创建模型和版本资源

如需使用您在训练快速入门中训练和导出的模型执行在线预测,请在 AI Platform 中创建模型资源和版本资源。版本资源是实际使用您训练的模型执行预测的资源。此结构允许您多次调整和重新训练模型,并在 AI Platform 中组织所有版本。详细了解模型和版本

首先,命名并创建模型资源:

MODEL_NAME="my_first_keras_model"

gcloud ai-platform models create $MODEL_NAME \
  --regions $REGION
Created ml engine model [projects/your-project-id/models/my_first_keras_model].

接下来,创建模型版本。训练快速入门中的训练作业将带有时间戳的 TensorFlow SavedModel 目录导出到您的 Cloud Storage 存储桶。AI Platform 会使用此目录创建模型版本。详细了解 SavedModel 和 AI Platform

您可以在训练作业的日志中找到此目录的路径。查找如下所示的行:

Model exported to:  gs://your-bucket-name/keras-job-dir/keras_export/1545439782

执行以下命令以标识 SavedModel 目录并使用它来创建模型版本资源:

MODEL_VERSION="v1"

# Get a list of directories in the `keras_export` parent directory. Then pick
# the directory with the latest timestamp, in case you've trained multiple
# times.
SAVED_MODEL_PATH=$(gsutil ls $JOB_DIR/keras_export | head -n 1)

# Create model version based on that SavedModel directory
gcloud ai-platform versions create $MODEL_VERSION \
  --model $MODEL_NAME \
  --region $REGION \
  --runtime-version 1.15 \
  --python-version 3.7 \
  --framework tensorflow \
  --origin $SAVED_MODEL_PATH

准备用于预测的输入数据

为了接收有效且有用的预测结果,您必须使用与预处理训练数据相同的方式预处理用于预测的输入数据。在生产系统中,建议您创建一个预处理流水线,这样您可以在训练和预测时使用同一流水线。

在本练习中,使用训练程序包的数据加载代码从评估数据中选择随机样本。这些数据的形式与在每个训练周期后评估准确率所用的形式相同,因此这些数据无需进一步的预处理即可用于发送测试预测。

从当前工作目录打开 Python 解释器 (python),以运行下面几个代码段:

from trainer import util

_, _, eval_x, eval_y = util.load_data()

prediction_input = eval_x.sample(20)
prediction_targets = eval_y[prediction_input.index]

prediction_input
age workclass education_num marital_status occupation relationship race capital_gain capital_loss hours_per_week native_country
1979 0.901213 1 1.525542 2 9 0 4 -0.144792 -0.217132 -0.437544 38
2430 -0.922154 3 -0.419265 4 2 3 4 -0.144792 -0.217132 -0.034039 38
4214 -1.213893 3 -0.030304 4 10 1 4 -0.144792 -0.217132 1.579979 38
10389 -0.630415 3 0.358658 4 0 3 4 -0.144792 -0.217132 -0.679647 38
14525 -1.505632 3 -1.586149 4 7 3 0 -0.144792 -0.217132 -0.034039 38
15040 -0.119873 5 0.358658 2 2 0 4 -0.144792 -0.217132 -0.841048 38
8409 0.244801 3 1.525542 2 9 0 4 -0.144792 -0.217132 1.176475 6
10628 0.098931 1 1.525542 2 9 0 4 0.886847 -0.217132 -0.034039 38
10942 0.390670 5 -0.030304 2 4 0 4 -0.144792 -0.217132 4.727315 38
5129 1.120017 3 1.136580 2 12 0 4 -0.144792 -0.217132 -0.034039 38
2096 -1.286827 3 -0.030304 4 11 3 4 -0.144792 -0.217132 -1.648058 38
12463 -0.703350 3 -0.419265 2 7 5 4 -0.144792 4.502280 -0.437544 38
8528 0.536539 3 1.525542 4 3 4 4 -0.144792 -0.217132 -0.034039 38
7093 -1.359762 3 -0.419265 4 6 3 2 -0.144792 -0.217132 -0.034039 38
12565 0.536539 3 1.136580 0 11 2 2 -0.144792 -0.217132 -0.034039 38
5655 1.338821 3 -0.419265 2 2 0 4 -0.144792 -0.217132 -0.034039 38
2322 0.682409 3 1.136580 0 12 3 4 -0.144792 -0.217132 -0.034039 38
12652 0.025997 3 1.136580 2 11 0 4 -0.144792 -0.217132 0.369465 38
4755 -0.411611 3 -0.419265 2 11 0 4 -0.144792 -0.217132 1.176475 38
4413 0.390670 6 1.136580 4 4 1 4 -0.144792 -0.217132 -0.034039 38

注意,分类字段(如 occupation)已转换为整数(使用的映射与训练所用的映射相同)。数值字段(如 age)已调整为 z-score。某些字段已从原始数据中删除。将同一示例的预测输入数据与原始数据进行比较:

import pandas as pd

_, eval_file_path = util.download(util.DATA_DIR)
raw_eval_data = pd.read_csv(eval_file_path,
                            names=util._CSV_COLUMNS,
                            na_values='?')

raw_eval_data.iloc[prediction_input.index]
age workclass fnlwgt education education_num marital_status occupation relationship race gender capital_gain capital_loss hours_per_week native_country income_bracket
1979 51 Local-gov 99064 Masters 14 Married-civ-spouse Prof-specialty Husband White Male 0 0 35 United-States <=50K
2430 26 Private 197967 HS-grad 9 Never-married Craft-repair Own-child White Male 0 0 40 United-States <=50K
4214 22 Private 221694 Some-college 10 Never-married Protective-serv Not-in-family White Male 0 0 60 United-States <=50K
10389 30 Private 96480 Assoc-voc 11 Never-married Adm-clerical Own-child White Female 0 0 32 United-States <=50K
14525 18 Private 146225 10th 6 Never-married Other-service Own-child Amer-Indian-Eskimo Female 0 0 40 United-States <=50K
15040 37 Self-emp-not-inc 50096 Assoc-voc 11 Married-civ-spouse Craft-repair Husband White Male 0 0 30 United-States <=50K
8409 42 Private 102988 Masters 14 Married-civ-spouse Prof-specialty Husband White Male 0 0 55 Ecuador >50K
10628 40 Local-gov 284086 Masters 14 Married-civ-spouse Prof-specialty Husband White Male 7688 0 40 United-States >50K
10942 44 Self-emp-not-inc 52505 Some-college 10 Married-civ-spouse Farming-fishing Husband White Male 0 0 99 United-States <=50K
5129 54 Private 106728 Bachelors 13 Married-civ-spouse Tech-support Husband White Male 0 0 40 United-States <=50K
2096 21 Private 190916 Some-college 10 Never-married Sales Own-child White Female 0 0 20 United-States <=50K
12463 29 Private 197565 HS-grad 9 Married-civ-spouse Other-service Wife White Female 0 1902 35 United-States >50K
8528 46 Private 193188 Masters 14 Never-married Exec-managerial Unmarried White Male 0 0 40 United-States <=50K
7093 20 Private 273147 HS-grad 9 Never-married Machine-op-inspct Own-child Black Male 0 0 40 United-States <=50K
12565 46 Private 203653 Bachelors 13 Divorced Sales Other-relative Black Male 0 0 40 United-States <=50K
5655 57 Private 174662 HS-grad 9 Married-civ-spouse Craft-repair Husband White Male 0 0 40 United-States <=50K
2322 48 Private 232149 Bachelors 13 Divorced Tech-support Own-child White Female 0 0 40 United-States <=50K
12652 39 Private 82521 Bachelors 13 Married-civ-spouse Sales Husband White Male 0 0 45 United-States >50K
4755 33 Private 330715 HS-grad 9 Married-civ-spouse Sales Husband White Male 0 0 55 United-States <=50K
4413 44 State-gov 128586 Bachelors 13 Never-married Farming-fishing Not-in-family White Male 0 0 40 United-States <=50K

将预测输入数据导出为以换行符分隔的 JSON 文件:

import json

with open('prediction_input.json', 'w') as json_file:
  for row in prediction_input.values.tolist():
    json.dump(row, json_file)
    json_file.write('\n')

退出 Python 解释器 (exit())。在 Shell 中检查 prediction_input.json

cat prediction_input.json
[0.9012127751273994, 1.0, 1.525541514460902, 2.0, 9.0, 0.0, 4.0, -0.14479173735784842, -0.21713186390175285, -0.43754385253479555, 38.0]
[-0.9221541171760282, 3.0, -0.4192650914017433, 4.0, 2.0, 3.0, 4.0, -0.14479173735784842, -0.21713186390175285, -0.03403923708700391, 38.0]
[-1.2138928199445767, 3.0, -0.030303770229214273, 4.0, 10.0, 1.0, 4.0, -0.14479173735784842, -0.21713186390175285, 1.5799792247041626, 38.0]
[-0.6304154144074798, 3.0, 0.35865755094331475, 4.0, 0.0, 3.0, 4.0, -0.14479173735784842, -0.21713186390175285, -0.6796466218034705, 38.0]
[-1.5056315227131252, 3.0, -1.5861490549193304, 4.0, 7.0, 3.0, 0.0, -0.14479173735784842, -0.21713186390175285, -0.03403923708700391, 38.0]
[-0.11987268456252011, 5.0, 0.35865755094331475, 2.0, 2.0, 0.0, 4.0, -0.14479173735784842, -0.21713186390175285, -0.8410484679825871, 38.0]
[0.24480069389816542, 3.0, 1.525541514460902, 2.0, 9.0, 0.0, 4.0, -0.14479173735784842, -0.21713186390175285, 1.176474609256371, 6.0]
[0.0989313425138912, 1.0, 1.525541514460902, 2.0, 9.0, 0.0, 4.0, 0.8868473744801746, -0.21713186390175285, -0.03403923708700391, 38.0]
[0.39067004528243965, 5.0, -0.030303770229214273, 2.0, 4.0, 0.0, 4.0, -0.14479173735784842, -0.21713186390175285, 4.7273152251969375, 38.0]
[1.1200168022038106, 3.0, 1.1365801932883728, 2.0, 12.0, 0.0, 4.0, -0.14479173735784842, -0.21713186390175285, -0.03403923708700391, 38.0]
[-1.2868274956367138, 3.0, -0.030303770229214273, 4.0, 11.0, 3.0, 4.0, -0.14479173735784842, -0.21713186390175285, -1.6480576988781703, 38.0]
[-0.7033500900996169, 3.0, -0.4192650914017433, 2.0, 7.0, 5.0, 4.0, -0.14479173735784842, 4.5022796885373735, -0.43754385253479555, 38.0]
[0.5365393966667138, 3.0, 1.525541514460902, 4.0, 3.0, 4.0, 4.0, -0.14479173735784842, -0.21713186390175285, -0.03403923708700391, 38.0]
[-1.3597621713288508, 3.0, -0.4192650914017433, 4.0, 6.0, 3.0, 2.0, -0.14479173735784842, -0.21713186390175285, -0.03403923708700391, 38.0]
[0.5365393966667138, 3.0, 1.1365801932883728, 0.0, 11.0, 2.0, 2.0, -0.14479173735784842, -0.21713186390175285, -0.03403923708700391, 38.0]
[1.338820829280222, 3.0, -0.4192650914017433, 2.0, 2.0, 0.0, 4.0, -0.14479173735784842, -0.21713186390175285, -0.03403923708700391, 38.0]
[0.6824087480509881, 3.0, 1.1365801932883728, 0.0, 12.0, 3.0, 4.0, -0.14479173735784842, -0.21713186390175285, -0.03403923708700391, 38.0]
[0.0259966668217541, 3.0, 1.1365801932883728, 2.0, 11.0, 0.0, 4.0, -0.14479173735784842, -0.21713186390175285, 0.3694653783607877, 38.0]
[-0.4116113873310685, 3.0, -0.4192650914017433, 2.0, 11.0, 0.0, 4.0, -0.14479173735784842, -0.21713186390175285, 1.176474609256371, 38.0]
[0.39067004528243965, 6.0, 1.1365801932883728, 4.0, 4.0, 1.0, 4.0, -0.14479173735784842, -0.21713186390175285, -0.03403923708700391, 38.0]

gcloud 命令行工具接受将采用换行符分隔的 JSON 用于在线预测,这一特定的 Keras 模型需要每个输入示例的数字的扁平列表。

当您在不使用 gcloud 工具的情况下向 REST API 发出在线预测请求时,AI Platform 需要使用其他格式。构建模型的方式也可能会改变您设置数据格式以进行预测所必须采用的方式。 详细了解设置数据格式以进行在线预测

提交在线预测请求

使用 gcloud 提交您的在线预测请求:

gcloud ai-platform predict \
  --model $MODEL_NAME \
  --region $REGION \
  --version $MODEL_VERSION \
  --json-instances prediction_input.json
DENSE_4
[0.6854287385940552]
[0.011786997318267822]
[0.037236183881759644]
[0.016223609447479248]
[0.0012015104293823242]
[0.23621389269828796]
[0.6174039244651794]
[0.9822691679000854]
[0.3815768361091614]
[0.6715215444564819]
[0.001094043254852295]
[0.43077391386032104]
[0.22132840752601624]
[0.004075437784194946]
[0.22736871242523193]
[0.4111979305744171]
[0.27328649163246155]
[0.6981356143951416]
[0.3309604525566101]
[0.20807647705078125]

由于模型的最后一层使用 sigmoid 函数进行激活,因此 0 到 0.5 之间的输出表示负例预测(“<=50K”),0.5 到 1 之间的输出表示正例预测(“>50K”)。

从头开始开发 Keras 模型

现在您已在 AI Platform 上训练完机器学习模型,将经过训练的模型部署为 AI Platform 上的版本资源,并接收了部署的在线预测结果。下一部分将逐步介绍如何重新创建用于训练模型的 Keras 代码。 它涵盖了有关开发用于 AI Platform 的机器学习模型的以下内容:

  • 下载和预处理数据
  • 设计和训练模型
  • 直观呈现如何训练模型和导出经过训练的模型

虽然本部分深入介绍了前几部分中完成的任务,但要详细了解如何使用 tf.keras,请参阅 TensorFlow 的 Keras 指南。如需详细了解如何以 AI Platform 的训练软件包的形式设计代码结构,请参阅打包训练应用并参考完整训练代码,后者采用 Python 软件包结构。

导入库并定义常量

首先,导入训练所需的 Python 库:

import os
from six.moves import urllib
import tempfile

import numpy as np
import pandas as pd
import tensorflow as tf

# Examine software versions
print(__import__('sys').version)
print(tf.__version__)
print(tf.keras.__version__)

然后,定义一些有用的常量:

  • 有关下载训练和评估数据的信息
  • Pandas 解析数据并将分类字段转换为数值特征所需的信息
  • 用于训练的超参数(例如,学习速率和批次大小)
### For downloading data ###

# Storage directory
DATA_DIR = os.path.join(tempfile.gettempdir(), 'census_data')

# Download options.
DATA_URL = 'https://storage.googleapis.com/cloud-samples-data/ai-platform' \
           '/census/data'
TRAINING_FILE = 'adult.data.csv'
EVAL_FILE = 'adult.test.csv'
TRAINING_URL = '%s/%s' % (DATA_URL, TRAINING_FILE)
EVAL_URL = '%s/%s' % (DATA_URL, EVAL_FILE)

### For interpreting data ###

# These are the features in the dataset.
# Dataset information: https://archive.ics.uci.edu/ml/datasets/census+income
_CSV_COLUMNS = [
    'age', 'workclass', 'fnlwgt', 'education', 'education_num',
    'marital_status', 'occupation', 'relationship', 'race', 'gender',
    'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
    'income_bracket'
]

_CATEGORICAL_TYPES = {
  'workclass': pd.api.types.CategoricalDtype(categories=[
    'Federal-gov', 'Local-gov', 'Never-worked', 'Private', 'Self-emp-inc',
    'Self-emp-not-inc', 'State-gov', 'Without-pay'
  ]),
  'marital_status': pd.api.types.CategoricalDtype(categories=[
    'Divorced', 'Married-AF-spouse', 'Married-civ-spouse',
    'Married-spouse-absent', 'Never-married', 'Separated', 'Widowed'
  ]),
  'occupation': pd.api.types.CategoricalDtype([
    'Adm-clerical', 'Armed-Forces', 'Craft-repair', 'Exec-managerial',
    'Farming-fishing', 'Handlers-cleaners', 'Machine-op-inspct',
    'Other-service', 'Priv-house-serv', 'Prof-specialty', 'Protective-serv',
    'Sales', 'Tech-support', 'Transport-moving'
  ]),
  'relationship': pd.api.types.CategoricalDtype(categories=[
    'Husband', 'Not-in-family', 'Other-relative', 'Own-child', 'Unmarried',
    'Wife'
  ]),
  'race': pd.api.types.CategoricalDtype(categories=[
    'Amer-Indian-Eskimo', 'Asian-Pac-Islander', 'Black', 'Other', 'White'
  ]),
  'native_country': pd.api.types.CategoricalDtype(categories=[
    'Cambodia', 'Canada', 'China', 'Columbia', 'Cuba', 'Dominican-Republic',
    'Ecuador', 'El-Salvador', 'England', 'France', 'Germany', 'Greece',
    'Guatemala', 'Haiti', 'Holand-Netherlands', 'Honduras', 'Hong', 'Hungary',
    'India', 'Iran', 'Ireland', 'Italy', 'Jamaica', 'Japan', 'Laos', 'Mexico',
    'Nicaragua', 'Outlying-US(Guam-USVI-etc)', 'Peru', 'Philippines', 'Poland',
    'Portugal', 'Puerto-Rico', 'Scotland', 'South', 'Taiwan', 'Thailand',
    'Trinadad&Tobago', 'United-States', 'Vietnam', 'Yugoslavia'
  ]),
  'income_bracket': pd.api.types.CategoricalDtype(categories=[
    '<=50K', '>50K'
  ])
}

# This is the label (target) we want to predict.
_LABEL_COLUMN = 'income_bracket'

### Hyperparameters for training ###

# This the training batch size
BATCH_SIZE = 128

# This is the number of epochs (passes over the full training data)
NUM_EPOCHS = 20

# Define learning rate.
LEARNING_RATE = .01

下载和预处理数据

下载数据

接下来,定义用于下载训练和评估数据的函数。这些函数还可以修复数据格式中轻微的不规则问题。

def _download_and_clean_file(filename, url):
  """Downloads data from url, and makes changes to match the CSV format.

  The CSVs may use spaces after the comma delimters (non-standard) or include
  rows which do not represent well-formed examples. This function strips out
  some of these problems.

  Args:
    filename: filename to save url to
    url: URL of resource to download
  """
  temp_file, _ = urllib.request.urlretrieve(url)
  with tf.gfile.Open(temp_file, 'r') as temp_file_object:
    with tf.gfile.Open(filename, 'w') as file_object:
      for line in temp_file_object:
        line = line.strip()
        line = line.replace(', ', ',')
        if not line or ',' not in line:
          continue
        if line[-1] == '.':
          line = line[:-1]
        line += '\n'
        file_object.write(line)
  tf.gfile.Remove(temp_file)


def download(data_dir):
  """Downloads census data if it is not already present.

  Args:
    data_dir: directory where we will access/save the census data
  """
  tf.gfile.MakeDirs(data_dir)

  training_file_path = os.path.join(data_dir, TRAINING_FILE)
  if not tf.gfile.Exists(training_file_path):
    _download_and_clean_file(training_file_path, TRAINING_URL)

  eval_file_path = os.path.join(data_dir, EVAL_FILE)
  if not tf.gfile.Exists(eval_file_path):
    _download_and_clean_file(eval_file_path, EVAL_URL)

  return training_file_path, eval_file_path

使用这些函数下载用于训练的数据,并验证您拥有用于训练和评估的 CSV 文件:

training_file_path, eval_file_path = download(DATA_DIR)

接下来,使用 Pandas 加载这些文件并检查数据:

# This census data uses the value '?' for fields (column) that are missing data.
# We use na_values to find ? and set it to NaN values.
# https://pandas.pydata.org/pandas-docs/stable/generated/pandas.read_csv.html

train_df = pd.read_csv(training_file_path, names=_CSV_COLUMNS, na_values='?')
eval_df = pd.read_csv(eval_file_path, names=_CSV_COLUMNS, na_values='?')

下表显示了进行预处理之前的数据摘录 (train_df.head()):

age workclass fnlwgt education education_num marital_status occupation relationship race gender capital_gain capital_loss hours_per_week native_country income_bracket
0 39 State-gov 77516 Bachelors 13 Never-married Adm-clerical Not-in-family White Male 2174 0 40 United-States <=50K
1 50 Self-emp-not-inc 83311 Bachelors 13 Married-civ-spouse Exec-managerial Husband White Male 0 0 13 United-States <=50K
2 38 Private 215646 HS-grad 9 Divorced Handlers-cleaners Not-in-family White Male 0 0 40 United-States <=50K
3 53 Private 234721 11th 7 Married-civ-spouse Handlers-cleaners Husband Black Male 0 0 40 United-States <=50K
4 28 Private 338409 Bachelors 13 Married-civ-spouse Prof-specialty Wife Black Female 0 0 40 Cuba <=50K

预处理数据

第一个预处理步骤从数据中移除某些特征,并将分类特征转换为数值,以便用于 Keras。

详细了解特征工程数据偏差

UNUSED_COLUMNS = ['fnlwgt', 'education', 'gender']


def preprocess(dataframe):
  """Converts categorical features to numeric. Removes unused columns.

  Args:
    dataframe: Pandas dataframe with raw data

  Returns:
    Dataframe with preprocessed data
  """
  dataframe = dataframe.drop(columns=UNUSED_COLUMNS)

  # Convert integer valued (numeric) columns to floating point
  numeric_columns = dataframe.select_dtypes(['int64']).columns
  dataframe[numeric_columns] = dataframe[numeric_columns].astype('float32')

  # Convert categorical columns to numeric
  cat_columns = dataframe.select_dtypes(['object']).columns
  dataframe[cat_columns] = dataframe[cat_columns].apply(lambda x: x.astype(
    _CATEGORICAL_TYPES[x.name]))
  dataframe[cat_columns] = dataframe[cat_columns].apply(lambda x: x.cat.codes)
  return dataframe

prepped_train_df = preprocess(train_df)
prepped_eval_df = preprocess(eval_df)

下表 (prepped_train_df.head()) 显示如何处理已更改的数据。请特别注意,您训练模型以预测的标签 income_bracket 已从 <=50K>50K 更改为 01

age workclass education_num marital_status occupation relationship race capital_gain capital_loss hours_per_week native_country income_bracket
0 39.0 6 13.0 4 0 1 4 2174.0 0.0 40.0 38 0
1 50.0 5 13.0 2 3 0 4 0.0 0.0 13.0 38 0
2 38.0 3 9.0 0 5 1 4 0.0 0.0 40.0 38 0
3 53.0 3 7.0 2 5 0 2 0.0 0.0 40.0 38 0
4 28.0 3 13.0 2 9 5 2 0.0 0.0 40.0 4 0

接下来,将数据分成特征(“x”)和标签(“y”),并将标签数组转换为稍后可用于 tf.data.Dataset 的格式。

# Split train and test data with labels.
# The pop() method will extract (copy) and remove the label column from the dataframe
train_x, train_y = prepped_train_df, prepped_train_df.pop(_LABEL_COLUMN)
eval_x, eval_y = prepped_eval_df, prepped_eval_df.pop(_LABEL_COLUMN)

# Reshape label columns for use with tf.data.Dataset
train_y = np.asarray(train_y).astype('float32').reshape((-1, 1))
eval_y = np.asarray(eval_y).astype('float32').reshape((-1, 1))

缩放训练数据,使每个数值特征列的平均值为 0,标准差为 1,以便改善您的模型

在生产系统中,建议您保存训练集的平均值和标准差,在预测时使用它们对测试数据执行相同的转换。为方便起见,在本练习中请暂时将训练和评估数据合并在一起进行缩放:

def standardize(dataframe):
  """Scales numerical columns using their means and standard deviation to get
  z-scores: the mean of each numerical column becomes 0, and the standard
  deviation becomes 1. This can help the model converge during training.

  Args:
    dataframe: Pandas dataframe

  Returns:
    Input dataframe with the numerical columns scaled to z-scores
  """
  dtypes = list(zip(dataframe.dtypes.index, map(str, dataframe.dtypes)))
  # Normalize numeric columns.
  for column, dtype in dtypes:
      if dtype == 'float32':
          dataframe[column] -= dataframe[column].mean()
          dataframe[column] /= dataframe[column].std()
  return dataframe


# Join train_x and eval_x to normalize on overall means and standard
# deviations. Then separate them again.
all_x = pd.concat([train_x, eval_x], keys=['train', 'eval'])
all_x = standardize(all_x)
train_x, eval_x = all_x.xs('train'), all_x.xs('eval')

完全预处理后的数据如下表 (train_x.head()) 中所示:

age workclass education_num marital_status occupation relationship race capital_gain capital_loss hours_per_week native_country
0 0.025997 6 1.136580 4 0 1 4 0.146933 -0.217132 -0.034039 38
1 0.828278 5 1.136580 2 3 0 4 -0.144792 -0.217132 -2.212964 38
2 -0.046938 3 -0.419265 0 5 1 4 -0.144792 -0.217132 -0.034039 38
3 1.047082 3 -1.197188 2 5 0 2 -0.144792 -0.217132 -0.034039 38
4 -0.776285 3 1.136580 2 9 5 2 -0.144792 -0.217132 -0.034039 4

设计和训练模型

创建训练数据集和验证数据集

创建输入函数,将特征和标签转换为用于训练或评估的 tf.data.Dataset

def input_fn(features, labels, shuffle, num_epochs, batch_size):
  """Generates an input function to be used for model training.

  Args:
    features: numpy array of features used for training or inference
    labels: numpy array of labels for each example
    shuffle: boolean for whether to shuffle the data or not (set True for
      training, False for evaluation)
    num_epochs: number of epochs to provide the data for
    batch_size: batch size for training

  Returns:
    A tf.data.Dataset that can provide data to the Keras model for training or
      evaluation
  """
  if labels is None:
    inputs = features
  else:
    inputs = (features, labels)
  dataset = tf.data.Dataset.from_tensor_slices(inputs)

  if shuffle:
    dataset = dataset.shuffle(buffer_size=len(features))

  # We call repeat after shuffling, rather than before, to prevent separate
  # epochs from blending together.
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size)
  return dataset

接下来,创建这些训练和评估数据集。使用先前定义的 NUM_EPOCHSBATCH_SIZE 超参数来定义训练数据集在训练期间如何为模型提供示例。设置验证数据集在一个批次中提供其所有示例,以便在每个训练周期结束时进行一个验证步骤。

# Pass a numpy array by using DataFrame.values
training_dataset = input_fn(features=train_x.values,
                    labels=train_y,
                    shuffle=True,
                    num_epochs=NUM_EPOCHS,
                    batch_size=BATCH_SIZE)

num_eval_examples = eval_x.shape[0]

# Pass a numpy array by using DataFrame.values
validation_dataset = input_fn(features=eval_x.values,
                    labels=eval_y,
                    shuffle=False,
                    num_epochs=NUM_EPOCHS,
                    batch_size=num_eval_examples)

设计 Keras 模型

使用 Keras Sequential API 设计您的神经网络。

这种深度神经网络 (DNN) 具有多个隐藏层,最后一层使用 sigmoid 激活函数输出一个 0 到 1 之间的值:

  • 输入层有 100 个单元使用 ReLU 激活函数。
  • 隐藏层有 75 个单元使用 ReLU 激活函数。
  • 隐藏层有 50 个单元使用 ReLU 激活函数。
  • 隐藏层有 25 个单元使用 ReLU 激活函数。
  • 输出层有 1 个单元使用 sigmoid 激活函数。
  • 优化器使用二进制交叉熵损失函数,该函数适用于与此类似的二元分类问题。

您可以随意更改这些层,以尝试改进模型:

def create_keras_model(input_dim, learning_rate):
  """Creates Keras Model for Binary Classification.

  Args:
    input_dim: How many features the input has
    learning_rate: Learning rate for training

  Returns:
    The compiled Keras model (still needs to be trained)
  """
  Dense = tf.keras.layers.Dense
  model = tf.keras.Sequential(
    [
        Dense(100, activation=tf.nn.relu, kernel_initializer='uniform',
                input_shape=(input_dim,)),
        Dense(75, activation=tf.nn.relu),
        Dense(50, activation=tf.nn.relu),
        Dense(25, activation=tf.nn.relu),
        Dense(1, activation=tf.nn.sigmoid)
    ])

  # Custom Optimizer:
  # https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
  optimizer = tf.keras.optimizers.RMSprop(
      lr=learning_rate)

  # Compile Keras model
  model.compile(
      loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
  return model

接下来,创建 Keras 模型对象:

num_train_examples, input_dim = train_x.shape
print('Number of features: {}'.format(input_dim))
print('Number of examples: {}'.format(num_train_examples))

keras_model = create_keras_model(
    input_dim=input_dim,
    learning_rate=LEARNING_RATE)

使用 keras_model.summary() 检查模型,它应该返回如下内容:

Number of features: 11
Number of examples: 32561
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 100)               1200      
_________________________________________________________________
dense_1 (Dense)              (None, 75)                7575      
_________________________________________________________________
dense_2 (Dense)              (None, 50)                3800      
_________________________________________________________________
dense_3 (Dense)              (None, 25)                1275      
_________________________________________________________________
dense_4 (Dense)              (None, 1)                 26        
=================================================================
Total params: 13,876
Trainable params: 13,876
Non-trainable params: 0
_________________________________________________________________

训练和评估模型

定义学习速率衰减,帮助模型参数在训练进行过程中做出较小的更改:

# Setup Learning Rate decay.
lr_decay_cb = tf.keras.callbacks.LearningRateScheduler(
    lambda epoch: LEARNING_RATE + 0.02 * (0.5 ** (1 + epoch)),
    verbose=True)

# Setup TensorBoard callback.
JOB_DIR = os.getenv('JOB_DIR')
tensorboard_cb = tf.keras.callbacks.TensorBoard(
      os.path.join(JOB_DIR, 'keras_tensorboard'),
      histogram_freq=1)

最后,进入训练模型阶段。为模型提供适当的 steps_per_epoch,使模型在每个周期内根据整个训练数据集进行训练(每个步骤都提供有 BATCH_SIZE 示例)。指示模型在每个周期结束时使用一个大的验证批次来计算验证准确率。

history = keras_model.fit(training_dataset,
                          epochs=NUM_EPOCHS,
                          steps_per_epoch=int(num_train_examples/BATCH_SIZE),
                          validation_data=validation_dataset,
                          validation_steps=1,
                          callbacks=[lr_decay_cb, tensorboard_cb],
                          verbose=1)

训练进度可能如下所示:

WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.

Epoch 00001: LearningRateScheduler reducing learning rate to 0.02.
Epoch 1/20
254/254 [==============================] - 1s 5ms/step - loss: 0.6986 - acc: 0.7893 - val_loss: 0.3894 - val_acc: 0.8329

Epoch 00002: LearningRateScheduler reducing learning rate to 0.015.
Epoch 2/20
254/254 [==============================] - 1s 4ms/step - loss: 0.3574 - acc: 0.8335 - val_loss: 0.3861 - val_acc: 0.8131

...

Epoch 00019: LearningRateScheduler reducing learning rate to 0.010000038146972657.
Epoch 19/20
254/254 [==============================] - 1s 4ms/step - loss: 0.3239 - acc: 0.8512 - val_loss: 0.3334 - val_acc: 0.8496

Epoch 00020: LearningRateScheduler reducing learning rate to 0.010000019073486329.
Epoch 20/20
254/254 [==============================] - 1s 4ms/step - loss: 0.3279 - acc: 0.8504 - val_loss: 0.3174 - val_acc: 0.8523

直观呈现训练过程和导出经过训练的模型

直观呈现训练过程

导入 matplotlib 以直观呈现模型在训练期间如何进行学习。(如有必要,请先运行 pip install matplotlib 安装该库。)

from matplotlib import pyplot as plt

根据每个训练周期结束时的测量结果,绘制模型的损失(二元交叉熵)和准确率:

# Visualize History for Loss.
plt.title('Keras model loss')
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['training', 'validation'], loc='upper right')
plt.show()

# Visualize History for Accuracy.
plt.title('Keras model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.legend(['training', 'validation'], loc='lower right')
plt.show()

随着时间的推移,损失降低,准确率升高。但它们会收敛到一个稳定的水平吗?训练和验证指标之间是否存在很大的差异(过拟合迹象)?

了解如何改进机器学习模型。 然后,随意调整超参数或模型架构并再次训练。

导出模型以提供支持

使用 tf.contrib.saved_model.save_keras_model 导出 TensorFlow SavedModel 目录。这是创建模型版本资源时 AI Platform 所需的格式。

由于某些优化器不能导出为 SavedModel 格式,因此您可能会在导出过程中看到警告。只要您成功导出执行图表,AI Platform 就可以使用 SavedModel 执行预测。

# Export the model to a local SavedModel directory
export_path = tf.contrib.saved_model.save_keras_model(keras_model, 'keras_export')
print("Model exported to: ", export_path)
WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.RMSprop object at 0x7fc198c4e400>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py:1436: update_checkpoint_state (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.train.CheckpointManager to manage checkpoints rather than manually editing the Checkpoint proto.
WARNING:tensorflow:Model was compiled with an optimizer, but the optimizer is not from `tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving graph was exported. The train and evaluate graphs were not added to the SavedModel.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:205: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
INFO:tensorflow:Signatures INCLUDED in export for Classify: None
INFO:tensorflow:Signatures INCLUDED in export for Regress: None
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default']
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:No assets to save.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: keras_export/1553710367/saved_model.pb
Model exported to:  b'keras_export/1553710367'

只要您具备所需的权限,就可以将 SavedModel 目录导出到本地文件系统或 Cloud Storage。在当前环境中,您通过验证 Google Cloud 账号并设置 GOOGLE_APPLICATION_CREDENTIALS 环境变量获得了对 Cloud Storage 的访问权限。AI Platform 训练作业还可以直接导出到 Cloud Storage,因为 AI Platform 服务账号有权访问自己项目中的 Cloud Storage 存储分区

尝试直接导出到 Cloud Storage:

JOB_DIR = os.getenv('JOB_DIR')

# Export the model to a SavedModel directory in Cloud Storage
export_path = tf.contrib.saved_model.save_keras_model(keras_model, JOB_DIR + '/keras_export')
print("Model exported to: ", export_path)
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.RMSprop object at 0x7fc198c4e400>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
WARNING:tensorflow:Model was compiled with an optimizer, but the optimizer is not from `tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving graph was exported. The train and evaluate graphs were not added to the SavedModel.
INFO:tensorflow:Signatures INCLUDED in export for Classify: None
INFO:tensorflow:Signatures INCLUDED in export for Regress: None
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default']
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:No assets to save.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: gs://your-bucket-name/keras-job-dir/keras_export/1553710379/saved_model.pb
Model exported to:  b'gs://your-bucket-name/keras-job-dir/keras_export/1553710379'

现在,您可以将此模型部署到 AI Platform,并按照预测快速入门中的步骤执行预测。

清理

如需清理此项目中使用的所有 Google Cloud 资源,您可以删除用于本教程的 Google Cloud项目

您也可以运行以下命令,清理各个资源:

# Delete model version resource
gcloud ai-platform versions delete $MODEL_VERSION --quiet --model $MODEL_NAME

# Delete model resource
gcloud ai-platform models delete $MODEL_NAME --quiet

# Delete Cloud Storage objects that were created
gsutil -m rm -r $JOB_DIR

# If training job is still running, cancel it
gcloud ai-platform jobs cancel $JOB_NAME --quiet

如果您的 Cloud Storage 存储分区中没有任何其他对象,且您想要删除该存储分区,请运行 gsutil rm -r gs://$BUCKET_NAME

后续步骤

  • 查看本指南中使用的完整训练代码,这些代码的结构可接受将自定义超参数作为命令行标志。
  • 了解如何为 AI Platform 训练作业封装代码
  • 了解如何部署模型以执行预测。