训练预测模型

本页面介绍如何使用 Google Cloud 控制台或 Vertex AI API 通过表式数据集训练预测模型。

准备工作

在训练预测模型之前,您必须先完成以下操作:

训练模型

Google Cloud 控制台

  1. 在 Google Cloud 控制台的 Vertex AI 部分中,前往数据集页面。

    转到“数据集”页面

  2. 点击要用于训练模型的数据集的名称,以打开其详情页面。

  3. 如果您的数据类型使用注释集,请选择要用于此模型的注释集。

  4. 点击训练新模型

  5. 选择其他

  6. 训练方法页面中,按如下方式配置:

    1. 选择模型训练方法。 如需了解详情,请参阅模型训练方法

    2. 点击继续

  7. 模型详情页面中,按如下方式配置:

    1. 输入新模型的显示名。

    2. 选择目标列。

      目标列是模型将预测的值。 详细了解目标列要求

    3. 如果您尚未对数据集设置序列标识符时间戳列,请立即选择它们。

    4. 选择您的数据粒度。如果您想使用节假日效应建模,请选择 Daily了解如何选择数据粒度

    5. 可选:在节假日区域下拉菜单中,选择一个或多个地理区域来启用节假日效应建模。 在训练期间,Vertex AI 会根据时间戳列中的日期和指定的地理区域在模型中创建节假日分类特征。只有当数据粒度设置为 Daily 时,才能选择此选项。默认情况下,节假日效应建模处于停用状态。 如需了解用于节假日效应建模的地理区域,请参阅节假日区域

    6. 输入您的上下文窗口预测范围

      预测范围确定模型预测每行预测数据的目标值的未来时间。预测范围数据粒度为单位指定。

      上下文窗口设置模型在训练期间的回溯时间(用于预测)。换句话说,对于每个数据点,上下文窗口会确定模型查找回溯模式的时间。 上下文窗口数据粒度为单位指定。

      了解详情

    7. 如果您要将测试数据集导出到 BigQuery,请勾选将测试数据集导出到 BigQuery (Export test dataset to BigQuery) 并提供表的名称。

    8. 如果您想要手动控制数据拆分或配置预测窗口,请打开高级选项

    9. 默认数据拆分按时间顺序,标准为 80/10/10 百分比。如果您想手动指定将哪些行分配给哪个拆分,请选择手动并指定数据拆分列。

      详细了解数据拆分

    10. 选择用于生成预测窗口的滚动窗口策略。默认策略是计数

      • 计数:在提供的文本框中设置最大窗口数的值。
      • 步长:在提供的文本框中设置步幅的值。
      • :从提供的下拉菜单中选择适当的列名。

      如需了解详情,请参阅滚动窗口策略

    11. 点击继续

  8. 训练选项页面中,按如下方式配置:

    1. 点击生成统计信息(如果您尚未生成)。

      生成统计信息会填充转换下拉菜单。

    2. 检查您的列列表,并从训练中排除任何不应用于训练模型的列。

      如果您要使用数据拆分列,则应包含该列。

    3. 查看为包含的特征选择的转换,并进行任何所需更新。

      系统会从训练中排除包含所选转换无效数据的行。详细了解转换

    4. 对于包含用于训练的每一列,请指定特征类型以获取该特征与其时序的关系,以及它在预测时是否可用。详细了解特征类型和可用性

    5. 如果您想指定权重列、更改默认的优化目标或启用分层预测,请打开高级选项

    6. 可选。如果您想指定权重列,请从下拉列表中选择。详细了解权重列

    7. 可选。如果您想选择优化目标,请从列表中选择。详细了解优化目标

    8. 可选。如果您想使用分层预测,请选择启用分层预测。您可以从三个分组选项中进行选择:

      • No grouping
      • Group by columns
      • Group all

      您还可以选择设置以下汇总损失权重:

      • Group total weight。只有在选择 Group by columnsGroup all 选项时才能设置此字段。
      • Temporal total weight
      • Group temporal total weight。只有在选择 Group by columnsGroup all 选项时才能设置此字段。

      详细了解分层预测

    9. 点击继续

  9. 计算和价格页面中,配置如下:

    1. 输入模型训练的最大小时数。此设置有助于限制训练费用。实际所用的时间可能超过此值,因为创建新模型涉及其他操作。

      建议的训练时间与预测范围和训练数据的大小有关。下表提供一些预测训练运行示例,以及训练高质量模型所需的训练时间范围。

      特征 预测范围 训练时间
      1200 万 10 6 3-6 小时
      2000 万 50 13 6-12 小时
      1600 万 30 365 24-48 小时

      如需了解训练价格,请参阅价格页面

    2. 点击开始训练

      模型训练可能需要几个小时,具体取决于数据的大小和复杂性,以及训练预算(如果指定)。您可以关闭此标签页,稍后再返回。模型完成训练后,您会收到电子邮件。

