Batch text prediction with Gemini model

Perform batch text prediction using Gemini model and returns the output location.

Explore further

For detailed documentation that includes this code sample, see the following:

Code sample

Go

Before trying this sample, follow the Go setup instructions in the Vertex AI quickstart using client libraries. For more information, see the Vertex AI Go API reference documentation.

To authenticate to Vertex AI, set up Application Default Credentials. For more information, see Set up authentication for a local development environment.

import (
	"context"
	"fmt"
	"io"
	"time"

	aiplatform "cloud.google.com/go/aiplatform/apiv1"
	aiplatformpb "cloud.google.com/go/aiplatform/apiv1/aiplatformpb"

	"google.golang.org/api/option"
	"google.golang.org/protobuf/types/known/structpb"
)

// batchPredictGCS submits a batch prediction job using GCS data source as its input
func batchPredictGCS(w io.Writer, projectID, location string, inputURIs []string, outputURI string) error {
	// location := "us-central1"
	// inputURIs := []string{"gs://cloud-samples-data/batch/prompt_for_batch_gemini_predict.jsonl"}
	// outputURI := "gs://<cloud-bucket-name>/<prefix-name>"
	modelName := "gemini-1.5-pro-002"
	jobName := "batch-predict-gcs-test-001"

	ctx := context.Background()
	apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
	client, err := aiplatform.NewJobClient(ctx, option.WithEndpoint(apiEndpoint))
	if err != nil {
		return fmt.Errorf("unable to create aiplatform client: %w", err)
	}
	defer client.Close()

	modelParameters, err := structpb.NewValue(map[string]interface{}{
		"temperature":     0.2,
		"maxOutputTokens": 200,
	})
	if err != nil {
		return fmt.Errorf("unable to convert model parameters to protobuf value: %w", err)
	}

	req := &aiplatformpb.CreateBatchPredictionJobRequest{
		Parent: fmt.Sprintf("projects/%s/locations/%s", projectID, location),
		BatchPredictionJob: &aiplatformpb.BatchPredictionJob{
			DisplayName:     jobName,
			Model:           fmt.Sprintf("publishers/google/models/%s", modelName),
			ModelParameters: modelParameters,
			// Check the API reference for `BatchPredictionJob` for supported input and output formats:
			// https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#google.cloud.aiplatform.v1.BatchPredictionJob
			InputConfig: &aiplatformpb.BatchPredictionJob_InputConfig{
				Source: &aiplatformpb.BatchPredictionJob_InputConfig_GcsSource{
					GcsSource: &aiplatformpb.GcsSource{
						Uris: inputURIs,
					},
				},
				InstancesFormat: "jsonl",
			},
			OutputConfig: &aiplatformpb.BatchPredictionJob_OutputConfig{
				Destination: &aiplatformpb.BatchPredictionJob_OutputConfig_GcsDestination{
					GcsDestination: &aiplatformpb.GcsDestination{
						OutputUriPrefix: outputURI,
					},
				},
				PredictionsFormat: "jsonl",
			},
		},
	}

	job, err := client.CreateBatchPredictionJob(ctx, req)
	if err != nil {
		return err
	}
	fullJobId := job.GetName()
	fmt.Fprintf(w, "submitted batch predict job for model %q\n", job.GetModel())
	fmt.Fprintf(w, "job id: %q\n", fullJobId)
	fmt.Fprintf(w, "job state: %s\n", job.GetState())
	// Example response:
	// submitted batch predict job for model "publishers/google/models/gemini-1.5-pro-002"
	// job id: "projects/.../locations/.../batchPredictionJobs/1234567890000000000"
	// job state: JOB_STATE_PENDING

	for {
		time.Sleep(5 * time.Second)

		job, err := client.GetBatchPredictionJob(ctx, &aiplatformpb.GetBatchPredictionJobRequest{
			Name: fullJobId,
		})
		if err != nil {
			return fmt.Errorf("error: couldn't get updated job state: %w", err)
		}

		if job.GetEndTime() != nil {
			fmt.Fprintf(w, "batch predict job finished with state %s\n", job.GetState())
			break
		} else {
			fmt.Fprintf(w, "batch predict job is running... job state is %s\n", job.GetState())
		}
	}

	return nil
}

Python

Before trying this sample, follow the Python setup instructions in the Vertex AI quickstart using client libraries. For more information, see the Vertex AI Python API reference documentation.

To authenticate to Vertex AI, set up Application Default Credentials. For more information, see Set up authentication for a local development environment.

import time
import vertexai

from vertexai.batch_prediction import BatchPredictionJob

# TODO(developer): Update and un-comment below line
# PROJECT_ID = "your-project-id"

# Initialize vertexai
vertexai.init(project=PROJECT_ID, location="us-central1")

input_uri = "gs://cloud-samples-data/batch/prompt_for_batch_gemini_predict.jsonl"

# Submit a batch prediction job with Gemini model
batch_prediction_job = BatchPredictionJob.submit(
    source_model="gemini-1.5-flash-002",
    input_dataset=input_uri,
    output_uri_prefix=output_uri,
)

# Check job status
print(f"Job resource name: {batch_prediction_job.resource_name}")
print(f"Model resource name with the job: {batch_prediction_job.model_name}")
print(f"Job state: {batch_prediction_job.state.name}")

# Refresh the job until complete
while not batch_prediction_job.has_ended:
    time.sleep(5)
    batch_prediction_job.refresh()

# Check if the job succeeds
if batch_prediction_job.has_succeeded:
    print("Job succeeded!")
else:
    print(f"Job failed: {batch_prediction_job.error}")

# Check the location of the output
print(f"Job output location: {batch_prediction_job.output_location}")

# Example response:
#  Job output location: gs://your-bucket/gen-ai-batch-prediction/prediction-model-year-month-day-hour:minute:second.12345

What's next

To search and filter code samples for other Google Cloud products, see the Google Cloud sample browser.