使用内置 TabNet 算法进行训练

使用 AI Platform Training 中的内置算法进行训练时,您无需编写任何训练代码即可提交数据集并训练模型。本页面介绍了内置 TabNet 算法的工作原理及使用方法。

概览

此内置算法会执行预处理和训练:

  1. 预处理:AI Platform Training 将混合的分类数据和数值数据处理为纯数值数据集,以便为训练做好准备。
  2. 训练:AI Platform Training 使用您提供的数据集和模型参数,通过 Tensorflow 的自定义 Estimator 运行训练。

限制

使用内置 TabNet 算法进行训练时不支持以下功能

支持的机器类型

以下是系统支持的 AI Platform Training 容量层级和机器类型

设置输入数据的格式

数据集的每一行表示一个实例,数据集的每一列表示一个特征值。目标列表示您要预测的值

准备 CSV 文件

您的输入数据必须是采用 UTF-8 编码的 CSV 文件。如果您的训练数据仅包含分类值和数值,则可以使用我们的预处理模块填充缺失的数值、拆分数据集以及移除缺失值超过 10% 的行。否则,您可以在未启用自动预处理的情况下运行训练。

您在准备输入 CSV 文件时须满足以下要求:

  • 移除标题行。标题行包含每列的标签。移除标题行,避免将其作为训练数据的一部分与其余数据实例一起提交。
  • 确保目标列为第一列。目标列包含您尝试预测的值。对于分类算法,目标列中的所有值均为类或类别。对于回归算法,目标列中的所有值均为数值。

处理整数值

整数值的含义可能不明确,这会导致整数值的列在自动预处理过程中成为问题。AI Platform Training 会自动确定如何处理整数值。默认情况下:

  • 如果每个整数值都是唯一的,则该列会被视为实例键。
  • 如果只有几个唯一的整数值,则该列会被视为分类列。
  • 否则,该列中的值将转换为浮点数并被视为数值。

要覆盖这些默认判定,请执行以下操作:

  • 如果数据应被视为数值,请将列中的所有整数值转换为浮点数,例如,{101.0, 102.0, 103.0}
  • 如果数据应被视为分类,请在列中所有整数值的前面加上非数值前缀,例如,{code_101, code_102, code_103}

提交 TabNet 训练作业

本部分介绍如何使用内置 TabNet 算法提交训练作业。

您可以在 Google Cloud 控制台中找到每个超参数的简要说明,可在内置 TabNet 算法参考中找到更全面的说明。

控制台

  1. 转到 Google Cloud 控制台中的 AI Platform Training“作业”页面:

    AI Platform Training“作业”页面

  2. 点击新建训练作业按钮。从下方显示的选项中,点击内置算法训练

  3. 创建新的训练作业页面上,选择 TabNet,然后点击下一步

  4. 如需详细了解所有可用的参数,请点击 Google Cloud 控制台中的链接,查看内置 TabNet 参考文档了解详情。

gcloud

  1. 为您的作业设置环境变量,用您的项目的相应值替代 [VALUES-IN-BRACKETS] 内的值:

       # Specify the name of the Cloud Storage bucket where you want your
       # training outputs to be stored, and the Docker container for
       # your built-in algorithm selection.
       BUCKET_NAME='[YOUR-BUCKET-NAME]'
       IMAGE_URI='gcr.io/cloud-ml-algos/tab_net:latest'
    
       # Specify the Cloud Storage path to your training input data.
       TRAINING_DATA='gs://[YOUR_BUCKET_NAME]/[YOUR_FILE_NAME].csv'
    
       DATASET_NAME='census'
       ALGORITHM='tabnet'
       MODEL_TYPE='classification'
    
       DATE='date '+%Y%m%d_%H%M%S''
       MODEL_NAME="${DATASET_NAME}_${ALGORITHM}_${MODEL_TYPE}"
       JOB_ID="${MODEL_NAME}_${DATE}"
    
       JOB_DIR="gs://${BUCKET_NAME}/algorithm_training/${MODEL_NAME}/${DATE}"
    
  2. 使用 gcloud ai-platform jobs training submit 提交训练作业:

    gcloud ai-platform jobs submit training $JOB_ID \
      --master-image-uri=$IMAGE_URI --scale-tier=BASIC \
      --job-dir=$JOB_DIR \
      -- \
      --max_steps=2000 \
      --preprocess \
      --model_type=$MODEL_TYPE \
      --batch_size=4096 \
      --learning_rate=0.01 \
      --training_data_path=$TRAINING_DATA_PATH
    

  3. 使用 gcloud 查看日志,从而监控训练作业的状态。请参阅 gcloud ai-platform jobs describegcloud ai-platform jobs stream-logs

       gcloud ai-platform jobs describe ${JOB_ID}
       gcloud ai-platform jobs stream-logs ${JOB_ID}
    