API

选择语言或环境标签页:

REST

您可以使用 trainingPipelines.create 命令训练模型。

在使用任何请求数据之前,请先进行以下替换:

  • LOCATION:您的区域。
  • PROJECT:您的项目 ID
  • TRAINING_PIPELINE_DISPLAY_NAME:为此操作创建的训练流水线的显示名称。
  • TRAINING_TASK_DEFINITION:模型训练方法。
    • 时序密集编码器 (TiDE)
      gs://google-cloud-aiplatform/schema/trainingjob/definition/time_series_dense_encoder_forecasting_1.0.0.yaml
    • Temporal Fusion Transformer (TFT)
      gs://google-cloud-aiplatform/schema/trainingjob/definition/temporal_fusion_transformer_time_series_forecasting_1.0.0.yaml
    • AutoML (L2L)
      gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_forecasting_1.0.0.yaml
    • Seq2Seq+
      gs://google-cloud-aiplatform/schema/trainingjob/definition/seq2seq_plus_time_series_forecasting_1.0.0.yaml
    如需了解详情,请参阅模型训练方法
  • TARGET_COLUMN:您希望此模型预测的列(值)。
  • TIME_COLUMN:时间列。了解详情
  • TIME_SERIES_IDENTIFIER_COLUMN:时序标识符列。了解详情
  • WEIGHT_COLUMN:(可选)权重列。了解详情
  • TRAINING_BUDGET:您希望模型训练的最长时间,以毫节点时为单位(1,000 毫节点时等于一节点时)。
  • GRANULARITY_UNIT:用于训练数据粒度以及预测水平和上下文窗口的单元。可以是 minutehourdayweekmonthyear。 如果您想使用节假日效应建模,请选择 day了解如何选择数据粒度
  • GRANULARITY_QUANTITY:训练数据两次观察所间隔的粒度单位数。对于所有单位(分钟除外),都必须为 1、5、10、15 或 30。了解如何选择数据粒度
  • GROUP_COLUMNS:训练输入表中用于标识层次结构级别分组的列名。这些列必须是“time_series_attributes_columns”。了解详情
  • GROUP_TOTAL_WEIGHT:组汇总损失相对于单个损失的权重。如果设置为“0.0”或未设置,则停用。如果未设置组列,则所有时序将被视为同一组,并按所有时序汇总。了解详情
  • TEMPORAL_TOTAL_WEIGHT:时间汇总损失相对于单个损失的权重。如果设置为“0.0”或未设置,则停用。了解详情
  • GROUP_TEMPORAL_TOTAL_WEIGHT:总(组 x 时间)汇总损失相对于单个损失的权重。如果设置为“0.0”或未设置,则停用。如果未设置组列,则所有时序将被视为同一组,并按所有时序汇总。了解详情
  • HOLIDAY_REGIONS:(可选)您可以选择一个或多个地理区域来启用节假日效应建模。在训练期间,Vertex AI 会根据 TIME_COLUMN 中的日期和指定的地理区域在模型中创建节假日分类特征。如需启用此功能,请将 GRANULARITY_UNIT 设置为 day,并在 HOLIDAY_REGIONS 字段中指定一个或多个区域。默认情况下,节假日效应建模处于停用状态。 如需了解详情,请参阅节假日区域
  • FORECAST_HORIZON:预测范围确定模型预测每行预测数据的目标值的未来时间。预测范围以数据粒度为单位 (GRANULARITY_UNIT) 指定。了解详情
  • CONTEXT_WINDOW:上下文窗口设置了模型在训练期间的回溯时间(用于预测)。换句话说,对于每个数据点,上下文窗口会确定模型查找回溯模式的时间。 上下文窗口以数据粒度为单位 (GRANULARITY_UNIT) 指定。了解详情
  • OPTIMIZATION_OBJECTIVE:默认情况下,Vertex AI 会最大限度降低均方根误差 (RMSE)。如果您要为预测模型使用其他优化目标,请选择预测模型的优化目标中的一个选项。如果您选择最大限度地减少分位数损失,则还必须为 QUANTILES 指定值。
  • PROBABILISTIC_INFERENCE:(可选)如果设置为 true,则 Vertex AI 会对预测的概率分布建模。概率推理可以通过处理噪声数据并量化不确定性来提高模型质量。如果指定了 QUANTILES,则 Vertex AI 还会返回概率分布的分位数。概率推理仅与 Time series Dense Encoder (TiDE) and the AutoML (L2L) training methods. It is incompatible with hierarchical forecasting and the minimize-quantile-loss optimization objective. 兼容
  • QUANTILES: Quantiles to use for the minimize-quantile-loss optimization objective and probabilistic inference. Provide a list of up to five unique numbers between 0 and 1, exclusive.
  • TIME_SERIES_ATTRIBUTE_COL: The name or names of the columns that are time series attributes. Learn more.
  • AVAILABLE_AT_FORECAST_COL: The name or names of the covariate columns whose value is known at forecast time. Learn more.
  • UNAVAILABLE_AT_FORECAST_COL: The name or names of the covariate columns whose value is unknown at forecast time. Learn more.
  • TRANSFORMATION_TYPE: The transformation type is provided for each column used to train the model. Learn more.
  • COLUMN_NAME: The name of the column with the specified transformation type. Every column used to train the model must be specified.
  • MODEL_DISPLAY_NAME: Display name for the newly trained model.
  • DATASET_ID: ID for the training Dataset.
  • You can provide a Split object to control your data split. For information about controlling data split, see Control the data split using REST.
  • You can provide a windowConfig object to configure a rolling window strategy for forecast window generation. For further information, see Configure the rolling window strategy using REST.
  • PROJECT_NUMBER: Your project's automatically generated project number

