Genera testo con il modello Gemma2 con acceleratori TPU

Questo esempio di codice mostra come eseguire l'interferenza su un modello Gemma2 di cui è stato eseguito il deployment in un endpoint Vertex AI con acceleratori TPU.

Per saperne di più

Per la documentazione dettagliata che include questo esempio di codice, consulta quanto segue:

Esempio di codice

Go

Prima di provare questo esempio, segui le istruzioni di configurazione Go riportate nella guida rapida all'utilizzo delle librerie client di Vertex AI. Per ulteriori informazioni, consulta la documentazione di riferimento dell'API Go di Vertex AI.

Per autenticarti a Vertex AI, configura le Credenziali predefinite dell'applicazione. Per ulteriori informazioni, consulta Configurare l'autenticazione per un ambiente di sviluppo locale.

import (
	"context"
	"fmt"
	"io"

	"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"

	"google.golang.org/protobuf/types/known/structpb"
)

// predictTPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accelerators.
func predictTPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error {
	ctx := context.Background()

	// Note: client can be initialized in the following way:
	// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
	// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
	// if err != nil {
	// 	return fmt.Errorf("unable to create prediction client: %v", err)
	// }
	// defer client.Close()

	gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
	prompt := "Why is the sky blue?"
	parameters := map[string]interface{}{
		"temperature":     0.9,
		"maxOutputTokens": 1024,
		"topP":            1.0,
		"topK":            1,
	}

	// Encapsulate the prompt in a correct format for TPUs.
	// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
	promptValue, err := structpb.NewValue(map[string]interface{}{
		"prompt":     prompt,
		"parameters": parameters,
	})
	if err != nil {
		fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
		return err
	}

	req := &aiplatformpb.PredictRequest{
		Endpoint:  gemma2Endpoint,
		Instances: []*structpb.Value{promptValue},
	}

	resp, err := client.Predict(ctx, req)
	if err != nil {
		return err
	}

	prediction := resp.GetPredictions()
	value := prediction[0].GetStringValue()
	fmt.Fprintf(w, "%v", value)

	return nil
}

Java

Prima di provare questo esempio, segui le istruzioni di configurazione Java riportate nella guida rapida all'utilizzo delle librerie client di Vertex AI. Per ulteriori informazioni, consulta la documentazione di riferimento dell'API Java di Vertex AI.

Per autenticarti a Vertex AI, configura le Credenziali predefinite dell'applicazione. Per ulteriori informazioni, consulta Configurare l'autenticazione per un ambiente di sviluppo locale.


import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.gson.Gson;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Gemma2PredictTpu {
  private final PredictionServiceClient predictionServiceClient;

  // Constructor to inject the PredictionServiceClient
  public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) {
    this.predictionServiceClient = predictionServiceClient;
  }

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String endpointRegion = "us-west1";
    String endpointId = "YOUR_ENDPOINT_ID";

    PredictionServiceSettings predictionServiceSettings =
        PredictionServiceSettings.newBuilder()
            .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
            .build();
    PredictionServiceClient predictionServiceClient =
        PredictionServiceClient.create(predictionServiceSettings);
    Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient);

    creator.gemma2PredictTpu(projectId, endpointRegion, endpointId);
  }

  // Demonstrates how to run inference on a Gemma2 model
  // deployed to a Vertex AI endpoint with TPU accelerators.
  public String gemma2PredictTpu(String projectId, String region,
           String endpointId) throws IOException {
    Map<String, Object> paramsMap = new HashMap<>();
    paramsMap.put("temperature", 0.9);
    paramsMap.put("maxOutputTokens", 1024);
    paramsMap.put("topP", 1.0);
    paramsMap.put("topK", 1);
    Value parameters = mapToValue(paramsMap);
    // Prompt used in the prediction
    String instance = "{ \"prompt\": \"Why is the sky blue?\"}";
    Value.Builder instanceValue = Value.newBuilder();
    JsonFormat.parser().merge(instance, instanceValue);
    // Encapsulate the prompt in a correct format for TPUs
    // Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
    List<Value> instances = new ArrayList<>();
    instances.add(instanceValue.build());

    EndpointName endpointName = EndpointName.of(projectId, region, endpointId);

    PredictResponse predictResponse = this.predictionServiceClient
        .predict(endpointName, instances, parameters);
    String textResponse = predictResponse.getPredictions(0).getStringValue();
    System.out.println(textResponse);
    return textResponse;
  }

  private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
    Gson gson = new Gson();
    String json = gson.toJson(map);
    Value.Builder builder = Value.newBuilder();
    JsonFormat.parser().merge(json, builder);
    return builder.build();
  }
}

