Text mit dem Gemma2-Modell mit GPU-Beschleunigern generieren

In diesem Codebeispiel wird gezeigt, wie Sie mit GPU-Beschleunigern eine Störung auf ein Gemma2-Modell ausführen, das auf einem Vertex AI-Endpunkt bereitgestellt wird.

Weitere Informationen

Eine ausführliche Dokumentation, die dieses Codebeispiel enthält, finden Sie hier:

Codebeispiel

Go

Bevor Sie dieses Beispiel anwenden, folgen Sie den Go-Einrichtungsschritten in der Vertex AI-Kurzanleitung zur Verwendung von Clientbibliotheken. Weitere Informationen finden Sie in der Referenzdokumentation zur Vertex AI Go API.

Richten Sie zur Authentifizierung bei Vertex AI Standardanmeldedaten für Anwendungen ein. Weitere Informationen finden Sie unter Authentifizierung für eine lokale Entwicklungsumgebung einrichten.

import (
	"context"
	"fmt"
	"io"

	"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"

	"google.golang.org/protobuf/types/known/structpb"
)

// predictGPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with GPU accelerators.
func predictGPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error {
	ctx := context.Background()

	// Note: client can be initialized in the following way:
	// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
	// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
	// if err != nil {
	// 	return fmt.Errorf("unable to create prediction client: %v", err)
	// }
	// defer client.Close()

	gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
	prompt := "Why is the sky blue?"
	parameters := map[string]interface{}{
		"temperature":     0.9,
		"maxOutputTokens": 1024,
		"topP":            1.0,
		"topK":            1,
	}

	// Encapsulate the prompt in a correct format for TPUs.
	// Pay attention that prompt should be set in "inputs" field.
	// Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.9}}]
	promptValue, err := structpb.NewValue(map[string]interface{}{
		"inputs":     prompt,
		"parameters": parameters,
	})
	if err != nil {
		fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
		return err
	}

	req := &aiplatformpb.PredictRequest{
		Endpoint:  gemma2Endpoint,
		Instances: []*structpb.Value{promptValue},
	}

	resp, err := client.Predict(ctx, req)
	if err != nil {
		return err
	}

	prediction := resp.GetPredictions()
	value := prediction[0].GetStringValue()
	fmt.Fprintf(w, "%v", value)

	return nil
}

Java

Bevor Sie dieses Beispiel anwenden, folgen Sie den Java-Einrichtungsschritten in der Vertex AI-Kurzanleitung zur Verwendung von Clientbibliotheken. Weitere Informationen finden Sie in der Referenzdokumentation zur Vertex AI Java API.

Richten Sie zur Authentifizierung bei Vertex AI Standardanmeldedaten für Anwendungen ein. Weitere Informationen finden Sie unter Authentifizierung für eine lokale Entwicklungsumgebung einrichten.


import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.gson.Gson;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Gemma2PredictGpu {

  private final PredictionServiceClient predictionServiceClient;

  // Constructor to inject the PredictionServiceClient
  public Gemma2PredictGpu(PredictionServiceClient predictionServiceClient) {
    this.predictionServiceClient = predictionServiceClient;
  }

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String endpointRegion = "us-east4";
    String endpointId = "YOUR_ENDPOINT_ID";

    PredictionServiceSettings predictionServiceSettings =
        PredictionServiceSettings.newBuilder()
            .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
            .build();
    PredictionServiceClient predictionServiceClient =
        PredictionServiceClient.create(predictionServiceSettings);
    Gemma2PredictGpu creator = new Gemma2PredictGpu(predictionServiceClient);

    creator.gemma2PredictGpu(projectId, endpointRegion, endpointId);
  }

  // Demonstrates how to run inference on a Gemma2 model
  // deployed to a Vertex AI endpoint with GPU accelerators.
  public String gemma2PredictGpu(String projectId, String region,
               String endpointId) throws IOException {
    Map<String, Object> paramsMap = new HashMap<>();
    paramsMap.put("temperature", 0.9);
    paramsMap.put("maxOutputTokens", 1024);
    paramsMap.put("topP", 1.0);
    paramsMap.put("topK", 1);
    Value parameters = mapToValue(paramsMap);

    // Prompt used in the prediction
    String instance = "{ \"inputs\": \"Why is the sky blue?\"}";
    Value.Builder instanceValue = Value.newBuilder();
    JsonFormat.parser().merge(instance, instanceValue);
    // Encapsulate the prompt in a correct format for GPUs
    // Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.8}}]
    List<Value> instances = new ArrayList<>();
    instances.add(instanceValue.build());

    EndpointName endpointName = EndpointName.of(projectId, region, endpointId);

    PredictResponse predictResponse = this.predictionServiceClient
        .predict(endpointName, instances, parameters);
    String textResponse = predictResponse.getPredictions(0).getStringValue();
    System.out.println(textResponse);
    return textResponse;
  }

  private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
    Gson gson = new Gson();
    String json = gson.toJson(map);
    Value.Builder builder = Value.newBuilder();
    JsonFormat.parser().merge(json, builder);
    return builder.build();
  }
}