HTTP method and URL:

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

Request JSON body:

{
    "displayName": "TRAINING_PIPELINE_DISPLAY_NAME",
    "trainingTaskDefinition": "TRAINING_TASK_DEFINITION",
    "trainingTaskInputs": {
        "targetColumn": "TARGET_COLUMN",
        "timeColumn": "TIME_COLUMN",
        "timeSeriesIdentifierColumn": "TIME_SERIES_IDENTIFIER_COLUMN",
        "weightColumn": "WEIGHT_COLUMN",
        "trainBudgetMilliNodeHours": TRAINING_BUDGET,
        "dataGranularity": {"unit": "GRANULARITY_UNIT", "quantity": GRANULARITY_QUANTITY},
        "hierarchyConfig": {"groupColumns": GROUP_COLUMNS, "groupTotalWeight": GROUP_TOTAL_WEIGHT, "temporalTotalWeight": TEMPORAL_TOTAL_WEIGHT, "groupTemporalTotalWeight": GROUP_TEMPORAL_TOTAL_WEIGHT}
        "holidayRegions" : ["HOLIDAY_REGIONS_1", "HOLIDAY_REGIONS_2", ...]
        "forecast_horizon": FORECAST_HORIZON,
        "context_window": CONTEXT_WINDOW,
        "optimizationObjective": "OPTIMIZATION_OBJECTIVE",
        "quantiles": "QUANTILES",
        "enableProbabilisticInference": "PROBABILISTIC_INFERENCE",
        "time_series_attribute_columns": ["TIME_SERIES_ATTRIBUTE_COL_1", "TIME_SERIES_ATTRIBUTE_COL_2", ...]
        "available_at_forecast_columns": ["AVAILABLE_AT_FORECAST_COL_1", "AVAILABLE_AT_FORECAST_COL_2", ...]
        "unavailable_at_forecast_columns": ["UNAVAILABLE_AT_FORECAST_COL_1", "UNAVAILABLE_AT_FORECAST_COL_2", ...]
        "transformations": [
            {"TRANSFORMATION_TYPE_1":  {"column_name" : "COLUMN_NAME_1"} },
            {"TRANSFORMATION_TYPE_2":  {"column_name" : "COLUMN_NAME_2"} },
            ...
    },
    "modelToUpload": {"displayName": "MODEL_DISPLAY_NAME"},
    "inputDataConfig": {
      "datasetId": "DATASET_ID",
    }
}

To send your request, expand one of these options:

You should receive a JSON response similar to the following:

{
  "name": "projects/PROJECT_NUMBER/locations/LOCATION/trainingPipelines/TRAINING_PIPELINE_ID",
  "displayName": "myModelName",
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",
  "modelToUpload": {
    "displayName": "myModelName"
  },
  "state": "PIPELINE_STATE_PENDING",
  "createTime": "2020-08-18T01:22:57.479336Z",
  "updateTime": "2020-08-18T01:22:57.479336Z"
}

Python

To learn how to install or update the Vertex AI SDK for Python, see Install the Vertex AI SDK for Python. For more information, see the Python API reference documentation.

def create_training_pipeline_forecasting_time_series_dense_encoder_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    location: str = "us-central1",
    model_display_name: str = "my_model",
    target_column: str = "target_column",
    time_column: str = "date",
    time_series_identifier_column: str = "time_series_id",
    unavailable_at_forecast_columns: List[str] = [],
    available_at_forecast_columns: List[str] = [],
    forecast_horizon: int = 1,
    data_granularity_unit: str = "week",
    data_granularity_count: int = 1,
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    timestamp_split_column_name: str = "timestamp_split",
    weight_column: str = "weight",
    time_series_attribute_columns: List[str] = [],
    context_window: int = 0,
    export_evaluated_data_items: bool = False,
    export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
    export_evaluated_data_items_override_destination: bool = False,
    quantiles: Optional[List[float]] = None,
    enable_probabilistic_inference: bool = False,
    validation_options: Optional[str] = None,
    predefined_split_column_name: Optional[str] = None,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    # Create training job
    forecasting_tide_job = aiplatform.TimeSeriesDenseEncoderForecastingTrainingJob(
        display_name=display_name,
        optimization_objective="minimize-rmse",
    )

    # Retrieve existing dataset
    dataset = aiplatform.TimeSeriesDataset(dataset_id)

    # Run training job
    model = forecasting_tide_job.run(
        dataset=dataset,
        target_column=target_column,
        time_column=time_column,
        time_series_identifier_column=time_series_identifier_column,
        unavailable_at_forecast_columns=unavailable_at_forecast_columns,
        available_at_forecast_columns=available_at_forecast_columns,
        forecast_horizon=forecast_horizon,
        data_granularity_unit=data_granularity_unit,
        data_granularity_count=data_granularity_count,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        predefined_split_column_name=predefined_split_column_name,
        timestamp_split_column_name=timestamp_split_column_name,
        weight_column=weight_column,
        time_series_attribute_columns=time_series_attribute_columns,
        context_window=context_window,
        export_evaluated_data_items=export_evaluated_data_items,
        export_evaluated_data_items_bigquery_destination_uri=export_evaluated_data_items_bigquery_destination_uri,
        export_evaluated_data_items_override_destination=export_evaluated_data_items_override_destination,
        quantiles=quantiles,
        enable_probabilistic_inference=enable_probabilistic_inference,
        validation_options=validation_options,
        budget_milli_node_hours=budget_milli_node_hours,
        model_display_name=model_display_name,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

Python

To learn how to install or update the Vertex AI SDK for Python, see Install the Vertex AI SDK for Python. For more information, see the Python API reference documentation.

def create_training_pipeline_forecasting_temporal_fusion_transformer_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    location: str = "us-central1",
    model_display_name: str = "my_model",
    target_column: str = "target_column",
    time_column: str = "date",
    time_series_identifier_column: str = "time_series_id",
    unavailable_at_forecast_columns: List[str] = [],
    available_at_forecast_columns: List[str] = [],
    forecast_horizon: int = 1,
    data_granularity_unit: str = "week",
    data_granularity_count: int = 1,
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    timestamp_split_column_name: str = "timestamp_split",
    weight_column: str = "weight",
    time_series_attribute_columns: List[str] = [],
    context_window: int = 0,
    export_evaluated_data_items: bool = False,
    export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
    export_evaluated_data_items_override_destination: bool = False,
    validation_options: Optional[str] = None,
    predefined_split_column_name: Optional[str] = None,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    # Create training job
    forecasting_tft_job = aiplatform.TemporalFusionTransformerForecastingTrainingJob(
        display_name=display_name,
        optimization_objective="minimize-rmse",
    )

    # Retrieve existing dataset
    dataset = aiplatform.TimeSeriesDataset(dataset_id)

    # Run training job
    model = forecasting_tft_job.run(
        dataset=dataset,
        target_column=target_column,
        time_column=time_column,
        time_series_identifier_column=time_series_identifier_column,
        unavailable_at_forecast_columns=unavailable_at_forecast_columns,
        available_at_forecast_columns=available_at_forecast_columns,
        forecast_horizon=forecast_horizon,
        data_granularity_unit=data_granularity_unit,
        data_granularity_count=data_granularity_count,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        predefined_split_column_name=predefined_split_column_name,
        timestamp_split_column_name=timestamp_split_column_name,
        weight_column=weight_column,
        time_series_attribute_columns=time_series_attribute_columns,
        context_window=context_window,
        export_evaluated_data_items=export_evaluated_data_items,
        export_evaluated_data_items_bigquery_destination_uri=export_evaluated_data_items_bigquery_destination_uri,
        export_evaluated_data_items_override_destination=export_evaluated_data_items_override_destination,
        validation_options=validation_options,
        budget_milli_node_hours=budget_milli_node_hours,
        model_display_name=model_display_name,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

Python

To learn how to install or update the Vertex AI SDK for Python, see Install the Vertex AI SDK for Python. For more information, see the Python API reference documentation.

def create_training_pipeline_forecasting_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    location: str = "us-central1",
    model_display_name: str = "my_model",
    target_column: str = "target_column",
    time_column: str = "date",
    time_series_identifier_column: str = "time_series_id",
    unavailable_at_forecast_columns: List[str] = [],
    available_at_forecast_columns: List[str] = [],
    forecast_horizon: int = 1,
    data_granularity_unit: str = "week",
    data_granularity_count: int = 1,
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    timestamp_split_column_name: str = "timestamp_split",
    weight_column: str = "weight",
    time_series_attribute_columns: List[str] = [],
    context_window: int = 0,
    export_evaluated_data_items: bool = False,
    export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
    export_evaluated_data_items_override_destination: bool = False,
    quantiles: Optional[List[float]] = None,
    enable_probabilistic_inference: bool = False,
    validation_options: Optional[str] = None,
    predefined_split_column_name: Optional[str] = None,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    # Create training job
    forecasting_job = aiplatform.AutoMLForecastingTrainingJob(
        display_name=display_name, optimization_objective="minimize-rmse"
    )

    # Retrieve existing dataset
    dataset = aiplatform.TimeSeriesDataset(dataset_id)

    # Run training job
    model = forecasting_job.run(
        dataset=dataset,
        target_column=target_column,
        time_column=time_column,
        time_series_identifier_column=time_series_identifier_column,
        unavailable_at_forecast_columns=unavailable_at_forecast_columns,
        available_at_forecast_columns=available_at_forecast_columns,
        forecast_horizon=forecast_horizon,
        data_granularity_unit=data_granularity_unit,
        data_granularity_count=data_granularity_count,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        predefined_split_column_name=predefined_split_column_name,
        timestamp_split_column_name=timestamp_split_column_name,
        weight_column=weight_column,
        time_series_attribute_columns=time_series_attribute_columns,
        context_window=context_window,
        export_evaluated_data_items=export_evaluated_data_items,
        export_evaluated_data_items_bigquery_destination_uri=export_evaluated_data_items_bigquery_destination_uri,
        export_evaluated_data_items_override_destination=export_evaluated_data_items_override_destination,
        quantiles=quantiles,
        enable_probabilistic_inference=enable_probabilistic_inference,
        validation_options=validation_options,
        budget_milli_node_hours=budget_milli_node_hours,
        model_display_name=model_display_name,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

Python

To learn how to install or update the Vertex AI SDK for Python, see Install the Vertex AI SDK for Python. For more information, see the Python API reference documentation.

def create_training_pipeline_forecasting_seq2seq_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    location: str = "us-central1",
    model_display_name: str = "my_model",
    target_column: str = "target_column",
    time_column: str = "date",
    time_series_identifier_column: str = "time_series_id",
    unavailable_at_forecast_columns: List[str] = [],
    available_at_forecast_columns: List[str] = [],
    forecast_horizon: int = 1,
    data_granularity_unit: str = "week",
    data_granularity_count: int = 1,
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    timestamp_split_column_name: str = "timestamp_split",
    weight_column: str = "weight",
    time_series_attribute_columns: List[str] = [],
    context_window: int = 0,
    export_evaluated_data_items: bool = False,
    export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
    export_evaluated_data_items_override_destination: bool = False,
    validation_options: Optional[str] = None,
    predefined_split_column_name: Optional[str] = None,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    # Create training job
    forecasting_seq2seq_job = aiplatform.SequenceToSequencePlusForecastingTrainingJob(
        display_name=display_name, optimization_objective="minimize-rmse"
    )

    # Retrieve existing dataset
    dataset = aiplatform.TimeSeriesDataset(dataset_id)

    # Run training job
    model = forecasting_seq2seq_job.run(
        dataset=dataset,
        target_column=target_column,
        time_column=time_column,
        time_series_identifier_column=time_series_identifier_column,
        unavailable_at_forecast_columns=unavailable_at_forecast_columns,
        available_at_forecast_columns=available_at_forecast_columns,
        forecast_horizon=forecast_horizon,
        data_granularity_unit=data_granularity_unit,
        data_granularity_count=data_granularity_count,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        predefined_split_column_name=predefined_split_column_name,
        timestamp_split_column_name=timestamp_split_column_name,
        weight_column=weight_column,
        time_series_attribute_columns=time_series_attribute_columns,
        context_window=context_window,
        export_evaluated_data_items=export_evaluated_data_items,
        export_evaluated_data_items_bigquery_destination_uri=export_evaluated_data_items_bigquery_destination_uri,
        export_evaluated_data_items_override_destination=export_evaluated_data_items_override_destination,
        validation_options=validation_options,
        budget_milli_node_hours=budget_milli_node_hours,
        model_display_name=model_display_name,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

Control the data split using REST

You can control how your training data is split between the training, validation, and test sets. Use a split column to manually specify the data split for each row and provide it as part of a PredefinedSplit Split object in the inputDataConfig of the JSON request.

DATA_SPLIT_COLUMN is the column containing the data split values (TRAIN, VALIDATION, TEST).

 "predefinedSplit": {
   "key": DATA_SPLIT_COLUMN
 },

Learn more about data splits.

Configure the rolling window strategy using REST

You can provide a windowConfig object to configure a rolling window strategy for forecast window generation. The default strategy is maxCount.

  • To use the maxCount option, add the following to trainingTaskInputs of the JSON request. MAX_COUNT_VALUE refers to the maximum number of windows.

     "windowConfig": {
       "maxCount": MAX_COUNT_VALUE
     },
     ```
    
  • To use the strideLength option, add the following to trainingTaskInputs of the JSON request. STRIDE_LENGTH_VALUE refers to the value of the stride length.

     "windowConfig": {
       "strideLength": STRIDE_LENGTH_VALUE
     },
     ```
    
  • To use the column option, add the following to trainingTaskInputs of the JSON request. COLUMN_NAME refers to the name of the column with True or False values.

     "windowConfig": {
       "column": "COLUMN_NAME"
     },
     ```
    

To learn more, see Rolling window strategies.

What's next