Node.js

Prima di provare questo esempio, segui le istruzioni di configurazione Node.js riportate nella guida rapida all'utilizzo delle librerie client di Vertex AI. Per ulteriori informazioni, consulta la documentazione di riferimento dell'API Node.js di Vertex AI.

Per autenticarti a Vertex AI, configura le Credenziali predefinite dell'applicazione. Per ulteriori informazioni, consulta Configurare l'autenticazione per un ambiente di sviluppo locale.

// Imports the Google Cloud Prediction Service Client library
const {
  // TODO(developer): Uncomment PredictionServiceClient before running the sample.
  // PredictionServiceClient,
  helpers,
} = require('@google-cloud/aiplatform');
/**
 * TODO(developer): Update these variables before running the sample.
 */
const projectId = 'your-project-id';
const endpointRegion = 'your-vertex-endpoint-region';
const endpointId = 'your-vertex-endpoint-id';

// Prompt used in the prediction
const prompt = 'Why is the sky blue?';

// Encapsulate the prompt in a correct format for TPUs
// Example format: [{prompt: 'Why is the sky blue?', temperature: 0.9}]
const input = {
  prompt,
  // Parameters for default configuration
  maxOutputTokens: 1024,
  temperature: 0.9,
  topP: 1.0,
  topK: 1,
};

// Convert input message to a list of GAPIC instances for model input
const instances = [helpers.toValue(input)];

// TODO(developer): Uncomment apiEndpoint and predictionServiceClient before running the sample.
// const apiEndpoint = `${endpointRegion}-aiplatform.googleapis.com`;

// Create a client
// predictionServiceClient = new PredictionServiceClient({apiEndpoint});

// Call the Gemma2 endpoint
const gemma2Endpoint = `projects/${projectId}/locations/${endpointRegion}/endpoints/${endpointId}`;

const [response] = await predictionServiceClient.predict({
  endpoint: gemma2Endpoint,
  instances,
});

const predictions = response.predictions;
const text = predictions[0].stringValue;

console.log('Predictions:', text);

Python

Prima di provare questo esempio, segui le istruzioni di configurazione Python riportate nella guida rapida all'utilizzo delle librerie client di Vertex AI. Per ulteriori informazioni, consulta la documentazione di riferimento dell'API Python di Vertex AI.

Per autenticarti a Vertex AI, configura le Credenziali predefinite dell'applicazione. Per ulteriori informazioni, consulta Configurare l'autenticazione per un ambiente di sviluppo locale.

"""
Sample to run inference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accellerators.
"""

from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

# TODO(developer): Update & uncomment lines below
# PROJECT_ID = "your-project-id"
# ENDPOINT_REGION = "your-vertex-endpoint-region"
# ENDPOINT_ID = "your-vertex-endpoint-id"

# Default configuration
config = {"max_tokens": 1024, "temperature": 0.9, "top_p": 1.0, "top_k": 1}

# Prompt used in the prediction
prompt = "Why is the sky blue?"

# Encapsulate the prompt in a correct format for TPUs
# Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
input = {"prompt": prompt}
input.update(config)

# Convert input message to a list of GAPIC instances for model input
instances = [json_format.ParseDict(input, Value())]

# Create a client
api_endpoint = f"{ENDPOINT_REGION}-aiplatform.googleapis.com"
client = aiplatform.gapic.PredictionServiceClient(
    client_options={"api_endpoint": api_endpoint}
)

# Call the Gemma2 endpoint
gemma2_end_point = (
    f"projects/{PROJECT_ID}/locations/{ENDPOINT_REGION}/endpoints/{ENDPOINT_ID}"
)
response = client.predict(
    endpoint=gemma2_end_point,
    instances=instances,
)
text_responses = response.predictions
print(text_responses[0])

Passaggi successivi

Per cercare e filtrare i sample di codice per altri prodotti Google Cloud , consulta il Google Cloud browser di sample.