Use a custom container for prediction

To customize how Vertex AI serves online predictions from your custom-trained model, you can specify a custom container instead of a pre-built container when you create a Model resource. When you use a custom container, Vertex AI runs a Docker container of your choice on each prediction node.

You might want to use a custom container for any of the following reasons:

  • to serve predictions from an ML model trained using a framework other than TensorFlow, scikit-learn, or XGBoost
  • to preprocess prediction requests or postprocess the predictions generated by your model
  • to run a prediction server written in a programming language of your choice
  • to install dependencies that you want to use to customize prediction

This guide describes how to create a Model that uses a custom container. It does not provide detailed instructions about designing and creating a Docker container image.

Prepare a container image

To create a Model that uses a custom container, you must provide a Docker container image as the basis of that container. This container image must meet the requirements described in Custom container requirements.

If you plan to use an existing container image created by a third party that you trust, then you might be able to skip one or both of the following sections.

Create a container image

Design and build a Docker container image that meets the container image requirements.

To learn the basics of designing and building a Docker container image, read the Docker documentation's quickstart

Push the container image to Artifact Registry or Container Registry

Push your container image to an Artifact Registry repository or a Container Registry repository that meets the container image publishing requirements.

Learn how to push a container image to Artifact Registry or push a container image to Container Registry.

Create a Model

To create a Model that uses a custom container, do one of the following:

The follow sections show how to configure the API fields related to custom containers when you create a Model in one of these ways.

Container-related API fields

When you create the Model, make sure to configure the containerSpec field with your custom container details, rather than with a pre-built container.

You must specify a ModelContainerSpec message in the Model.containerSpec field. Within this message, you can specify the following subfields:

imageUri (required)

The Artifact Registry or Container Registry URI of your container image.

If you are using the gcloud ai models upload command, then you can use the --container-image-uri flag to specify this field.

command (optional)

An array of an executable and arguments to override the container's ENTRYPOINT. To learn more about how to format this field and how it interacts with the args field, read the API reference for ModelContainerSpec.

If you are using the gcloud ai models upload command, then you can use the --container-command flag to specify this field.

args (optional)

An array of an executable and arguments to override the container's CMD To learn more about how to format this field and how it interacts with the command field, read the API reference for ModelContainerSpec.

If you are using the gcloud ai models upload command, then you can use the --container-args flag to specify this field.

ports (optional)

An array of ports; Vertex AI sends liveness checks, health checks, and prediction requests to your container on the first port listed, or 8080 by default. Specifying additional ports has no effect.

If you are using the gcloud ai models upload command, then you can use the --container-ports flag to specify this field.

env (optional)

An array of environment variables that the container's entrypoint command, as well as the command and args fields, can reference. To learn more about how other fields can reference these environment variables, read the API reference for ModelContainerSpec.

If you are using the gcloud ai models upload command, then you can use the --container-env-vars flag to specify this field.

healthRoute (optional)

The path on your container's HTTP server where you want Vertex AI to send health checks.

If you don't specify this field, then when you deploy the Model as a DeployedModel to an Endpoint resource it defaults to /v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL, where ENDPOINT is replaced by the last segment of the Endpoint's name field (following endpoints/) and DEPLOYED_MODEL is replaced by the DeployedModel's id field.

If you are using the gcloud ai models upload command, then you can use the --container-health-route flag to specify this field.

predictRoute (optional)

The path on your container's HTTP server where you want Vertex AI to forward prediction requests.

If you don't specify this field, then when you deploy the Model as a DeployedModel to an Endpoint resource it defaults to /v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL, where ENDPOINT is replaced by the last segment of the Endpoint's name field (following endpoints/) and DEPLOYED_MODEL is replaced by the DeployedModel's id field.

If you are using the gcloud ai models upload command, then you can use the --container-predict-route flag to specify this field.

In addition to the variables that you set in the Model.containerSpec.env field, Vertex AI sets several other variables based on your configuration. Learn more about using these environment variables in these fields and in the container's entrypoint command.

Model import examples

The following examples show how to specify container-related API fields when you import a model.

gcloud

The following example uses the gcloud ai models upload command:

gcloud ai models upload \
  --region=LOCATION \
  --display-name=MODEL_NAME \
  --container-image-uri=IMAGE_URI \
  --container-command=COMMAND \
  --container-args=ARGS \
  --container-ports=PORTS \
  --container-env-vars=ENV \
  --container-health-route=HEATLH_ROUTE \
  --container-predict-route=PREDICT_ROUTE \
  --artifact-uri=PATH_TO_MODEL_ARTIFACT_DIRECTORY

The --container-image-uri flag is required; all other flags that begin with --container- are optional. To learn about the values for these fields, see the preceding section of this guide.

Java


import com.google.api.gax.longrunning.OperationFuture;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.ModelServiceClient;
import com.google.cloud.aiplatform.v1.ModelServiceSettings;
import com.google.cloud.aiplatform.v1.UploadModelOperationMetadata;
import com.google.cloud.aiplatform.v1.UploadModelResponse;
import java.io.IOException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

public class UploadModelSample {
  public static void main(String[] args)
      throws InterruptedException, ExecutionException, TimeoutException, IOException {
    // TODO(developer): Replace these variables before running the sample.
    String project = "YOUR_PROJECT_ID";
    String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
    String metadataSchemaUri =
        "gs://google-cloud-aiplatform/schema/trainingjob/definition/custom_task_1.0.0.yaml";
    String imageUri = "YOUR_IMAGE_URI";
    String artifactUri = "gs://your-gcs-bucket/artifact_path";
    uploadModel(project, modelDisplayName, metadataSchemaUri, imageUri, artifactUri);
  }

  static void uploadModel(
      String project,
      String modelDisplayName,
      String metadataSchemaUri,
      String imageUri,
      String artifactUri)
      throws IOException, InterruptedException, ExecutionException, TimeoutException {
    ModelServiceSettings modelServiceSettings =
        ModelServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) {
      String location = "us-central1";
      LocationName locationName = LocationName.of(project, location);

      ModelContainerSpec modelContainerSpec =
          ModelContainerSpec.newBuilder().setImageUri(imageUri).build();

      Model model =
          Model.newBuilder()
              .setDisplayName(modelDisplayName)
              .setMetadataSchemaUri(metadataSchemaUri)
              .setArtifactUri(artifactUri)
              .setContainerSpec(modelContainerSpec)
              .build();

      OperationFuture<UploadModelResponse, UploadModelOperationMetadata> uploadModelResponseFuture =
          modelServiceClient.uploadModelAsync(locationName, model);
      System.out.format(
          "Operation name: %s\n", uploadModelResponseFuture.getInitialFuture().get().getName());
      System.out.println("Waiting for operation to finish...");
      UploadModelResponse uploadModelResponse = uploadModelResponseFuture.get(5, TimeUnit.MINUTES);

      System.out.println("Upload Model Response");
      System.out.format("Model: %s\n", uploadModelResponse.getModel());
    }
  }
}

Node.js

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 */

// const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME';
// const metadataSchemaUri = 'YOUR_METADATA_SCHEMA_URI';
// const imageUri = 'YOUR_IMAGE_URI';
// const artifactUri = 'YOUR_ARTIFACT_URI';
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';

// Imports the Google Cloud Model Service Client library
const {ModelServiceClient} = require('@google-cloud/aiplatform');

// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const modelServiceClient = new ModelServiceClient(clientOptions);

async function uploadModel() {
  // Configure the parent resources
  const parent = `projects/${project}/locations/${location}`;
  // Configure the model resources
  const model = {
    displayName: modelDisplayName,
    metadataSchemaUri: '',
    artifactUri: artifactUri,
    containerSpec: {
      imageUri: imageUri,
      command: [],
      args: [],
      env: [],
      ports: [],
      predictRoute: '',
      healthRoute: '',
    },
  };
  const request = {
    parent,
    model,
  };

  console.log('PARENT AND MODEL');
  console.log(parent, model);
  // Upload Model request
  const [response] = await modelServiceClient.uploadModel(request);
  console.log(`Long running operation : ${response.name}`);

  // Wait for operation to complete
  await response.promise();
  const result = response.result;

  console.log('Upload model response ');
  console.log(`\tModel : ${result.model}`);
}
uploadModel();

Python

def upload_model_sample(
    project: str,
    location: str,
    display_name: str,
    serving_container_image_uri: str,
    artifact_uri: Optional[str] = None,
    serving_container_predict_route: Optional[str] = None,
    serving_container_health_route: Optional[str] = None,
    description: Optional[str] = None,
    serving_container_command: Optional[Sequence[str]] = None,
    serving_container_args: Optional[Sequence[str]] = None,
    serving_container_environment_variables: Optional[Dict[str, str]] = None,
    serving_container_ports: Optional[Sequence[int]] = None,
    instance_schema_uri: Optional[str] = None,
    parameters_schema_uri: Optional[str] = None,
    prediction_schema_uri: Optional[str] = None,
    explanation_metadata: Optional[explain.ExplanationMetadata] = None,
    explanation_parameters: Optional[explain.ExplanationParameters] = None,
    sync: bool = True,
):

    aiplatform.init(project=project, location=location)

    model = aiplatform.Model.upload(
        display_name=display_name,
        artifact_uri=artifact_uri,
        serving_container_image_uri=serving_container_image_uri,
        serving_container_predict_route=serving_container_predict_route,
        serving_container_health_route=serving_container_health_route,
        instance_schema_uri=instance_schema_uri,
        parameters_schema_uri=parameters_schema_uri,
        prediction_schema_uri=prediction_schema_uri,
        description=description,
        serving_container_command=serving_container_command,
        serving_container_args=serving_container_args,
        serving_container_environment_variables=serving_container_environment_variables,
        serving_container_ports=serving_container_ports,
        explanation_metadata=explanation_metadata,
        explanation_parameters=explanation_parameters,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    return model

For more context, read the Model import guide.

Send prediction requests

To send an online prediction request to your Model, follow the guides to deploy your Model to an Endpoint and get online prediction from custom-trained models: this process works the same regardless of whether you use a custom container.

Read about predict request and response requirements for custom containers.

What's next