Previsione del testo in batch con un modello preaddestrato

Esegui la previsione del testo in batch utilizzando un modello di generazione di testo preaddestrato.

Esempio di codice

GoJava

Prima di provare questo esempio, segui le istruzioni di configurazione Go riportate nella guida rapida all'utilizzo delle librerie client di Vertex AI. Per ulteriori informazioni, consulta la documentazione di riferimento dell'API Go di Vertex AI.

Per autenticarti a Vertex AI, configura le Credenziali predefinite dell'applicazione. Per ulteriori informazioni, consulta Configura l'autenticazione per un ambiente di sviluppo locale.

import (
	"context"
	"fmt"
	"io"

	aiplatform "cloud.google.com/go/aiplatform/apiv1"
	aiplatformpb "cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
	"google.golang.org/api/option"
	"google.golang.org/protobuf/types/known/structpb"
)

// batchTextPredict perform batch text prediction using a pre-trained text generation model
func batchTextPredict(w io.Writer, projectID, location, name, outputURI string, inputURIs []string) error {
	// inputURI := []string{"gs://cloud-samples-data/batch/prompt_for_batch_text_predict.jsonl"}
	// outputURI: existing template path. Following formats are allowed:
	// 	- gs://BUCKET_NAME/DIRECTORY/
	// 	- bq://project_name.llm_dataset

	ctx := context.Background()
	apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
	// Pretrained text model
	model := "publishers/google/models/text-bison"
	parameters := map[string]interface{}{
		"temperature":     0.2,
		"maxOutputTokens": 200,
		"topP":            0.95,
		"topK":            40,
	}

	parametersValue, err := structpb.NewValue(parameters)
	if err != nil {
		fmt.Fprintf(w, "unable to convert parameters to Value: %v", err)
		return err
	}

	client, err := aiplatform.NewJobClient(ctx, option.WithEndpoint(apiEndpoint))
	if err != nil {
		return err
	}
	defer client.Close()

	req := &aiplatformpb.CreateBatchPredictionJobRequest{
		Parent: fmt.Sprintf("projects/%s/locations/%s", projectID, location),
		BatchPredictionJob: &aiplatformpb.BatchPredictionJob{
			DisplayName:     name,
			Model:           model,
			ModelParameters: parametersValue,
			InputConfig: &aiplatformpb.BatchPredictionJob_InputConfig{
				Source: &aiplatformpb.BatchPredictionJob_InputConfig_GcsSource{
					GcsSource: &aiplatformpb.GcsSource{
						Uris: inputURIs,
					},
				},
				// List of supported formarts: https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#model
				InstancesFormat: "jsonl",
			},
			OutputConfig: &aiplatformpb.BatchPredictionJob_OutputConfig{
				Destination: &aiplatformpb.BatchPredictionJob_OutputConfig_GcsDestination{
					GcsDestination: &aiplatformpb.GcsDestination{
						OutputUriPrefix: outputURI,
					},
				},
				// List of supported formarts: https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#model
				PredictionsFormat: "jsonl",
			},
		},
	}

	job, err := client.CreateBatchPredictionJob(ctx, req)
	if err != nil {
		return err
	}
	fmt.Fprint(w, job.GetDisplayName())

	return nil
}

Prima di provare questo esempio, segui le istruzioni di configurazione Java riportate nella guida rapida all'utilizzo delle librerie client di Vertex AI. Per ulteriori informazioni, consulta la documentazione di riferimento dell'API Java di Vertex AI.

Per autenticarti a Vertex AI, configura le Credenziali predefinite dell'applicazione. Per ulteriori informazioni, consulta Configura l'autenticazione per un ambiente di sviluppo locale.


import com.google.cloud.aiplatform.v1.BatchPredictionJob;
import com.google.cloud.aiplatform.v1.GcsDestination;
import com.google.cloud.aiplatform.v1.GcsSource;
import com.google.cloud.aiplatform.v1.JobServiceClient;
import com.google.cloud.aiplatform.v1.JobServiceSettings;
import com.google.gson.Gson;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;

public class BatchTextPredictionSample {

  public static void main(String[] args)
      throws IOException, InterruptedException, ExecutionException, TimeoutException {
    // TODO(developer): Replace these variables before running the sample.
    String project = "YOUR_PROJECT_ID";
    String location = "us-central1";
    // inputUri: URI of the input dataset.
    // Could be a BigQuery table or a Google Cloud Storage file.
    // E.g. "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
    String inputUri = "gs://cloud-samples-data/batch/prompt_for_batch_text_predict.jsonl";
    // outputUri: URI where the output will be stored.
    // Could be a BigQuery table or a Google Cloud Storage file.
    // E.g. "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
    String outputUri = "gs://YOUR_BUCKET/batch_text_predict_output";
    String textModel = "text-bison";

    batchTextPrediction(project, inputUri, outputUri, textModel, location);
  }

  // Perform batch text prediction using a pre-trained text generation model.
  // Example of using Google Cloud Storage bucket as the input and output data source
  static BatchPredictionJob batchTextPrediction(
      String projectId, String inputUri, String outputUri, String textModel, String location)
      throws IOException {
    BatchPredictionJob response;
    JobServiceSettings jobServiceSettings =  JobServiceSettings.newBuilder()
        .setEndpoint("us-central1-aiplatform.googleapis.com:443").build();
    String parent = String.format("projects/%s/locations/%s", projectId, location);
    String modelName = String.format(
        "projects/%s/locations/%s/publishers/google/models/%s", projectId, location, textModel);
    // Construct model parameters
    Map<String, String> modelParameters = new HashMap<>();
    modelParameters.put("maxOutputTokens", "200");
    modelParameters.put("temperature", "0.2");
    modelParameters.put("topP", "0.95");
    modelParameters.put("topK", "40");
    Value parameterValue = mapToValue(modelParameters);

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests.
    try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) {

      BatchPredictionJob batchPredictionJob =
          BatchPredictionJob.newBuilder()
              .setDisplayName("my batch text prediction job " + System.currentTimeMillis())
              .setModel(modelName)
              .setInputConfig(
                  BatchPredictionJob.InputConfig.newBuilder()
                      .setGcsSource(GcsSource.newBuilder().addUris(inputUri).build())
                      .setInstancesFormat("jsonl")
                      .build())
              .setOutputConfig(
                  BatchPredictionJob.OutputConfig.newBuilder()
                      .setGcsDestination(GcsDestination.newBuilder()
                          .setOutputUriPrefix(outputUri).build())
                      .setPredictionsFormat("jsonl")
                      .build())
              .setModelParameters(parameterValue)
              .build();

      // Create the batch prediction job
      response =
          jobServiceClient.createBatchPredictionJob(parent, batchPredictionJob);

      System.out.format("response: %s\n", response);
      System.out.format("\tName: %s\n", response.getName());
    }
    return response;
  }

  private static Value mapToValue(Map<String, String> map) throws InvalidProtocolBufferException {
    Gson gson = new Gson();
    String json = gson.toJson(map);
    Value.Builder builder = Value.newBuilder();
    JsonFormat.parser().merge(json, builder);
    return builder.build();
  }
}

Passaggi successivi

Per cercare e filtrare gli esempi di codice per altri prodotti Google Cloud , consulta il browser degli esempi diGoogle Cloud .