Crea una canalización de entrenamiento para la previsión tabular
Crea una canalización de entrenamiento para la previsión tabular con el método create_training_pipeline.
Explora más
Para obtener documentación detallada en la que se incluye esta muestra de código, consulta lo siguiente:
Muestra de código
Si deseas obtener información para instalar y usar la biblioteca cliente de Vertex AI, consulta las bibliotecas cliente de Vertex AI.
Si deseas obtener más información, consulta la documentación de referencia de la API de Vertex AI para Python.
from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
def create_training_pipeline_tabular_forecasting_sample(
project: str,
display_name: str,
dataset_id: str,
model_display_name: str,
target_column: str,
time_series_identifier_column: str,
time_column: str,
time_series_attribute_columns: str,
unavailable_at_forecast: str,
available_at_forecast: str,
forecast_horizon: int,
location: str = "us-central1",
api_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
# The AI Platform services require regional API endpoints.
client_options = {"api_endpoint": api_endpoint}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
client = aiplatform.gapic.PipelineServiceClient(client_options=client_options)
# set the columns used for training and their data types
transformations = [
{"auto": {"column_name": "date"}},
{"auto": {"column_name": "state_name"}},
{"auto": {"column_name": "county_fips_code"}},
{"auto": {"column_name": "confirmed_cases"}},
{"auto": {"column_name": "deaths"}},
]
data_granularity = {"unit": "day", "quantity": 1}
# the inputs should be formatted according to the training_task_definition yaml file
training_task_inputs_dict = {
# required inputs
"targetColumn": target_column,
"timeSeriesIdentifierColumn": time_series_identifier_column,
"timeColumn": time_column,
"transformations": transformations,
"dataGranularity": data_granularity,
"optimizationObjective": "minimize-rmse",
"trainBudgetMilliNodeHours": 8000,
"timeSeriesAttributeColumns": time_series_attribute_columns,
"unavailableAtForecast": unavailable_at_forecast,
"availableAtForecast": available_at_forecast,
"forecastHorizon": forecast_horizon,
}
training_task_inputs = json_format.ParseDict(training_task_inputs_dict, Value())
training_pipeline = {
"display_name": display_name,
"training_task_definition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_forecasting_1.0.0.yaml",
"training_task_inputs": training_task_inputs,
"input_data_config": {
"dataset_id": dataset_id,
"fraction_split": {
"training_fraction": 0.8,
"validation_fraction": 0.1,
"test_fraction": 0.1,
},
},
"model_to_upload": {"display_name": model_display_name},
}
parent = f"projects/{project}/locations/{location}"
response = client.create_training_pipeline(
parent=parent, training_pipeline=training_pipeline
)
print("response:", response)
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
[{
"type": "thumb-down",
"id": "hardToUnderstand",
"label":"Difícil de entender"
},{
"type": "thumb-down",
"id": "incorrectInformationOrSampleCode",
"label":"Información o código de muestra incorrectos"
},{
"type": "thumb-down",
"id": "missingTheInformationSamplesINeed",
"label":"Faltan la información o los ejemplos que necesito"
},{
"type": "thumb-down",
"id": "translationIssue",
"label":"Problema de traducción"
},{
"type": "thumb-down",
"id": "otherDown",
"label":"Otro"
}]
[{
"type": "thumb-up",
"id": "easyToUnderstand",
"label":"Fácil de comprender"
},{
"type": "thumb-up",
"id": "solvedMyProblem",
"label":"Resolvió mi problema"
},{
"type": "thumb-up",
"id": "otherUp",
"label":"Otro"
}]