이미지 분류 모델 학습

이 페이지에서는 Google Cloud 콘솔 또는 Vertex AI API를 사용하여 이미지 데이터 세트에서 AutoML 분류 모델을 학습시키는 방법을 보여줍니다.

AutoML 모델 학습

Google Cloud 콘솔

  1. Google Cloud 콘솔의 Vertex AI 섹션에서 데이터 세트 페이지로 이동합니다.

    데이터 세트 페이지로 이동

  2. 모델을 학습시키는 데 사용할 데이터 세트의 이름을 클릭하여 세부정보 페이지를 엽니다.

  3. 새 모델 학습을 클릭합니다.

  4. 학습 메서드로 AutoML을 선택합니다.

  5. 계속을 클릭합니다.

  6. 모델의 이름을 입력합니다.

  7. 학습 데이터 분할 방법을 수동으로 설정하려면 고급 옵션을 펼치고 데이터 분할 옵션을 선택합니다. 자세히 알아보기

  8. 학습 시작을 클릭합니다.

    데이터의 규모 및 복잡성과 학습 예산(지정한 경우)에 따라 모델 학습에 많은 시간이 소요될 수 있습니다. 탭을 닫았다가 나중에 다시 돌아와도 됩니다. 모델 학습이 완료되면 이메일이 전송됩니다.

API

아래에서 목표에 대한 탭을 선택합니다.

분류

아래에서 언어 또는 환경에 대한 탭을 선택하세요.

REST

요청 데이터를 사용하기 전에 다음을 바꿉니다.

  • LOCATION: 데이터 세트가 있고 모델이 생성된 리전입니다. 예를 들면 us-central1입니다.
  • PROJECT: 프로젝트 ID
  • TRAININGPIPELINE_DISPLAYNAME: 필수. trainingPipeline의 표시 이름입니다.
  • DATASET_ID: 학습에 사용할 데이터 세트의 ID 번호입니다.
  • fractionSplit: 선택사항. 가능한 여러 ML 중 하나가 데이터에 분할 옵션을 사용합니다. fractionSplit의 경우 값은 합계가 1이어야 합니다. 예를 들면 다음과 같습니다.
    • {"trainingFraction": "0.7","validationFraction": "0.15","testFraction": "0.15"}
  • MODEL_DISPLAYNAME*: TrainingPipeline에서 업로드(생성)한 모델의 표시 이름입니다.
  • MODEL_DESCRIPTION*: 모델에 대한 설명입니다.
  • modelToUpload.labels*: 모델을 구성할 모든 키-값 쌍 세트입니다. 예를 들면 다음과 같습니다.
    • "env": "prod"
    • "tier": "backend"
  • MODELTYPE: 학습시킬 클라우드 호스팅 모델의 유형입니다. 옵션은 다음과 같습니다.
    • CLOUD(기본)
  • NODE_HOUR_BUDGET: 실제 학습 비용은 이 값보다 작거나 같습니다. Cloud 모델의 경우 예산은 8,000~800,000밀리 노드 시간이어야 합니다(8,000, 800,000 포함). 기본값은 실제 경과 시간으로 1일을 나타내는 192,000이며, 8개의 노드가 사용되었음을 의미합니다.
  • PROJECT_NUMBER: 프로젝트의 자동으로 생성된 프로젝트 번호

HTTP 메서드 및 URL:

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

JSON 요청 본문:

{
  "displayName": "TRAININGPIPELINE_DISPLAYNAME",
  "inputDataConfig": {
    "datasetId": "DATASET_ID",
    "fractionSplit": {
      "trainingFraction": "DECIMAL",
      "validationFraction": "DECIMAL",
      "testFraction": "DECIMAL"
    }
  },
  "modelToUpload": {
    "displayName": "MODEL_DISPLAYNAME",
    "description": "MODEL_DESCRIPTION",
    "labels": {
      "KEY": "VALUE"
    }
  },
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml",
  "trainingTaskInputs": {
    "multiLabel": "false",
    "modelType": ["MODELTYPE"],
    "budgetMilliNodeHours": NODE_HOUR_BUDGET
  }
}

