Starting on September 15, 2024, you can only customize
classification, entity extraction,
and sentiment analysis objectives by moving to Vertex AI Gemini prompts and tuning. Training or
updating models for Vertex AI AutoML for Text classification, entity extraction, and sentiment
analysis objectives will no longer be available. You can continue using existing Vertex AI AutoML
Text models until June 15, 2025. For a comparison of AutoML text and Gemini, see Gemini for AutoML text users . For more information about how Gemini
offers enhanced user experience through improved prompting capabilities, see Introduction to tuning .
To get started with tuning, see Model tuning for Gemini text models
This page shows you how to train an AutoML classification model from a text
dataset using either the Google Cloud console or the Vertex AI API.
Train an AutoML model
Google Cloud console
In the Google Cloud console, in the Vertex AI section, go to
the Datasets page.
Go to the Datasets page
Click the name of the dataset you want to use to train your model to open
its details page.
Click Train new model .
For the training method, select
radio_button_checked AutoML .
Click Continue .
Enter a name for the model.
If you want manually set how your training data is split, expand Advanced
options and select a data split option.
Learn more .
Click Start Training .
Model training can take many hours, depending on the size and complexity of your
data and your training budget, if you specified one. You can close this tab and
return to it later. You will receive an email when your model has completed
training.
API
Select a tab for your language or environment:
REST
Create a TrainingPipeline
object to train a model.
Before using any of the request data,
make the following replacements:
LOCATION : The region where the model will be created, such as
us-central1
PROJECT : Your project ID
MODEL_DISPLAY_NAME : Name for the model as it appears in the
user interface
MULTI-LABEL : A Boolean value that indicates whether
Vertex AI trains a multi-label model; the default is
false
(single-label model)
DATASET_ID : The ID for the dataset
PROJECT_NUMBER : Your project's automatically generated project number
HTTP method and URL:
POST https://LOCATION -aiplatform.googleapis.com/v1/projects/PROJECT /locations/LOCATION /trainingPipelines
Request JSON body:
{
"displayName": "MODEL_DISPLAY_NAME ",
"trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml",
"trainingTaskInputs": {
"multiLabel": MULTI-LABEL
},
"modelToUpload": {
"displayName": "MODEL_DISPLAY_NAME "
},
"inputDataConfig": {
"datasetId": "DATASET_ID "
}
}
To send your request, expand one of these options:
curl (Linux, macOS, or Cloud Shell)
Note:
The following command assumes that you have logged in to
the gcloud
CLI with your user account by running
gcloud init
or
gcloud auth login
, or by using Cloud Shell ,
which automatically logs you into the gcloud
CLI
.
You can check the currently active account by running
gcloud auth list
.
Save the request body in a file named request.json
,
and execute the following command:
curl -X POST \ -H "Authorization: Bearer $(gcloud auth print-access-token)" \ -H "Content-Type: application/json; charset=utf-8" \ -d @request.json \ "https://LOCATION -aiplatform.googleapis.com/v1/projects/PROJECT /locations/LOCATION /trainingPipelines"
PowerShell (Windows)
Save the request body in a file named request.json
,
and execute the following command:
$cred = gcloud auth print-access-token $headers = @{ "Authorization" = "Bearer $cred" } Invoke-WebRequest ` -Method POST ` -Headers $headers ` -ContentType: "application/json; charset=utf-8" ` -InFile request.json ` -Uri "https://LOCATION -aiplatform.googleapis.com/v1/projects/PROJECT /locations/LOCATION /trainingPipelines" | Select-Object -Expand Content
You should receive a JSON response similar to the following:
{
"name": "projects/PROJECT_NUMBER /locations/us-central1/trainingPipelines/PIPELINE_ID ",
"displayName": "MODEL_DISPLAY_NAME ",
"inputDataConfig": {
"datasetId": "DATASET_ID "
},
"trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml",
"trainingTaskInputs": {
"multiLabel": MULTI-LABEL
},
"modelToUpload": {
"displayName": "MODEL_DISPLAY_NAME "
},
"state": "PIPELINE_STATE_PENDING",
"createTime": "2020-04-18T01:22:57.479336Z",
"updateTime": "2020-04-18T01:22:57.479336Z"
}
Java
Before trying this sample, follow the Java setup instructions in the
Vertex AI quickstart using
client libraries .
For more information, see the
Vertex AI Java API
reference documentation .
To authenticate to Vertex AI, set up Application Default Credentials.
For more information, see
Set up authentication for a local development environment .
import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.Model.ExportFormat;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTextClassificationInputs;
import com.google.rpc.Status;
import java.io.IOException;
public class CreateTrainingPipelineTextClassificationSample {
public static void main(String[] args) throws IOException {
// TODO(developer): Replace these variables before running the sample.
String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME";
String project = "YOUR_PROJECT_ID";
String datasetId = "YOUR_DATASET_ID";
String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
createTrainingPipelineTextClassificationSample(
project, trainingPipelineDisplayName, datasetId, modelDisplayName);
}
static void createTrainingPipelineTextClassificationSample(
String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName)
throws IOException {
PipelineServiceSettings pipelineServiceSettings =
PipelineServiceSettings.newBuilder()
.setEndpoint("us-central1-aiplatform.googleapis.com:443")
.build();
// Initialize client that will be used to send requests. This client only needs to be created
// once, and can be reused for multiple requests. After completing all of your requests, call
// the "close" method on the client to safely clean up any remaining background resources.
try (PipelineServiceClient pipelineServiceClient =
PipelineServiceClient.create(pipelineServiceSettings)) {
String location = "us-central1";
String trainingTaskDefinition =
"gs://google-cloud-aiplatform/schema/trainingjob/definition/"
+ "automl_text_classification_1.0.0.yaml";
LocationName locationName = LocationName.of(project, location);
AutoMlTextClassificationInputs trainingTaskInputs =
AutoMlTextClassificationInputs.newBuilder().setMultiLabel(false).build();
InputDataConfig trainingInputDataConfig =
InputDataConfig.newBuilder().setDatasetId(datasetId).build();
Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
TrainingPipeline trainingPipeline =
TrainingPipeline.newBuilder()
.setDisplayName(trainingPipelineDisplayName)
.setTrainingTaskDefinition(trainingTaskDefinition)
.setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs))
.setInputDataConfig(trainingInputDataConfig)
.setModelToUpload(model)
.build();
TrainingPipeline trainingPipelineResponse =
pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);
System.out.println("Create Training Pipeline Text Classification Response");
System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());
System.out.format(
"\tTraining Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
System.out.format(
"\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
System.out.format(
"\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
System.out.format("State: %s\n", trainingPipelineResponse.getState());
System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
System.out.format("\tStartTime %s\n", trainingPipelineResponse.getStartTime());
System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());
InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig();
System.out.println("\tInput Data Config");
System.out.format("\t\tDataset Id: %s", inputDataConfig.getDatasetId());
System.out.format("\t\tAnnotations Filter: %s\n", inputDataConfig.getAnnotationsFilter());
FractionSplit fractionSplit = inputDataConfig.getFractionSplit();
System.out.println("\t\tFraction Split");
System.out.format("\t\t\tTraining Fraction: %s\n", fractionSplit.getTrainingFraction());
System.out.format("\t\t\tValidation Fraction: %s\n", fractionSplit.getValidationFraction());
System.out.format("\t\t\tTest Fraction: %s\n", fractionSplit.getTestFraction());
FilterSplit filterSplit = inputDataConfig.getFilterSplit();
System.out.println("\t\tFilter Split");
System.out.format("\t\t\tTraining Filter: %s\n", filterSplit.getTrainingFilter());
System.out.format("\t\t\tValidation Filter: %s\n", filterSplit.getValidationFilter());
System.out.format("\t\t\tTest Filter: %s\n", filterSplit.getTestFilter());
PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit();
System.out.println("\t\tPredefined Split");
System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());
TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit();
System.out.println("\t\tTimestamp Split");
System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());
Model modelResponse = trainingPipelineResponse.getModelToUpload();
System.out.println("\tModel To Upload");
System.out.format("\t\tName: %s\n", modelResponse.getName());
System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());
System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
System.out.format("\t\tMetadata: %s\n", modelResponse.getMetadata());
System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());
System.out.format(
"\t\tSupported Deployment Resources Types: %s\n",
modelResponse.getSupportedDeploymentResourcesTypesList());
System.out.format(
"\t\tSupported Input Storage Formats: %s\n",
modelResponse.getSupportedInputStorageFormatsList());
System.out.format(
"\t\tSupported Output Storage Formats: %s\n",
modelResponse.getSupportedOutputStorageFormatsList());
System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
System.out.format("\t\tLabels: %sn\n", modelResponse.getLabelsMap());
PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
System.out.println("\t\tPredict Schemata");
System.out.format("\t\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
System.out.format(
"\t\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
System.out.format(
"\t\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());
for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) {
System.out.println("\t\tSupported Export Format");
System.out.format("\t\t\tId: %s\n", exportFormat.getId());
}
ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec();
System.out.println("\t\tContainer Spec");
System.out.format("\t\t\tImage Uri: %s\n", modelContainerSpec.getImageUri());
System.out.format("\t\t\tCommand: %s\n", modelContainerSpec.getCommandList());
System.out.format("\t\t\tArgs: %s\n", modelContainerSpec.getArgsList());
System.out.format("\t\t\tPredict Route: %s\n", modelContainerSpec.getPredictRoute());
System.out.format("\t\t\tHealth Route: %s\n", modelContainerSpec.getHealthRoute());
for (EnvVar envVar : modelContainerSpec.getEnvList()) {
System.out.println("\t\t\tEnv");
System.out.format("\t\t\t\tName: %s\n", envVar.getName());
System.out.format("\t\t\t\tValue: %s\n", envVar.getValue());
}
for (Port port : modelContainerSpec.getPortsList()) {
System.out.println("\t\t\tPort");
System.out.format("\t\t\t\tContainer Port: %s\n", port.getContainerPort());
}
for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
System.out.println("\t\tDeployed Model");
System.out.format("\t\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());
System.out.format("\t\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
}
Status status = trainingPipelineResponse.getError();
System.out.println("\tError");
System.out.format("\t\tCode: %s\n", status.getCode());
System.out.format("\t\tMessage: %s\n", status.getMessage());
}
}
}
Control the data split using REST
You can control how your training data is split between the training,
validation, and test sets. When using the Vertex AI API, use the
Split
object to determine
your data split. The Split
object can be included in the InputConfig
object
as one of several object types, each of which provides a different way to
split the training data. You can select one method only.
FractionSplit
:
TRAINING_FRACTION : The fraction of the training data to
be used for the training set.
VALIDATION_FRACTION : The fraction of the training data
to be used for the validation set. Not used for video data.
TEST_FRACTION : The fraction of the training data to be
used for the test set.
If any of the fractions are specified, all must be specified. The
fractions must add up to 1.0. The
default values for the fractions
differ depending on your data type.
Learn more .
"fractionSplit": {
"trainingFraction": TRAINING_FRACTION ,
"validationFraction": VALIDATION_FRACTION ,
"testFraction": TEST_FRACTION
},
FilterSplit
:
TRAINING_FILTER : Data items that match this filter are used
for the training set.
VALIDATION_FILTER : Data items that match this filter are
used for the validation set. Must be "-" for video data.
TEST_FILTER : Data items that match this filter are used for
the test set.
These filters can be used with the ml_use
label,
or with any labels you apply to your data. Learn more about using
the ml-use label
and other labels
to filter your data.
The following example shows how to use the filterSplit
object with the ml_use
label, with the validation
set included:
"filterSplit": {
"trainingFilter": "labels.aiplatform.googleapis.com/ml_use=training",
"validationFilter": "labels.aiplatform.googleapis.com/ml_use=validation",
"testFilter": "labels.aiplatform.googleapis.com/ml_use=test"
}
Send feedback
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License , and code samples are licensed under the Apache 2.0 License . For details, see the Google Developers Site Policies . Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-11-22 UTC.
Need to tell us more?
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Hard to understand","hardToUnderstand","thumb-down"],["Incorrect information or sample code","incorrectInformationOrSampleCode","thumb-down"],["Missing the information/samples I need","missingTheInformationSamplesINeed","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-11-22 UTC."],[],[]]