VertexAI 中的 TabNet 内置算法使用入门。

概览

TabNet 是一种用于表格(结构化)数据的可解释深度学习架构。它融合了两方面的优势:既可以解释较简单的基于树的模型,同时可实现黑盒模型和集成学习的高准确率。这使得 TabNet 非常适合各种表格数据任务,如金融资产价格预测、欺诈/网络攻击/犯罪检测、零售需求预测、来自医疗保健记录的诊断、产品推荐和其他应用。

TabNet 在其架构中采用专门设计的层:依序关注,旨在从模型的每个步骤中选择要推理的模型特征。这种机制可以解释模型如何实现其预测,并帮助其学习更准确的模型。得益于这种设计,TabNet 的表现不仅优于其他神经网络和决策树,同时还提供了可解释的特征归因。

输入数据

TabNet 要求采用以下格式之一的表格输入:

  • 训练数据:用于训练模型的标签数据。支持的文件格式如下:
  • 输入架构
    • 如果输入为 csv,则第一列将是目标变量。
    • 如果输入为 BigQuery,则您需要指定 target_column 参数。

准备 CSV 文件

输入数据可以是采用 UTF-8 编码的 CSV 文件。如果训练数据仅包含分类值和数值,则可以使用我们的预处理模块填充缺失的数值、拆分数据集以及移除缺失值超过 10% 的行。否则,您可以在未启用自动预处理的情况下运行训练。 在 CSV 文件中,第一列是目标变量。如果 CSV 文件包含标题,则必须指定 has_header 参数。

准备 BigQuery 数据集

输入可以是 BigQuery 数据集。您可以通过多种方法将输入数据加载到 BigQuery。

培训

如需使用 TabNet 执行单节点训练,请使用以下命令。此命令会创建一个使用单个 CPU 机器的 CustomJob 资源。如需了解可在训练期间使用的标志,请参阅本页面上的标志部分。

使用 CSV 输入进行训练

以下是使用 CSV 作为输入格式的示例。成功训练模型后,建议您优化训练期间使用的超参数,以提高模型的准确性和性能。教程笔记本提供了超参数训练作业的示例。

# URI of the TabNet Docker image.
LEARNER_IMAGE_URI='us-docker.pkg.dev/vertex-ai-restricted/builtin-algorithm/tab_net_v2'

# The region to run the job in.
REGION='us-central1'

# Your project.
PROJECT_ID="[your-project-id]"

# Set the training data
DATASET_NAME="petfinder"  # Change to your dataset name.
IMPORT_FILE="petfinder-tabular-classification-tabnet-with-header.csv"
MODEL_TYPE="classification"

# Give a unique name to your training job.
DATE="$(date '+%Y%m%d_%H%M%S')"

# Set a unique name for the job to run.
JOB_NAME="tab_net_cpu_${DATASET_NAME}_${DATE}"
echo $JOB_NAME

# Define your bucket.
YOUR_BUCKET_NAME="gs://[your-bucket-name]" # Replace by your bucket name

# Copy the csv to your bucket.
TRAINING_DATA_PATH="${YOUR_BUCKET_NAME}/data/${DATASET_NAME}/train.csv"
gsutil cp gs://cloud-samples-data/ai-platform-unified/datasets/tabular/${IMPORT_FILE} TRAINING_DATA_PATH

# Set a location for the output.
OUTPUT_DIR="${YOUR_BUCKET_NAME}/${JOB_NAME}/"

echo $OUTPUT_DIR
echo $JOB_NAME

gcloud ai custom-jobs create \
  --region=${REGION} \
  --display-name=${JOB_NAME} \
  --worker-pool-spec=machine-type=n1-standard-8,replica-count=1,container-image-uri=${LEARNER_IMAGE_URI} \
  --args=--preprocess \
  --args=--model_type=${MODEL_TYPE} \
  --args=--data_has_header \
  --args=--training_data_path=${TRAINING_DATA_PATH} \
  --args=--job-dir=${OUTPUT_DIR} \
  --args=--max_steps=2000 \
  --args=--batch_size=4096 \
  --args=--learning_rate=0.01

使用 BigQuery 输入进行训练

以下是使用 BigQuery 作为输入的示例。

# URI of the TabNet Docker image.
LEARNER_IMAGE_URI='us-docker.pkg.dev/vertex-ai-restricted/builtin-algorithm/tab_net_v2'

# The region to run the job in.
REGION='us-central1'

# Your project.
PROJECT_ID="[your-project-id]"

# Set the training data
DATASET_NAME="petfinder"  # Change to your dataset name.
IMPORT_FILE="petfinder-tabular-classification-tabnet-with-header.csv"

# Give a unique name to your training job.
DATE="$(date '+%Y%m%d_%H%M%S')"

# Set a unique name for the job to run.
JOB_NAME="tab_net_cpu_${DATASET_NAME}_${DATE}"
echo $JOB_NAME

# Define your bucket.
YOUR_BUCKET_NAME="gs://[your-bucket-name]" # Replace by your bucket name