요청을 보내려면 다음 옵션 중 하나를 선택합니다.

curl

요청 본문을 request.json 파일에 저장하고 다음 명령어를 실행합니다.

curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines"

PowerShell

요청 본문을 request.json 파일에 저장하고 다음 명령어를 실행합니다.

$cred = gcloud auth print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }

Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines" | Select-Object -Expand Content

응답에는 사양 및 TRAININGPIPELINE_ID에 대한 정보가 포함됩니다.

Java

이 샘플을 사용해 보기 전에 Vertex AI 빠른 시작: 클라이언트 라이브러리 사용Java 설정 안내를 따르세요. 자세한 내용은 Vertex AI Java API 참고 문서를 참조하세요.

Vertex AI에 인증하려면 애플리케이션 기본 사용자 인증 정보를 설정합니다. 자세한 내용은 로컬 개발 환경의 인증 설정을 참조하세요.

import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.Model.ExportFormat;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType;
import com.google.rpc.Status;
import java.io.IOException;

public class CreateTrainingPipelineImageClassificationSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME";
    String project = "YOUR_PROJECT_ID";
    String datasetId = "YOUR_DATASET_ID";
    String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
    createTrainingPipelineImageClassificationSample(
        project, trainingPipelineDisplayName, datasetId, modelDisplayName);
  }

  static void createTrainingPipelineImageClassificationSample(
      String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
              + "automl_image_classification_1.0.0.yaml";
      LocationName locationName = LocationName.of(project, location);

      AutoMlImageClassificationInputs autoMlImageClassificationInputs =
          AutoMlImageClassificationInputs.newBuilder()
              .setModelType(ModelType.CLOUD)
              .setMultiLabel(false)
              .setBudgetMilliNodeHours(8000)
              .setDisableEarlyStopping(false)
              .build();

      InputDataConfig trainingInputDataConfig =
          InputDataConfig.newBuilder().setDatasetId(datasetId).build();
      Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(trainingPipelineDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(autoMlImageClassificationInputs))
              .setInputDataConfig(trainingInputDataConfig)
              .setModelToUpload(model)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Image Classification Response");
      System.out.format("Name: %s\n", trainingPipelineResponse.getName());
      System.out.format("Display Name: %s\n", trainingPipelineResponse.getDisplayName());

      System.out.format(
          "Training Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "Training Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "Training Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
      System.out.format("State: %s\n", trainingPipelineResponse.getState());

      System.out.format("Create Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("StartTime %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("End Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("Update Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("Labels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig();
      System.out.println("Input Data Config");
      System.out.format("Dataset Id: %s", inputDataConfig.getDatasetId());
      System.out.format("Annotations Filter: %s\n", inputDataConfig.getAnnotationsFilter());

      FractionSplit fractionSplit = inputDataConfig.getFractionSplit();
      System.out.println("Fraction Split");
      System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction());

      FilterSplit filterSplit = inputDataConfig.getFilterSplit();
      System.out.println("Filter Split");
      System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter());
      System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter());
      System.out.format("Test Filter: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit();
      System.out.println("Predefined Split");
      System.out.format("Key: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit();
      System.out.println("Timestamp Split");
      System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("Key: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("Model To Upload");
      System.out.format("Name: %s\n", modelResponse.getName());
      System.out.format("Display Name: %s\n", modelResponse.getDisplayName());
      System.out.format("Description: %s\n", modelResponse.getDescription());

      System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("Metadata: %s\n", modelResponse.getMetadata());
      System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "Supported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList());
      System.out.format(
          "Supported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList());
      System.out.format(
          "Supported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList());

      System.out.format("Create Time: %s\n", modelResponse.getCreateTime());
      System.out.format("Update Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("Labels: %sn\n", modelResponse.getLabelsMap());

      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
      System.out.println("Predict Schemata");
      System.out.format("Instance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format("Parameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format("Prediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) {
        System.out.println("Supported Export Format");
        System.out.format("Id: %s\n", exportFormat.getId());
      }

      ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec();
      System.out.println("Container Spec");
      System.out.format("Image Uri: %s\n", modelContainerSpec.getImageUri());
      System.out.format("Command: %s\n", modelContainerSpec.getCommandList());
      System.out.format("Args: %s\n", modelContainerSpec.getArgsList());
      System.out.format("Predict Route: %s\n", modelContainerSpec.getPredictRoute());
      System.out.format("Health Route: %s\n", modelContainerSpec.getHealthRoute());

      for (EnvVar envVar : modelContainerSpec.getEnvList()) {
        System.out.println("Env");
        System.out.format("Name: %s\n", envVar.getName());
        System.out.format("Value: %s\n", envVar.getValue());
      }

      for (Port port : modelContainerSpec.getPortsList()) {
        System.out.println("Port");
        System.out.format("Container Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("Deployed Model");
        System.out.format("Endpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("Error");
      System.out.format("Code: %s\n", status.getCode());
      System.out.format("Message: %s\n", status.getMessage());
    }
  }
}

Node.js

이 샘플을 사용해 보기 전에 Vertex AI 빠른 시작: 클라이언트 라이브러리 사용Node.js 설정 안내를 따르세요. 자세한 내용은 Vertex AI Node.js API 참고 문서를 참조하세요.

Vertex AI에 인증하려면 애플리케이션 기본 사용자 인증 정보를 설정합니다. 자세한 내용은 로컬 개발 환경의 인증 설정을 참조하세요.

/**
 * TODO(developer): Uncomment these variables before running the sample.
 * (Not necessary if passing values as arguments)
 */
/*
const datasetId = 'YOUR DATASET';
const modelDisplayName = 'NEW MODEL NAME;
const trainingPipelineDisplayName = 'NAME FOR TRAINING PIPELINE';
const project = 'YOUR PROJECT ID';
const location = 'us-central1';
  */
// Imports the Google Cloud Pipeline Service Client library
const aiplatform = require('@google-cloud/aiplatform');

const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;
const ModelType = definition.AutoMlImageClassificationInputs.ModelType;

// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const {PipelineServiceClient} = aiplatform.v1;
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineImageClassification() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  // Values should match the input expected by your model.
  const trainingTaskInputsMessage =
    new definition.AutoMlImageClassificationInputs({
      multiLabel: true,
      modelType: ModelType.CLOUD,
      budgetMilliNodeHours: 8000,
      disableEarlyStopping: false,
    });

  const trainingTaskInputs = trainingTaskInputsMessage.toValue();

  const trainingTaskDefinition =
    'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml';

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {datasetId};
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition,
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {parent, trainingPipeline};

  // Create training pipeline request
  const [response] =
    await pipelineServiceClient.createTrainingPipeline(request);

  console.log('Create training pipeline image classification response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}

createTrainingPipelineImageClassification();

Python

Python을 설치하거나 업데이트하는 방법은 Python용 Vertex AI SDK 설치를 참조하세요. 자세한 내용은 Python API 참고 문서를 참조하세요.

def create_training_pipeline_image_classification_sample(
    project: str,
    location: str,
    display_name: str,
    dataset_id: str,
    model_display_name: Optional[str] = None,
    model_type: str = "CLOUD",
    multi_label: bool = False,
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    disable_early_stopping: bool = False,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    job = aiplatform.AutoMLImageTrainingJob(
        display_name=display_name,
        model_type=model_type,
        prediction_type="classification",
        multi_label=multi_label,
    )

    my_image_ds = aiplatform.ImageDataset(dataset_id)

    model = job.run(
        dataset=my_image_ds,
        model_display_name=model_display_name,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        budget_milli_node_hours=budget_milli_node_hours,
        disable_early_stopping=disable_early_stopping,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

분류

아래에서 언어 또는 환경에 대한 탭을 선택하세요.

REST

요청 데이터를 사용하기 전에 다음을 바꿉니다.

  • LOCATION: 데이터 세트가 있고 모델이 생성된 리전입니다. 예를 들면 us-central1입니다.
  • PROJECT: 프로젝트 ID
  • TRAININGPIPELINE_DISPLAYNAME: 필수. trainingPipeline의 표시 이름입니다.
  • DATASET_ID: 학습에 사용할 데이터 세트의 ID 번호입니다.
  • fractionSplit: 선택사항. 가능한 여러 ML 중 하나가 데이터에 분할 옵션을 사용합니다. fractionSplit의 경우 값은 합계가 1이어야 합니다. 예를 들면 다음과 같습니다.
    • {"trainingFraction": "0.7","validationFraction": "0.15","testFraction": "0.15"}
  • MODEL_DISPLAYNAME*: TrainingPipeline에서 업로드(생성)한 모델의 표시 이름입니다.
  • MODEL_DESCRIPTION*: 모델에 대한 설명입니다.
  • modelToUpload.labels*: 모델을 구성할 모든 키-값 쌍 세트입니다. 예를 들면 다음과 같습니다.
    • "env": "prod"
    • "tier": "backend"
  • MODELTYPE: 학습시킬 클라우드 호스팅 모델의 유형입니다. 옵션은 다음과 같습니다.
    • CLOUD(기본)
  • NODE_HOUR_BUDGET: 실제 학습 비용은 이 값보다 작거나 같습니다. Cloud 모델의 경우 예산은 8,000~800,000밀리 노드 시간이어야 합니다(8,000, 800,000 포함). 기본값은 실제 경과 시간으로 1일을 나타내는 192,000이며, 8개의 노드가 사용되었음을 의미합니다.
  • PROJECT_NUMBER: 프로젝트의 자동으로 생성된 프로젝트 번호

HTTP 메서드 및 URL:

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

JSON 요청 본문:

{
  "displayName": "TRAININGPIPELINE_DISPLAYNAME",
  "inputDataConfig": {
    "datasetId": "DATASET_ID",
    "fractionSplit": {
      "trainingFraction": "DECIMAL",
      "validationFraction": "DECIMAL",
      "testFraction": "DECIMAL"
    }
  },
  "modelToUpload": {
    "displayName": "MODEL_DISPLAYNAME",
    "description": "MODEL_DESCRIPTION",
    "labels": {
      "KEY": "VALUE"
    }
  },
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml",
  "trainingTaskInputs": {
    "multiLabel": "true",
    "modelType": ["MODELTYPE"],
    "budgetMilliNodeHours": NODE_HOUR_BUDGET
  }
}

요청을 보내려면 다음 옵션 중 하나를 선택합니다.

curl

요청 본문을 request.json 파일에 저장하고 다음 명령어를 실행합니다.

curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines"

PowerShell

요청 본문을 request.json 파일에 저장하고 다음 명령어를 실행합니다.

$cred = gcloud auth print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }

Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines" | Select-Object -Expand Content

응답에는 사양 및 TRAININGPIPELINE_ID에 대한 정보가 포함됩니다.

Java

이 샘플을 사용해 보기 전에 Vertex AI 빠른 시작: 클라이언트 라이브러리 사용Java 설정 안내를 따르세요. 자세한 내용은 Vertex AI Java API 참고 문서를 참조하세요.

Vertex AI에 인증하려면 애플리케이션 기본 사용자 인증 정보를 설정합니다. 자세한 내용은 로컬 개발 환경의 인증 설정을 참조하세요.

import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.Model.ExportFormat;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType;
import com.google.rpc.Status;
import java.io.IOException;

public class CreateTrainingPipelineImageClassificationSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME";
    String project = "YOUR_PROJECT_ID";
    String datasetId = "YOUR_DATASET_ID";
    String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
    createTrainingPipelineImageClassificationSample(
        project, trainingPipelineDisplayName, datasetId, modelDisplayName);
  }

  static void createTrainingPipelineImageClassificationSample(
      String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
              + "automl_image_classification_1.0.0.yaml";
      LocationName locationName = LocationName.of(project, location);

      AutoMlImageClassificationInputs autoMlImageClassificationInputs =
          AutoMlImageClassificationInputs.newBuilder()
              .setModelType(ModelType.CLOUD)
              .setMultiLabel(false)
              .setBudgetMilliNodeHours(8000)
              .setDisableEarlyStopping(false)
              .build();

      InputDataConfig trainingInputDataConfig =
          InputDataConfig.newBuilder().setDatasetId(datasetId).build();
      Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(trainingPipelineDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(autoMlImageClassificationInputs))
              .setInputDataConfig(trainingInputDataConfig)
              .setModelToUpload(model)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Image Classification Response");
      System.out.format("Name: %s\n", trainingPipelineResponse.getName());
      System.out.format("Display Name: %s\n", trainingPipelineResponse.getDisplayName());

      System.out.format(
          "Training Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "Training Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "Training Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
      System.out.format("State: %s\n", trainingPipelineResponse.getState());

      System.out.format("Create Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("StartTime %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("End Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("Update Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("Labels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig();
      System.out.println("Input Data Config");
      System.out.format("Dataset Id: %s", inputDataConfig.getDatasetId());
      System.out.format("Annotations Filter: %s\n", inputDataConfig.getAnnotationsFilter());

      FractionSplit fractionSplit = inputDataConfig.getFractionSplit();
      System.out.println("Fraction Split");
      System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction());

      FilterSplit filterSplit = inputDataConfig.getFilterSplit();
      System.out.println("Filter Split");
      System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter());
      System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter());
      System.out.format("Test Filter: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit();
      System.out.println("Predefined Split");
      System.out.format("Key: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit();
      System.out.println("Timestamp Split");
      System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("Key: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("Model To Upload");
      System.out.format("Name: %s\n", modelResponse.getName());
      System.out.format("Display Name: %s\n", modelResponse.getDisplayName());
      System.out.format("Description: %s\n", modelResponse.getDescription());

      System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("Metadata: %s\n", modelResponse.getMetadata());
      System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "Supported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList());
      System.out.format(
          "Supported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList());
      System.out.format(
          "Supported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList());

      System.out.format("Create Time: %s\n", modelResponse.getCreateTime());
      System.out.format("Update Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("Labels: %sn\n", modelResponse.getLabelsMap());

      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
      System.out.println("Predict Schemata");
      System.out.format("Instance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format("Parameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format("Prediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) {
        System.out.println("Supported Export Format");
        System.out.format("Id: %s\n", exportFormat.getId());
      }

      ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec();
      System.out.println("Container Spec");
      System.out.format("Image Uri: %s\n", modelContainerSpec.getImageUri());
      System.out.format("Command: %s\n", modelContainerSpec.getCommandList());
      System.out.format("Args: %s\n", modelContainerSpec.getArgsList());
      System.out.format("Predict Route: %s\n", modelContainerSpec.getPredictRoute());
      System.out.format("Health Route: %s\n", modelContainerSpec.getHealthRoute());

      for (EnvVar envVar : modelContainerSpec.getEnvList()) {
        System.out.println("Env");
        System.out.format("Name: %s\n", envVar.getName());
        System.out.format("Value: %s\n", envVar.getValue());
      }

      for (Port port : modelContainerSpec.getPortsList()) {
        System.out.println("Port");
        System.out.format("Container Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("Deployed Model");
        System.out.format("Endpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("Error");
      System.out.format("Code: %s\n", status.getCode());
      System.out.format("Message: %s\n", status.getMessage());
    }
  }
}

Node.js

이 샘플을 사용해 보기 전에 Vertex AI 빠른 시작: 클라이언트 라이브러리 사용Node.js 설정 안내를 따르세요. 자세한 내용은 Vertex AI Node.js API 참고 문서를 참조하세요.

Vertex AI에 인증하려면 애플리케이션 기본 사용자 인증 정보를 설정합니다. 자세한 내용은 로컬 개발 환경의 인증 설정을 참조하세요.

/**
 * TODO(developer): Uncomment these variables before running the sample.
 * (Not necessary if passing values as arguments)
 */
/*
const datasetId = 'YOUR DATASET';
const modelDisplayName = 'NEW MODEL NAME;
const trainingPipelineDisplayName = 'NAME FOR TRAINING PIPELINE';
const project = 'YOUR PROJECT ID';
const location = 'us-central1';
  */
// Imports the Google Cloud Pipeline Service Client library
const aiplatform = require('@google-cloud/aiplatform');

const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;
const ModelType = definition.AutoMlImageClassificationInputs.ModelType;

// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const {PipelineServiceClient} = aiplatform.v1;
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineImageClassification() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  // Values should match the input expected by your model.
  const trainingTaskInputsMessage =
    new definition.AutoMlImageClassificationInputs({
      multiLabel: true,
      modelType: ModelType.CLOUD,
      budgetMilliNodeHours: 8000,
      disableEarlyStopping: false,
    });

  const trainingTaskInputs = trainingTaskInputsMessage.toValue();

  const trainingTaskDefinition =
    'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml';

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {datasetId};
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition,
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {parent, trainingPipeline};

  // Create training pipeline request
  const [response] =
    await pipelineServiceClient.createTrainingPipeline(request);

  console.log('Create training pipeline image classification response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}

createTrainingPipelineImageClassification();

Python

Python을 설치하거나 업데이트하는 방법은 Python용 Vertex AI SDK 설치를 참조하세요. 자세한 내용은 Python API 참고 문서를 참조하세요.

def create_training_pipeline_image_classification_sample(
    project: str,
    location: str,
    display_name: str,
    dataset_id: str,
    model_display_name: Optional[str] = None,
    model_type: str = "CLOUD",
    multi_label: bool = False,
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    disable_early_stopping: bool = False,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    job = aiplatform.AutoMLImageTrainingJob(
        display_name=display_name,
        model_type=model_type,
        prediction_type="classification",
        multi_label=multi_label,
    )

    my_image_ds = aiplatform.ImageDataset(dataset_id)

    model = job.run(
        dataset=my_image_ds,
        model_display_name=model_display_name,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        budget_milli_node_hours=budget_milli_node_hours,
        disable_early_stopping=disable_early_stopping,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

REST를 사용하여 데이터 분할 제어

학습 데이터가 학습, 검증, 테스트 세트 간에 분할되는 방식을 제어할 수 있습니다. Vertex AI API를 사용할 경우 Split 객체를 사용하여 데이터 분할을 결정합니다. Split 객체는 InputConfig 객체에 여러 객체 유형 중 하나로 포함되어 있으며, 각 객체는 학습 데이터를 분할하는 다른 방법을 제공합니다. 하나의 방법만 선택할 수 있습니다.

  • FractionSplit:
    • TRAINING_FRACTION: 학습 세트에 사용할 학습 데이터의 비율입니다.
    • VALIDATION_FRACTION: 검증 세트에 사용할 학습 데이터의 비율입니다. 동영상 데이터에는 사용되지 않습니다.
    • TEST_FRACTION: 테스트 세트에 사용할 학습 데이터의 비율입니다.

    비율 중 하나라도 지정된 경우 모두 지정해야 합니다. 비율의 합은 1.0이 되어야 합니다. 비율에 대한 기본값은 데이터 유형에 따라 다릅니다. 자세히 알아보기

    "fractionSplit": {
      "trainingFraction": TRAINING_FRACTION,
      "validationFraction": VALIDATION_FRACTION,
      "testFraction": TEST_FRACTION
    },
    
  • FilterSplit:
    • TRAINING_FILTER: 이 필터와 일치하는 데이터 항목은 학습 세트에 사용됩니다.
    • VALIDATION_FILTER: 이 필터와 일치하는 데이터 항목은 검증 세트에 사용됩니다. 동영상 데이터는 '-'여야 합니다.
    • TEST_FILTER: 이 필터와 일치하는 데이터 항목은 테스트 세트에 사용됩니다.

    이 필터는 ml_use 라벨 또는 데이터에 적용하는 모든 라벨과 함께 사용할 수 있습니다. ml-use 라벨기타 라벨을 사용하여 데이터를 필터링하는 방법을 자세히 알아보세요.

    다음 예시에서는 검증 세트가 포함된 ml_use 라벨과 함께 filterSplit 객체를 사용하는 방법을 보여줍니다.

    "filterSplit": {
    "trainingFilter": "labels.aiplatform.googleapis.com/ml_use=training",
    "validationFilter": "labels.aiplatform.googleapis.com/ml_use=validation",
    "testFilter": "labels.aiplatform.googleapis.com/ml_use=test"
    }