Node.js

Bevor Sie dieses Beispiel anwenden, folgen Sie den Node.js-Einrichtungsschritten in der Vertex AI-Kurzanleitung zur Verwendung von Clientbibliotheken. Weitere Informationen finden Sie in der Referenzdokumentation zur Vertex AI Node.js API.

Richten Sie zur Authentifizierung bei Vertex AI Standardanmeldedaten für Anwendungen ein. Weitere Informationen finden Sie unter Authentifizierung für eine lokale Entwicklungsumgebung einrichten.

async function gemma2PredictGpu(predictionServiceClient) {
  // Imports the Google Cloud Prediction Service Client library
  const {
    // TODO(developer): Uncomment PredictionServiceClient before running the sample.
    // PredictionServiceClient,
    helpers,
  } = require('@google-cloud/aiplatform');
  /**
   * TODO(developer): Update these variables before running the sample.
   */
  const projectId = 'your-project-id';
  const endpointRegion = 'your-vertex-endpoint-region';
  const endpointId = 'your-vertex-endpoint-id';

  // Default configuration
  const config = {maxOutputTokens: 1024, temperature: 0.9, topP: 1.0, topK: 1};
  // Prompt used in the prediction
  const prompt = 'Why is the sky blue?';

  // Encapsulate the prompt in a correct format for GPUs
  // Example format: [{inputs: 'Why is the sky blue?', parameters: {temperature: 0.9}}]
  const input = {
    inputs: prompt,
    parameters: config,
  };

  // Convert input message to a list of GAPIC instances for model input
  const instances = [helpers.toValue(input)];

  // TODO(developer): Uncomment apiEndpoint and predictionServiceClient before running the sample.
  // const apiEndpoint = `${endpointRegion}-aiplatform.googleapis.com`;

  // Create a client
  // predictionServiceClient = new PredictionServiceClient({apiEndpoint});

  // Call the Gemma2 endpoint
  const gemma2Endpoint = `projects/${projectId}/locations/${endpointRegion}/endpoints/${endpointId}`;

  const [response] = await predictionServiceClient.predict({
    endpoint: gemma2Endpoint,
    instances,
  });

  const predictions = response.predictions;
  const text = predictions[0].stringValue;

  console.log('Predictions:', text);
  return text;
}

module.exports = gemma2PredictGpu;

// TODO(developer): Uncomment below lines before running the sample.
// gemma2PredictGpu(...process.argv.slice(2)).catch(err => {
//   console.error(err.message);
//   process.exitCode = 1;
// });

Python

Bevor Sie dieses Beispiel anwenden, folgen Sie den Python-Einrichtungsschritten in der Vertex AI-Kurzanleitung zur Verwendung von Clientbibliotheken. Weitere Informationen finden Sie in der Referenzdokumentation zur Vertex AI Python API.

Richten Sie zur Authentifizierung bei Vertex AI Standardanmeldedaten für Anwendungen ein. Weitere Informationen finden Sie unter Authentifizierung für eine lokale Entwicklungsumgebung einrichten.

"""
Sample to run inference on a Gemma2 model deployed to a Vertex AI endpoint with GPU accellerators.
"""

from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

# TODO(developer): Update & uncomment lines below
# PROJECT_ID = "your-project-id"
# ENDPOINT_REGION = "your-vertex-endpoint-region"
# ENDPOINT_ID = "your-vertex-endpoint-id"

# Default configuration
config = {"max_tokens": 1024, "temperature": 0.9, "top_p": 1.0, "top_k": 1}

# Prompt used in the prediction
prompt = "Why is the sky blue?"

# Encapsulate the prompt in a correct format for GPUs
# Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.9}}]
input = {"inputs": prompt, "parameters": config}

# Convert input message to a list of GAPIC instances for model input
instances = [json_format.ParseDict(input, Value())]

# Create a client
api_endpoint = f"{ENDPOINT_REGION}-aiplatform.googleapis.com"
client = aiplatform.gapic.PredictionServiceClient(
    client_options={"api_endpoint": api_endpoint}
)

# Call the Gemma2 endpoint
gemma2_end_point = (
    f"projects/{PROJECT_ID}/locations/{ENDPOINT_REGION}/endpoints/{ENDPOINT_ID}"
)
response = client.predict(
    endpoint=gemma2_end_point,
    instances=instances,
)
text_responses = response.predictions
print(text_responses[0])

Nächste Schritte

Informationen zum Suchen und Filtern von Codebeispielen für andere Google Cloud -Produkte finden Sie im Google Cloud Beispielbrowser.