预处理的工作原理

自动预处理适用于分类数据和数值数据。预处理例程先分析您的数据,然后再对其进行转换

分析

首先,AI Platform Training 会自动检测每个列的数据类型、确定应如何处理每一列,并计算该列中数据的部分统计信息。此类信息将捕获到 metadata.json 文件中。

AI Platform Training 会分析目标列的类型,以此确定特定数据集是用于回归还是分类。如果此分析与您对 model_type 的选择冲突,则会导致错误。通过在不确定的情况下清楚地设置数据格式,明确应如何处理目标列。

  • 类型:列可以是数值列,也可以是分类列

  • 处理方式:AI Platform Training 会根据如下原则确定如何处理每个列:

    • 如果列在所有行中包含单个值,则将其视为常量
    • 如果列是分类列,并且所有行中都包含唯一值,则将其视为 row_identifier
    • 如果列是带有浮点值的数值列,或者该列是带整数值的数值列且包含许多唯一值,则将该列视为数值
    • 如果列是带有整数值的数值列,并且包含的唯一值非常少,则将该列视为分类列,其中整数值是身份词汇
      • 如果列中唯一值的数量小于输入数据集中行数的 20%,则认为该列具有很少的唯一值
    • 如果列是具有高基数的分类列,则对该列进行哈希处理,其中哈希分区数等于列中唯一值数量的平方根。
      • 如果唯一值的数量大于数据集中行数的平方根,则认为该分类列具有高基数
    • 如果列是分类列,并且唯一值的数量小于或等于数据集行数的平方根,则将该列视为包含词汇的普通分类列。
  • 统计信息:AI Platform Training 会根据已确定的列类型和处理方式计算以下统计信息,以便在后续阶段中用于转换列。

    • 如果列是数值列,则计算均值和方差值。
    • 如果列是分类列,并且处理方式是身份或词汇,则从列中提取不同的值。
    • 如果列是分类列,并且处理方式是哈希技术,则按照列的基数计算哈希分区的数量。

转换

完成数据集的初始分析后,AI Platform Training 会根据应用于数据集的类型、处理方式和统计信息来转换您的数据。AI Platform Training 会按以下顺序执行转换:

  1. 将训练数据集拆分为验证和测试数据集(如果您指定拆分百分比)。
  2. 移除特征缺少 10% 以上的行。
  3. 使用列的平均值填写缺失的数值。

转换示例

超过 10% 的值缺失的行会被删除。在以下示例中,假设该行具有 10 个值。为简单起见,每个示例行都经过精简。

行问题 原始值 转换后的值 说明
没有缺失值的示例行 [3, 0.45, ...,
'fruits', 0, 1]
[3, 0.45, ...,
1, 0, 0, 0, 1]
字符串 'fruits' 在独热编码中转换为值“1, 0, 0”。此转换稍后在 TensorFlow 图中发生。
缺失值过多 [3, 0.45, ...,
'fruits', __, __]
行已移除 行中超过 10% 的值缺失。
缺少数值 [3, 0.45, ...,
'fruits', 0, __]
[3, 0.45, ...,
1, 0, 0, 0, 0.54]
  • 该列的平均值替换了缺失的数值。在本示例中,平均值是 0.54。
  • 字符串 'fruits' 在独热编码中转换为值“1, 0, 0”。此转换稍后在 TensorFlow 图中发生。
缺少分类值 [3, 0.45, ...,
__, 0, 1]
[3, 0.45, ...,
0, 0, 0, 0, 1]
  • 缺失的分类值在独热编码中转换为值“0, 0, 0”。此转换稍后在 TensorFlow 图中发生。

特征列

转换期间不会处理这些列。分析期间生成的元数据会传递到 AI Platform Training,以相应地创建特征列:

列类型 列的处理方式 生成的特征列
数值 (所有列处理类型) tf.feature_column.numeric_column

均值和方差值用于执行值的标准化操作:
new_value = (input_value - mean) / sqrt(variance)

分类 身份 tf.feature_column.categorical_column_with_identity
分类 词汇 tf.feature_column.categorical_column_with_vocabulary_list
分类 哈希技术 tf.feature_column.categorical_column_with_hash_bucket
分类 常量或行标识符 忽略。未创建任何特征列。

自动预处理完成后,AI Platform Training 会将处理后的数据集上传回 Cloud Storage 存储分区,并存放在您在作业请求内指定的目录中。

深入学习资源