Melatih model dengan TabNet

Halaman ini menunjukkan cara melatih model regresi atau klasifikasi dari set data tabulasi dengan Tabular Workflow untuk TabNet.

Tersedia dua versi Tabular Workflow untuk TabNet:

  • HyperparameterTuningJob menelusuri kumpulan nilai hyperparameter terbaik yang akan digunakan untuk pelatihan model.
  • CustomJob memungkinkan Anda menentukan nilai hyperparameter yang akan digunakan untuk pelatihan model. Jika Anda tahu persis nilai hyperparameter mana yang dibutuhkan, Anda dapat menentukannya tanpa perlu mencarinya dan menghemat resource pelatihan.

Untuk mempelajari akun layanan yang digunakan oleh alur kerja ini, lihat Akun layanan untuk Tabular Workflows.

API Alur Kerja

Alur kerja ini menggunakan API berikut:

  • Vertex AI
  • Dataflow
  • Compute Engine
  • Cloud Storage

Melatih model dengan HyperparameterTuningJob

Kode contoh berikut ini menunjukkan cara untuk menjalankan pipeline HyperparameterTuningJob:

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

Parameter service_account opsional di pipeline_job.run() memungkinkan Anda menetapkan akun layanan Vertex AI Pipelines ke akun pilihan Anda.

Pipeline dan nilai parameter ditentukan oleh fungsi berikut. Data pelatihan dapat berupa file CSV di dalam Cloud Storage atau tabel di dalam BigQuery.

template_path, parameter_values =  automl_tabular_utils.get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters(...)

Berikut adalah subset parameter get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters:

Nama parameter Jenis Definisi
data_source_csv_filenames String URI untuk CSV yang disimpan di Cloud Storage.
data_source_bigquery_table_path String URI untuk tabel BigQuery.
dataflow_service_account String (Opsional) Akun layanan kustom untuk menjalankan tugas Dataflow. Tugas Dataflow dapat dikonfigurasi untuk menggunakan IP pribadi dan subnet VPC tertentu. Parameter ini berfungsi sebagai pengganti untuk akun layanan worker Dataflow default.
study_spec_parameters_override List[Dict[String, Any]] (Opsional) Penggantian untuk penyesuaian hyperparameter. Parameter ini boleh kosong, atau berisi satu atau beberapa hyperparameter yang memungkinkan. Jika nilai hyperparameter tidak ditetapkan, Vertex AI akan menggunakan rentang penyesuaian default untuk hyperparameter tersebut.

Jika ingin mengonfigurasi hyperparameter menggunakan parameter study_spec_parameters_override, Anda dapat menggunakan fungsi bantuan get_tabnet_study_spec_parameters_override milik Vertex AI. Fungsi tersebut memiliki input berikut:

  • dataset_size_bucket: Bucket untuk ukuran set data
    • 'small': < 1 juta baris
    • 'medium': 1 juta - 100 juta baris
    • 'large': > 100 juta baris
  • training_budget_bucket: Bucket untuk anggaran pelatihan
    • 'small': < $600
    • 'medium': $600 - $2.400
    • 'large': > $2.400
  • prediction_type: Jenis prediksi yang diinginkan

Fungsi get_tabnet_study_spec_parameters_override menampilkan daftar hyperparameter dan rentang.

Berikut ini adalah contoh cara penggunaan fungsi get_tabnet_study_spec_parameters_override:

study_spec_parameters_override = automl_tabular_utils.get_tabnet_study_spec_parameters_override(
    dataset_size_bucket="small",
    prediction_type="classification",
    training_budget_bucket="small",
)

Melatih model dengan CustomJob

Kode contoh berikut ini menunjukkan cara menjalankan pipeline CustomJob:

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

Parameter service_account opsional di pipeline_job.run() memungkinkan Anda menetapkan akun layanan Vertex AI Pipelines ke akun pilihan Anda.

Pipeline dan nilai parameter ditentukan oleh fungsi berikut. Data pelatihan dapat berupa file CSV di dalam Cloud Storage atau tabel di dalam BigQuery.

template_path, parameter_values = automl_tabular_utils.get_tabnet_trainer_pipeline_and_parameters(...)

Berikut adalah subset parameter get_tabnet_trainer_pipeline_and_parameters:

Nama parameter Jenis Definisi
data_source_csv_filenames String URI untuk CSV yang disimpan di Cloud Storage.
data_source_bigquery_table_path String URI untuk tabel BigQuery.
dataflow_service_account String (Opsional) Akun layanan kustom untuk menjalankan tugas Dataflow. Tugas Dataflow dapat dikonfigurasi untuk menggunakan IP pribadi dan subnet VPC tertentu. Parameter ini berfungsi sebagai pengganti untuk akun layanan worker Dataflow default.

Langkah selanjutnya

Setelah siap untuk membuat prediksi dengan model klasifikasi atau regresi, Anda memiliki dua opsi: