Use a custom container for prediction

To customize how Vertex AI serves online predictions from your custom-trained model, you can specify a custom container instead of a prebuilt container when you create a Model resource. When you use a custom container, Vertex AI runs a Docker container of your choice on each prediction node.

You might want to use a custom container for any of the following reasons:

  • to serve predictions from an ML model trained using a framework that isn't available as a prebuilt container
  • to preprocess prediction requests or postprocess the predictions generated by your model
  • to run a prediction server written in a programming language of your choice
  • to install dependencies that you want to use to customize prediction

This guide describes how to create a model that uses a custom container. It doesn't provide detailed instructions about designing and creating a Docker container image.

Prepare a container image

To create a Model that uses a custom container, you must provide a Docker container image as the basis of that container. This container image must meet the requirements described in Custom container requirements.

If you plan to use an existing container image created by a third party that you trust, you might be able to skip one or both of the following sections.

Create a container image

Design and build a Docker container image that meets the container image requirements.

To learn the basics of designing and building a Docker container image, read the Docker documentation's quickstart

Push the container image to Artifact Registry

Push your container image to an Artifact Registry repository.

Learn how to push a container image to Artifact Registry.

Create a Model

To create a Model that uses a custom container, do one of the following:

The follow sections show how to configure the API fields related to custom containers when you create a Model in one of these ways.

Container-related API fields

When you create the Model, make sure to configure the containerSpec field with your custom container details, rather than with a prebuilt container.

You must specify a ModelContainerSpec message in the Model.containerSpec field. Within this message, you can specify the following subfields:

imageUri (required)

The Artifact Registry URI of your container image.

If you are using the gcloud ai models upload command, you can use the --container-image-uri flag to specify this field.

command (optional)

An array of an executable and arguments to override the container's ENTRYPOINT instruction. To learn more about how to format this field and how it interacts with the args field, read the API reference for ModelContainerSpec.

If you are using the gcloud ai models upload command, you can use the --container-command flag to specify this field.

args (optional)

An array of an executable and arguments to override the container's CMD To learn more about how to format this field and how it interacts with the command field, read the API reference for ModelContainerSpec.

If you are using the gcloud ai models upload command, you can use the --container-args flag to specify this field.

ports (optional)

An array of ports; Vertex AI sends liveness checks, health checks, and prediction requests to your container on the first port listed, or 8080 by default. Specifying additional ports has no effect.

If you are using the gcloud ai models upload command, you can use the --container-ports flag to specify this field.

env (optional)

An array of environment variables that the container's ENTRYPOINT instruction, as well as the command and args fields, can reference. To learn more about how other fields can reference these environment variables, read the API reference for ModelContainerSpec.

If you are using the gcloud ai models upload command, you can use the --container-env-vars flag to specify this field.

healthRoute (optional)

The path on your container's HTTP server where you want Vertex AI to send health checks.

If you don't specify this field, then when you deploy the Model as a DeployedModel to an Endpoint resource it defaults to /v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL, where ENDPOINT is replaced by the last segment of the Endpoint's name field (following endpoints/) and DEPLOYED_MODEL is replaced by the DeployedModel's id field.

If you are using the gcloud ai models upload command, you can use the --container-health-route flag to specify this field.

predictRoute (optional)

The path on your container's HTTP server where you want Vertex AI to forward prediction requests.

If you don't specify this field, then when you deploy the Model as a DeployedModel to an Endpoint resource it defaults to /v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL:predict, where ENDPOINT is replaced by the last segment of the Endpoint's name field (following endpoints/) and DEPLOYED_MODEL is replaced by the DeployedModel's id field.

If you are using the gcloud ai models upload command, you can use the --container-predict-route flag to specify this field.

sharedMemorySizeMb (optional)

The amount of VM memory to reserve in a shared memory volume for the model in megabytes.

Shared memory is an Inter-process communication (IPC) mechanism that allows multiple processes to access and manipulate a common block of memory. The amount of shared memory needed, if any, is an implementation detail of your container and model. Consult your model server documentation for guidelines.

If you are using the gcloud ai models upload command, you can use the --container-shared-memory-size-mb flag to specify this field.

startupProbe (optional)

Specification for the probe that checks whether the container application has started.

If you are using the gcloud ai models upload command, you can use the --container-startup-probe-exec, --container-startup-probe-period-seconds, --container-startup-probe-timeout-seconds flag to specify this field.

healthProbe (optional)

Specification for the probe that checks whether a container is ready to accept traffic.

If you are using the gcloud ai models upload command, you can use the --container-health-probe-exec, --container-health-probe-period-seconds, --container-health-probe-timeout-seconds flag to specify this field.

In addition to the variables that you set in the Model.containerSpec.env field, Vertex AI sets several other variables based on your configuration. Learn more about using these environment variables in these fields and in the container's ENTRYPOINT instruction.

Model import examples

The following examples show how to specify container-related API fields when you import a model.

gcloud

The following example uses the gcloud ai models upload command:

gcloud ai models upload \
  --region=LOCATION \
  --display-name=MODEL_NAME \
  --container-image-uri=IMAGE_URI \
  --container-command=COMMAND \
  --container-args=ARGS \
  --container-ports=PORTS \
  --container-env-vars=ENV \
  --container-health-route=HEALTH_ROUTE \
  --container-predict-route=PREDICT_ROUTE \
  --container-shared-memory-size-mb=SHARED_MEMORY_SIZE \
  --container-startup-probe-exec=STARTUP_PROBE_EXEC \
  --container-startup-probe-period-seconds=STARTUP_PROBE_PERIOD \
  --container-startup-probe-timeout-seconds=STARTUP_PROBE_TIMEOUT \
  --container-health-probe-exec=HEALTH_PROBE_EXEC \
  --container-health-probe-period-seconds=HEALTH_PROBE_PERIOD \
  --container-health-probe-timeout-seconds=HEALTH_PROBE_TIMEOUT \
  --artifact-uri=PATH_TO_MODEL_ARTIFACT_DIRECTORY

The --container-image-uri flag is required; all other flags that begin with --container- are optional. To learn about the values for these fields, see the preceding section of this guide.

Java

Before trying this sample, follow the Java setup instructions in the Vertex AI quickstart using client libraries. For more information, see the Vertex AI Java API reference documentation.

To authenticate to Vertex AI, set up Application Default Credentials. For more information, see Set up authentication for a local development environment.


import com.google.api.gax.longrunning.OperationFuture;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.ModelServiceClient;
import com.google.cloud.aiplatform.v1.ModelServiceSettings;
import com.google.cloud.aiplatform.v1.UploadModelOperationMetadata;
import com.google.cloud.aiplatform.v1.UploadModelResponse;
import java.io.IOException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

public class UploadModelSample {
  public static void main(String[] args)
      throws InterruptedException, ExecutionException, TimeoutException, IOException {
    // TODO(developer): Replace these variables before running the sample.
    String project = "YOUR_PROJECT_ID";
    String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
    String metadataSchemaUri =
        "gs://google-cloud-aiplatform/schema/trainingjob/definition/custom_task_1.0.0.yaml";
    String imageUri = "YOUR_IMAGE_URI";
    String artifactUri = "gs://your-gcs-bucket/artifact_path";
    uploadModel(project, modelDisplayName, metadataSchemaUri, imageUri, artifactUri);
  }

  static void uploadModel(
      String project,
      String modelDisplayName,
      String metadataSchemaUri,
      String imageUri,
      String artifactUri)
      throws IOException, InterruptedException, ExecutionException, TimeoutException {
    ModelServiceSettings modelServiceSettings =
        ModelServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (ModelServiceClient modelServiceClient = ModelServiceClient.create(modelServiceSettings)) {
      String location = "us-central1";
      LocationName locationName = LocationName.of(project, location);

      ModelContainerSpec modelContainerSpec =
          ModelContainerSpec.newBuilder().setImageUri(imageUri).build();

      Model model =
          Model.newBuilder()
              .setDisplayName(modelDisplayName)
              .setMetadataSchemaUri(metadataSchemaUri)
              .setArtifactUri(artifactUri)
              .setContainerSpec(modelContainerSpec)
              .build();

      OperationFuture<UploadModelResponse, UploadModelOperationMetadata> uploadModelResponseFuture =
          modelServiceClient.uploadModelAsync(locationName, model);
      System.out.format(
          "Operation name: %s\n", uploadModelResponseFuture.getInitialFuture().get().getName());
      System.out.println("Waiting for operation to finish...");
      UploadModelResponse uploadModelResponse = uploadModelResponseFuture.get(5, TimeUnit.MINUTES);

      System.out.println("Upload Model Response");
      System.out.format("Model: %s\n", uploadModelResponse.getModel());
    }
  }
}

