概览
TabNet 是一种用于表格(结构化)数据的可解释深度学习架构。它融合了两方面的优势:既可以解释较简单的基于树的模型,同时可实现黑盒模型和集成学习的高准确率。这使得 TabNet 非常适合各种表格数据任务,如金融资产价格预测、欺诈/网络攻击/犯罪检测、零售需求预测、来自医疗保健记录的诊断、产品推荐和其他应用。
TabNet 在其架构中采用专门设计的层:依序关注,旨在从模型的每个步骤中选择要推理的模型特征。这种机制可以解释模型如何实现其预测,并帮助其学习更准确的模型。得益于这种设计,TabNet 的表现不仅优于其他神经网络和决策树,同时还提供了可解释的特征归因。
输入数据
TabNet 要求采用以下格式之一的表格输入:
- 训练数据:用于训练模型的标签数据。支持的文件格式如下:
- 输入架构:
- 如果输入为 csv,则第一列将是目标变量。
- 如果输入为 BigQuery,则您需要指定
target_column
参数。
准备 CSV 文件
输入数据可以是采用 UTF-8 编码的 CSV 文件。如果训练数据仅包含分类值和数值,则可以使用我们的预处理模块填充缺失的数值、拆分数据集以及移除缺失值超过 10% 的行。否则,您可以在未启用自动预处理的情况下运行训练。 在 CSV 文件中,第一列是目标变量。如果 CSV 文件包含标题,则必须指定 has_header
参数。
准备 BigQuery 数据集
输入可以是 BigQuery 数据集。您可以通过多种方法将输入数据加载到 BigQuery。
培训
如需使用 TabNet 执行单节点训练,请使用以下命令。此命令会创建一个使用单个 CPU 机器的 CustomJob
资源。如需了解可在训练期间使用的标志,请参阅本页面上的标志部分。
使用 CSV 输入进行训练
以下是使用 CSV 作为输入格式的示例。成功训练模型后,建议您优化训练期间使用的超参数,以提高模型的准确性和性能。教程笔记本提供了超参数训练作业的示例。
# URI of the TabNet Docker image. LEARNER_IMAGE_URI='us-docker.pkg.dev/vertex-ai-restricted/builtin-algorithm/tab_net_v2' # The region to run the job in. REGION='us-central1' # Your project. PROJECT_ID="[your-project-id]" # Set the training data DATASET_NAME="petfinder" # Change to your dataset name. IMPORT_FILE="petfinder-tabular-classification-tabnet-with-header.csv" MODEL_TYPE="classification" # Give a unique name to your training job. DATE="$(date '+%Y%m%d_%H%M%S')" # Set a unique name for the job to run. JOB_NAME="tab_net_cpu_${DATASET_NAME}_${DATE}" echo $JOB_NAME # Define your bucket. YOUR_BUCKET_NAME="gs://[your-bucket-name]" # Replace by your bucket name # Copy the csv to your bucket. TRAINING_DATA_PATH="${YOUR_BUCKET_NAME}/data/${DATASET_NAME}/train.csv" gsutil cp gs://cloud-samples-data/ai-platform-unified/datasets/tabular/${IMPORT_FILE} TRAINING_DATA_PATH # Set a location for the output. OUTPUT_DIR="${YOUR_BUCKET_NAME}/${JOB_NAME}/" echo $OUTPUT_DIR echo $JOB_NAME gcloud ai custom-jobs create \ --region=${REGION} \ --display-name=${JOB_NAME} \ --worker-pool-spec=machine-type=n1-standard-8,replica-count=1,container-image-uri=${LEARNER_IMAGE_URI} \ --args=--preprocess \ --args=--model_type=${MODEL_TYPE} \ --args=--data_has_header \ --args=--training_data_path=${TRAINING_DATA_PATH} \ --args=--job-dir=${OUTPUT_DIR} \ --args=--max_steps=2000 \ --args=--batch_size=4096 \ --args=--learning_rate=0.01
使用 BigQuery 输入进行训练
以下是使用 BigQuery 作为输入的示例。
# URI of the TabNet Docker image. LEARNER_IMAGE_URI='us-docker.pkg.dev/vertex-ai-restricted/builtin-algorithm/tab_net_v2' # The region to run the job in. REGION='us-central1' # Your project. PROJECT_ID="[your-project-id]" # Set the training data DATASET_NAME="petfinder" # Change to your dataset name. IMPORT_FILE="petfinder-tabular-classification-tabnet-with-header.csv" # Give a unique name to your training job. DATE="$(date '+%Y%m%d_%H%M%S')" # Set a unique name for the job to run. JOB_NAME="tab_net_cpu_${DATASET_NAME}_${DATE}" echo $JOB_NAME # Define your bucket. YOUR_BUCKET_NAME="gs://[your-bucket-name]" # Replace by your bucket name # Copy the csv to your bucket. TRAINING_DATA_PATH="${YOUR_BUCKET_NAME}/data/${DATASET_NAME}/train.csv" gsutil cp gs://cloud-samples-data/ai-platform-unified/datasets/tabular/${IMPORT_FILE} TRAINING_DATA_PATH # Create BigQuery dataset. bq --location=${REGION} mk --dataset ${PROJECT_ID}:${DATASET_NAME} # Create BigQuery table using CSV file. TABLE_NAME="train" bq --location=${REGION} load --source_format=CSV --autodetect ${PROJECT_ID}:${DATASET_NAME}.${TABLE_NAME} ${YOUR_BUCKET_NAME}/data/petfinder/train.csv # Set a location for the output. OUTPUT_DIR="${YOUR_BUCKET_NAME}/${JOB_NAME}/" echo $OUTPUT_DIR echo $JOB_NAME gcloud ai custom-jobs create \ --region=${REGION} \ --display-name=${JOB_NAME} \ --worker-pool-spec=machine-type=n1-standard-8,replica-count=1,container-image-uri=${LEARNER_IMAGE_URI} \ --args=--preprocess \ --args=--input_type=bigquery \ --args=--model_type=classification \ --args=--stream_inputs \ --args=--bq_project=${PROJECT_ID} \ --args=--dataset_name=${DATASET_NAME} \ --args=--table_name=${TABLE_NAME} \ --args=--target_column=Adopted \ --args=--num_parallel_reads=2 \ --args=--optimizer_type=adam \ --args=--data_cache=disk \ --args=--deterministic_data=False \ --args=--loss_function_type=weighted_cross_entropy \ --args=--replace_transformed_features=True \ --args=--apply_quantile_transform=True \ --args=--apply_log_transform=True \ --args=--max_steps=2000 \ --args=--batch_size=4096 \ --args=--learning_rate=0.01 \ --args=--job-dir=${OUTPUT_DIR}
了解作业目录
训练作业成功完成后,TabNet 训练会在 Cloud Storage 存储桶中创建经过训练的模型,以及其他一些工件。您可以在 JOB_DIR
中找到以下目录结构:
- artifacts/
- metadata.json
- model/(亦包含
deployment_config.yaml
文件的 TensorFlow SavedModel 目录)- saved_model.pb
- deployment_config.yaml
- processed_data/
- test.csv
- training.csv
- validation.csv
作业目录还包含位于“实验”目录中的各种模型检查点文件。 您可以使用 TensorBoard 来直观呈现指标。最终指标也包含在 deployment_config.yaml
中。
确认您的 JOB_DIR
中的目录结构与上述结构相同:
gsutil ls -a $JOB_DIR/*
教程笔记本
Colab 中有一个示例笔记本,用于启动 TabNet。此笔记本还向您展示了如何:
使用 BigQuery 输入进行训练。
使用 GPU 进行分布式训练。
使用超参数调节。
标志
训练模型时,请使用以下通用训练标志和特定于 TabNet 的训练标志。
通用训练标志
以下自定义训练标志最为常用。如需了解详情,请参阅创建自定义训练作业。
worker-pool-spec
:自定义作业使用的工作器池配置。要创建具有多个工作器池的自定义作业,请指定多个worker-pool-spec
配置。worker-pool-spec
可能包含以下字段,这些字段在 WorkerPoolSpec API 消息中与相应字段一起列出。machine-type
:该池的机器类型。如需查看受支持的机器列表,请参阅机器类型。replica-count
:池中的机器副本数。container-image-uri
:要在每个工作器上运行的 Docker 映像。如需使用 TabNet 内置算法,必须将 Docker 映像设置为us-docker.pkg.dev/vertex-ai-restricted/builtin-algorithm/tab_net_v2:latest
。
display-name
:作业的名称。region
:您希望在其中运行作业的区域。
特定于 TabNet 的训练标志
下表展示了您可以在 TabNet 训练作业中设置的运行时参数:
参数 | 数据类型 | 说明 | 必填 |
---|---|---|---|
preprocess |
布尔值参数 | 指定此参数可启用自动预处理。 | 否 |
job_dir |
字符串 | 用于存储模型输出文件的 Cloud Storage 目录。 | 有 |
input_metadata_path |
字符串 | 指向训练数据集特定于 TabNet 的元数据的 Cloud Storage 路径。请参阅上文,了解如何创建该元数据。 | 否 |
training_data_path |
字符串 | 用于存储训练数据的 Cloud Storage 模式。 | 有 |
validation_data_path |
字符串 | 用于存储评估数据的 Cloud Storage 模式。 | 否 |
test_data_path |
字符串 | 用于存储测试数据的 Cloud Storage 模式。 | 有 |
input_type |
字符串 | “bigquery”或“csv”- 输入表格数据的类型。如果提及到 csv,则第一列被视为目标。如果 CSV 文件包含标题,还应传递“data_has_header”标志。如果使用“bigquery”,则可以提供训练/验证数据路径或提供 BigQuery 项目、数据集和表名称进行预处理,以生成训练和验证数据集。 | 否 - 默认值为“csv”。 |
model_type |
字符串 | 分类或回归等学习任务。 | 有 |
split_column |
字符串 | 用于创建训练、验证和测试拆分的列名称。列值(又称为 table['split_column'])应为“TRAIN”、“VALIDATE”或“TEST”。“TEST”是可选的。仅适用于 bigquery 输入。 | 否 |
train_batch_size |
整数 | 用于训练的批次大小。 | 否 - 默认值为 1024。 |
eval_split |
float | 如果未提供 validation_data_path ,则此参数表示用于评估数据集的拆分比例。 |
否 - 默认值为 0.2 |
learning_rate |
float | 训练的学习速率。 | 否 - 默认为指定的优化器的默认学习速率。 |
eval_frequency_secs |
整数 | 评估和检查点的执行频率,默认值为 600。 | 否 |
num_parallel_reads |
整数 | 用于读取输入文件的线程数。在大多数情况下,我们建议将其设置为等于或略低于机器的 CPU 数量,以实现最高性能。例如,每个 GPU 6 个是一个不错的默认选择。 | 有 |
data_cache |
字符串 | 选择将数据缓存到“memory”、“disk”或“no_cache”。对于大型数据集,将数据缓存到内存中会抛出内存不足错误,因此我们建议选择“disk”。您可以在配置文件中指定磁盘大小(如以下示例所示)。对于大型(B 规模)数据集,请务必请求足够大(例如 TB 大小)的磁盘来写入数据。 | 否 - 默认值为“memory”。 |
bq_project |
字符串 | BigQuery 项目的名称。如果 input_type=bigquery 并使用标志“预处理”,则此参数是必需的。这是指定训练、验证和测试数据路径的替代方案。 | 否 |
dataset_name |
字符串 | BigQuery 数据集的名称。如果 input_type=bigquery 并使用标志“预处理”,则此参数是必需的。这是指定训练、验证和测试数据路径的替代方案。 | 否 |
table_name |
字符串 | BigQuery 表的名称。如果 input_type=bigquery 并使用标志“预处理”,则此参数是必需的。这是指定训练、验证和测试数据路径的替代方案。 | 否 |
loss_function_type |
字符串 | TabNet 中有多种损失函数类型。对于回归:mse/mae 包含在内。对于分类:cross_entropy/weighted_cross_entropy/focal_loss 包含在内。 | 否 - 对于回归,默认值为“mse”;对于分类,默认值为“cross_entropy”。 |
deterministic_data |
布尔值参数 | 从表格数据中读取数据的确定性。默认值设置为 False。设置为 True 时,实验具有确定性。为了在大型数据集上进行快速训练,我们建议 deterministic_data=False 设置,尽管结果存在随机性(在大型数据集中可以忽略不计)。请注意,分布式训练仍不能保证确定性,因为 map-reduce 会因精确率有限的代数运算顺序而导致随机性。但是,这一随机性在实践中可以忽略不计,尤其是在大型数据集上。如果需要 100% 确定性,则除了 deterministic_data=True 设置之外,我们还建议使用单个 GPU(例如,使用 MACHINE_TYPE="n1-highmem-8")进行训练。 | 否 - 默认值为 False。 |
stream_inputs |
布尔值参数 | 从 Cloud Storage 流式传输输入数据,而不是在本地下载 - 建议使用此选项以实现快速运行时。 | 否 |
large_category_dim |
整数 | 嵌入的维度 - 如果分类列的不同类别数量大于 large_category_thresh,则使用 large_category_dim 维度嵌入(而不是一维嵌入)。默认值为 1。如果提高准确率是主要目标(而不是计算效率和可解释性),则建议增加该值(例如,在大多数情况下,约增加到 5;如果数据集中的类别数量非常大,则甚至约增加到 10)。 | 否 - 默认值为 1。 |
large_category_thresh |
整数 | 分类列基数的阈值 - 如果分类列的不同类别数量大于 large_category_thresh,则使用 large_category_dim 维度嵌入(而不是一维嵌入)。默认值为 300。如果提高准确率是主要目标(而不是计算效率和可解释性),则建议降低该值(例如,约降低到 10)。 | 否 - 默认值为 300。 |
yeo_johnson_transform |
布尔值参数 | 启用可训练的 Yeo-Johnson 幂转换(默认为停用状态)。如需详细了解 Yeo-Johnson 幂转换,请查看以下链接:https://www.stat.umn.edu/arc/yjpower.pdf。借助我们的实现,转换参数可与 TabNet 一起学习,并以端到端方式进行训练。 | 否 |
apply_log_transform |
布尔值参数 | 如果元数据中包含日志转换统计信息,并且此标志为 true,则输入特征将进行日志转换。使用 false 表示不使用转换,使用 true(默认值)表示使用转换。尤其是对于数值分布存在偏差的数据集而言,日志转换非常有用。 | 否 |
apply_quantile_transform |
布尔值参数 | 如果元数据中包含分位数统计信息,并且此标志为 true,则输入特征将进行分位数转换。使用 false 表示不使用转换,使用 true(默认值)表示使用转换。尤其是对于数值分布存在偏差的数据集而言,分位数转换非常有用。目前支持 BigQuery input_type。 | 否 |