# Copy the csv to your bucket.
TRAINING_DATA_PATH="${YOUR_BUCKET_NAME}/data/${DATASET_NAME}/train.csv"
gsutil cp gs://cloud-samples-data/ai-platform-unified/datasets/tabular/${IMPORT_FILE} TRAINING_DATA_PATH

# Create BigQuery dataset.
bq --location=${REGION} mk --dataset ${PROJECT_ID}:${DATASET_NAME}

# Create BigQuery table using CSV file.
TABLE_NAME="train"
bq --location=${REGION} load --source_format=CSV --autodetect ${PROJECT_ID}:${DATASET_NAME}.${TABLE_NAME} ${YOUR_BUCKET_NAME}/data/petfinder/train.csv

# Set a location for the output.
OUTPUT_DIR="${YOUR_BUCKET_NAME}/${JOB_NAME}/"
echo $OUTPUT_DIR
echo $JOB_NAME

gcloud ai custom-jobs create \
  --region=${REGION} \
  --display-name=${JOB_NAME} \
  --worker-pool-spec=machine-type=n1-standard-8,replica-count=1,container-image-uri=${LEARNER_IMAGE_URI} \
  --args=--preprocess \
  --args=--input_type=bigquery \
  --args=--model_type=classification \
  --args=--stream_inputs \
  --args=--bq_project=${PROJECT_ID} \
  --args=--dataset_name=${DATASET_NAME} \
  --args=--table_name=${TABLE_NAME} \
  --args=--target_column=Adopted \
  --args=--num_parallel_reads=2 \
  --args=--optimizer_type=adam \
  --args=--data_cache=disk \
  --args=--deterministic_data=False \
  --args=--loss_function_type=weighted_cross_entropy \
  --args=--replace_transformed_features=True \
  --args=--apply_quantile_transform=True \
  --args=--apply_log_transform=True \
  --args=--max_steps=2000 \
  --args=--batch_size=4096 \
  --args=--learning_rate=0.01 \
  --args=--job-dir=${OUTPUT_DIR}

了解作业目录

训练作业成功完成后,TabNet 训练会在 Cloud Storage 存储桶中创建经过训练的模型,以及其他一些工件。您可以在 JOB_DIR 中找到以下目录结构:

  • artifacts/
    • metadata.json
  • model/(亦包含 deployment_config.yaml 文件的 TensorFlow SavedModel 目录
    • saved_model.pb
    • deployment_config.yaml
  • processed_data/
    • test.csv
    • training.csv
    • validation.csv

作业目录还包含位于“实验”目录中的各种模型检查点文件。 您可以使用 TensorBoard 来直观呈现指标。最终指标也包含在 deployment_config.yaml 中。

确认您的 JOB_DIR 中的目录结构与上述结构相同:

gsutil ls -a $JOB_DIR/*

教程笔记本

Colab 中有一个示例笔记本,用于启动 TabNet。此笔记本还向您展示了如何:

  • 使用 BigQuery 输入进行训练。

  • 使用 GPU 进行分布式训练。

  • 使用超参数调节。

标志

训练模型时,请使用以下通用训练标志和特定于 TabNet 的训练标志。

通用训练标志

以下自定义训练标志最为常用。如需了解详情,请参阅创建自定义训练作业

  • worker-pool-spec:自定义作业使用的工作器池配置。要创建具有多个工作器池的自定义作业,请指定多个 worker-pool-spec 配置。

    worker-pool-spec 可能包含以下字段,这些字段在 WorkerPoolSpec API 消息中与相应字段一起列出。

    • machine-type:该池的机器类型。如需查看受支持的机器列表,请参阅机器类型
    • replica-count:池中的机器副本数。
    • container-image-uri:要在每个工作器上运行的 Docker 映像。如需使用 TabNet 内置算法,必须将 Docker 映像设置为 us-docker.pkg.dev/vertex-ai-restricted/builtin-algorithm/tab_net_v2:latest
  • display-name:作业的名称。

  • region:您希望在其中运行作业的区域。

特定于 TabNet 的训练标志

下表展示了您可以在 TabNet 训练作业中设置的运行时参数:

参数 数据类型 说明 必填
preprocess 布尔值参数 指定此参数可启用自动预处理。
job_dir 字符串 用于存储模型输出文件的 Cloud Storage 目录。
input_metadata_path 字符串 指向训练数据集特定于 TabNet 的元数据的 Cloud Storage 路径。请参阅上文,了解如何创建该元数据。
training_data_path 字符串 用于存储训练数据的 Cloud Storage 模式。
validation_data_path 字符串 用于存储评估数据的 Cloud Storage 模式。
test_data_path 字符串 用于存储测试数据的 Cloud Storage 模式。
input_type 字符串 “bigquery”或“csv”- 输入表格数据的类型。如果提及到 csv,则第一列被视为目标。如果 CSV 文件包含标题,还应传递“data_has_header”标志。如果使用“bigquery”,则可以提供训练/验证数据路径或提供 BigQuery 项目、数据集和表名称进行预处理,以生成训练和验证数据集。 否 - 默认值为“csv”。
model_type 字符串 分类或回归等学习任务。
split_column 字符串 用于创建训练、验证和测试拆分的列名称。列值(又称为 table['split_column'])应为“TRAIN”、“VALIDATE”或“TEST”。“TEST”是可选的。仅适用于 bigquery 输入。
train_batch_size 整数 用于训练的批次大小。 否 - 默认值为 1024。
eval_split float 如果未提供 validation_data_path,则此参数表示用于评估数据集的拆分比例。 否 - 默认值为 0.2
learning_rate float 训练的学习速率。 否 - 默认为指定的优化器的默认学习速率。
eval_frequency_secs 整数 评估和检查点的执行频率,默认值为 600。
num_parallel_reads 整数 用于读取输入文件的线程数。在大多数情况下,我们建议将其设置为等于或略低于机器的 CPU 数量,以实现最高性能。例如,每个 GPU 6 个是一个不错的默认选择。
data_cache 字符串 选择将数据缓存到“memory”、“disk”或“no_cache”。对于大型数据集,将数据缓存到内存中会抛出内存不足错误,因此我们建议选择“disk”。您可以在配置文件中指定磁盘大小(如以下示例所示)。对于大型(B 规模)数据集,请务必请求足够大(例如 TB 大小)的磁盘来写入数据。 否 - 默认值为“memory”。
bq_project 字符串 BigQuery 项目的名称。如果 input_type=bigquery 并使用标志“预处理”,则此参数是必需的。这是指定训练、验证和测试数据路径的替代方案。
dataset_name 字符串 BigQuery 数据集的名称。如果 input_type=bigquery 并使用标志“预处理”,则此参数是必需的。这是指定训练、验证和测试数据路径的替代方案。
table_name 字符串 BigQuery 表的名称。如果 input_type=bigquery 并使用标志“预处理”,则此参数是必需的。这是指定训练、验证和测试数据路径的替代方案。
loss_function_type 字符串 TabNet 中有多种损失函数类型。对于回归:mse/mae 包含在内。对于分类:cross_entropy/weighted_cross_entropy/focal_loss 包含在内。 否 - 对于回归,默认值为“mse”;对于分类,默认值为“cross_entropy”。
deterministic_data 布尔值参数 从表格数据中读取数据的确定性。默认值设置为 False。设置为 True 时,实验具有确定性。为了在大型数据集上进行快速训练,我们建议 deterministic_data=False 设置,尽管结果存在随机性(在大型数据集中可以忽略不计)。请注意,分布式训练仍不能保证确定性,因为 map-reduce 会因精确率有限的代数运算顺序而导致随机性。但是,这一随机性在实践中可以忽略不计,尤其是在大型数据集上。如果需要 100% 确定性,则除了 deterministic_data=True 设置之外,我们还建议使用单个 GPU(例如,使用 MACHINE_TYPE="n1-highmem-8")进行训练。 否 - 默认值为 False。
stream_inputs 布尔值参数 从 Cloud Storage 流式传输输入数据,而不是在本地下载 - 建议使用此选项以实现快速运行时。
large_category_dim 整数 嵌入的维度 - 如果分类列的不同类别数量大于 large_category_thresh,则使用 large_category_dim 维度嵌入(而不是一维嵌入)。默认值为 1。如果提高准确率是主要目标(而不是计算效率和可解释性),则建议增加该值(例如,在大多数情况下,约增加到 5;如果数据集中的类别数量非常大,则甚至约增加到 10)。 否 - 默认值为 1。
large_category_thresh 整数 分类列基数的阈值 - 如果分类列的不同类别数量大于 large_category_thresh,则使用 large_category_dim 维度嵌入(而不是一维嵌入)。默认值为 300。如果提高准确率是主要目标(而不是计算效率和可解释性),则建议降低该值(例如,约降低到 10)。 否 - 默认值为 300。
yeo_johnson_transform 布尔值参数 启用可训练的 Yeo-Johnson 幂转换(默认为停用状态)。如需详细了解 Yeo-Johnson 幂转换,请查看以下链接:https://www.stat.umn.edu/arc/yjpower.pdf。借助我们的实现,转换参数可与 TabNet 一起学习,并以端到端方式进行训练。
apply_log_transform 布尔值参数 如果元数据中包含日志转换统计信息,并且此标志为 true,则输入特征将进行日志转换。使用 false 表示不使用转换,使用 true(默认值)表示使用转换。尤其是对于数值分布存在偏差的数据集而言,日志转换非常有用。
apply_quantile_transform 布尔值参数 如果元数据中包含分位数统计信息,并且此标志为 true,则输入特征将进行分位数转换。使用 false 表示不使用转换,使用 true(默认值)表示使用转换。尤其是对于数值分布存在偏差的数据集而言,分位数转换非常有用。目前支持 BigQuery input_type。

后续步骤