Node.js

Before trying this sample, follow the Node.js setup instructions in the Vertex AI quickstart using client libraries. For more information, see the Vertex AI Node.js API reference documentation.

To authenticate to Vertex AI, set up Application Default Credentials. For more information, see Set up authentication for a local development environment.

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 */

// const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME';
// const metadataSchemaUri = 'YOUR_METADATA_SCHEMA_URI';
// const imageUri = 'YOUR_IMAGE_URI';
// const artifactUri = 'YOUR_ARTIFACT_URI';
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';

// Imports the Google Cloud Model Service Client library
const {ModelServiceClient} = require('@google-cloud/aiplatform');

// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const modelServiceClient = new ModelServiceClient(clientOptions);

async function uploadModel() {
  // Configure the parent resources
  const parent = `projects/${project}/locations/${location}`;
  // Configure the model resources
  const model = {
    displayName: modelDisplayName,
    metadataSchemaUri: '',
    artifactUri: artifactUri,
    containerSpec: {
      imageUri: imageUri,
      command: [],
      args: [],
      env: [],
      ports: [],
      predictRoute: '',
      healthRoute: '',
    },
  };
  const request = {
    parent,
    model,
  };

  console.log('PARENT AND MODEL');
  console.log(parent, model);
  // Upload Model request
  const [response] = await modelServiceClient.uploadModel(request);
  console.log(`Long running operation : ${response.name}`);

  // Wait for operation to complete
  await response.promise();
  const result = response.result;

  console.log('Upload model response ');
  console.log(`\tModel : ${result.model}`);
}
uploadModel();

Python

To learn how to install or update the Vertex AI SDK for Python, see Install the Vertex AI SDK for Python. For more information, see the Python API reference documentation.

from typing import Dict, Optional, Sequence

from google.cloud import aiplatform
from google.cloud.aiplatform import explain


def upload_model_sample(
    project: str,
    location: str,
    display_name: str,
    serving_container_image_uri: str,
    artifact_uri: Optional[str] = None,
    serving_container_predict_route: Optional[str] = None,
    serving_container_health_route: Optional[str] = None,
    description: Optional[str] = None,
    serving_container_command: Optional[Sequence[str]] = None,
    serving_container_args: Optional[Sequence[str]] = None,
    serving_container_environment_variables: Optional[Dict[str, str]] = None,
    serving_container_ports: Optional[Sequence[int]] = None,
    instance_schema_uri: Optional[str] = None,
    parameters_schema_uri: Optional[str] = None,
    prediction_schema_uri: Optional[str] = None,
    explanation_metadata: Optional[explain.ExplanationMetadata] = None,
    explanation_parameters: Optional[explain.ExplanationParameters] = None,
    sync: bool = True,
):

    aiplatform.init(project=project, location=location)

    model = aiplatform.Model.upload(
        display_name=display_name,
        artifact_uri=artifact_uri,
        serving_container_image_uri=serving_container_image_uri,
        serving_container_predict_route=serving_container_predict_route,
        serving_container_health_route=serving_container_health_route,
        instance_schema_uri=instance_schema_uri,
        parameters_schema_uri=parameters_schema_uri,
        prediction_schema_uri=prediction_schema_uri,
        description=description,
        serving_container_command=serving_container_command,
        serving_container_args=serving_container_args,
        serving_container_environment_variables=serving_container_environment_variables,
        serving_container_ports=serving_container_ports,
        explanation_metadata=explanation_metadata,
        explanation_parameters=explanation_parameters,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    return model

For more context, read the Model import guide.

Send prediction requests

To send an online prediction request to your Model, follow the instructions at Get predictions from a custom trained model: this process works the same regardless of whether you use a custom container.

Read about predict request and response requirements for custom containers.

What's next