TabNet을 사용한 모델 학습

이 페이지에서는 TabNet용 테이블 형식 워크플로를 사용하여 테이블 형식의 데이터 세트에서 분류 또는 회귀 모델을 학습시키는 방법을 보여줍니다.

TabNet용 테이블 형식 워크플로에는 두 가지 버전이 있습니다.

  • HyperparameterTuningJob은 모델 학습에 사용할 최적의 초매개변수 값 집합을 검색합니다.
  • CustomJob을 사용하면 모델 학습에 사용할 초매개변수 값을 지정할 수 있습니다. 필요한 초매개변수 값을 정확히 알고 있으면 해당 값을 검색하지 않고 지정하고 학습 리소스에 저장할 수 있습니다.

이 워크플로에 사용되는 서비스 계정에 대한 자세한 내용은 테이블 형식 워크플로의 서비스 계정을 참조하세요.

워크플로 API

이 워크플로에서는 다음 API를 사용합니다.

  • Vertex AI
  • Dataflow
  • Compute Engine
  • Cloud Storage

HyperparameterTuningJob을 사용한 모델 학습

다음 샘플 코드에서는 HyperparameterTuningJob 파이프라인을 실행하는 방법을 보여줍니다.

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

pipeline_job.run()의 선택사항인 service_account 매개변수를 사용하면 Vertex AI Pipelines 서비스 계정을 원하는 계정으로 설정할 수 있습니다.

파이프라인과 매개변수 값은 다음 함수로 정의됩니다. 학습 데이터는 Cloud Storage의 CSV 파일이거나 BigQuery의 테이블일 수 있습니다.

template_path, parameter_values =  automl_tabular_utils.get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters(...)

다음은 get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters 매개변수의 하위 집합입니다.

매개변수 이름 유형 정의
data_source_csv_filenames 문자열 Cloud Storage에 저장된 CSV의 URI입니다.
data_source_bigquery_table_path 문자열 BigQuery 테이블의 URI입니다.
dataflow_service_account 문자열 (선택사항) Dataflow 작업을 실행하기 위한 커스텀 서비스 계정입니다. 비공개 IP와 특정 VPC 서브넷을 사용하도록 Dataflow 작업은 비공개 IP와 특정 VPC 서브넷을 사용하도록 구성할 수 있습니다. 이 매개변수는 기본 Dataflow 작업자 서비스 계정을 재정의하는 역할을 합니다.
study_spec_parameters_override List[Dict[문자열, 무관]] (선택사항) 초매개변수 조정을 재정의합니다. 이 매개변수는 비어 있거나 가능한 초매개변수를 하나 이상 포함할 수 있습니다. 초매개변수 값이 설정되지 않은 경우 Vertex AI는 초매개변수에 기본 조정 범위를 사용합니다.

study_spec_parameters_override 매개변수를 사용하여 초매개변수를 구성하려면 Vertex AI의 도우미 함수 get_tabnet_study_spec_parameters_override를 사용할 수 있습니다. 이 함수의 입력은 다음과 같습니다.

  • dataset_size_bucket: 데이터 세트 크기의 버킷
    • 'small': 행 100만개 미만
    • 'medium': 행 100만~1억 개
    • 'large': 행 1억개 초과
  • training_budget_bucket: 학습 예산의 버킷
    • 'small': $600 미만
    • 'medium': $600~$2,400
    • 'large': $2,400 초과
  • prediction_type: 원하는 예측 유형

get_tabnet_study_spec_parameters_override 함수는 초매개변수와 범위의 목록을 반환합니다.

다음은 get_tabnet_study_spec_parameters_override 함수를 사용하는 방법의 예시입니다.

study_spec_parameters_override = automl_tabular_utils.get_tabnet_study_spec_parameters_override(
    dataset_size_bucket="small",
    prediction_type="classification",
    training_budget_bucket="small",
)

CustomJob을 사용한 모델 학습

다음 샘플 코드에서는 CustomJob 파이프라인을 실행하는 방법을 보여줍니다.

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

pipeline_job.run()의 선택사항인 service_account 매개변수를 사용하면 Vertex AI Pipelines 서비스 계정을 원하는 계정으로 설정할 수 있습니다.

파이프라인과 매개변수 값은 다음 함수로 정의됩니다. 학습 데이터는 Cloud Storage의 CSV 파일이거나 BigQuery의 테이블일 수 있습니다.

template_path, parameter_values = automl_tabular_utils.get_tabnet_trainer_pipeline_and_parameters(...)

다음은 get_tabnet_trainer_pipeline_and_parameters 매개변수의 하위 집합입니다.

매개변수 이름 유형 정의
data_source_csv_filenames 문자열 Cloud Storage에 저장된 CSV의 URI입니다.
data_source_bigquery_table_path 문자열 BigQuery 테이블의 URI입니다.
dataflow_service_account 문자열 (선택사항) Dataflow 작업을 실행하기 위한 커스텀 서비스 계정입니다. 비공개 IP와 특정 VPC 서브넷을 사용하도록 Dataflow 작업은 비공개 IP와 특정 VPC 서브넷을 사용하도록 구성할 수 있습니다. 이 매개변수는 기본 Dataflow 작업자 서비스 계정을 재정의하는 역할을 합니다.

다음 단계

분류 또는 회귀 모델을 사용하여 예측할 준비가 되면 다음 두 가지 옵션이 있습니다.