GPU 가속기를 사용하여 Gemma2 모델로 텍스트 생성

이 코드 샘플은 GPU 가속기를 사용하여 Vertex AI 엔드포인트에 배포된 Gemma2 모델에서 간섭을 실행하는 방법을 보여줍니다.

더 살펴보기

이 코드 샘플이 포함된 자세한 문서는 다음을 참조하세요.

코드 샘플

Go

이 샘플을 사용해 보기 전에 Vertex AI 빠른 시작: 클라이언트 라이브러리 사용Go 설정 안내를 따르세요. 자세한 내용은 Vertex AI Go API 참고 문서를 참조하세요.

Vertex AI에 인증하려면 애플리케이션 기본 사용자 인증 정보를 설정합니다. 자세한 내용은 로컬 개발 환경의 인증 설정을 참조하세요.

import (
	"context"
	"fmt"
	"io"

	"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"

	"google.golang.org/protobuf/types/known/structpb"
)

// predictGPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with GPU accelerators.
func predictGPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error {
	ctx := context.Background()

	// Note: client can be initialized in the following way:
	// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
	// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
	// if err != nil {
	// 	return fmt.Errorf("unable to create prediction client: %v", err)
	// }
	// defer client.Close()

	gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
	prompt := "Why is the sky blue?"
	parameters := map[string]interface{}{
		"temperature":     0.9,
		"maxOutputTokens": 1024,
		"topP":            1.0,
		"topK":            1,
	}

	// Encapsulate the prompt in a correct format for TPUs.
	// Pay attention that prompt should be set in "inputs" field.
	// Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.9}}]
	promptValue, err := structpb.NewValue(map[string]interface{}{
		"inputs":     prompt,
		"parameters": parameters,
	})
	if err != nil {
		fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
		return err
	}

	req := &aiplatformpb.PredictRequest{
		Endpoint:  gemma2Endpoint,
		Instances: []*structpb.Value{promptValue},
	}

	resp, err := client.Predict(ctx, req)
	if err != nil {
		return err
	}

	prediction := resp.GetPredictions()
	value := prediction[0].GetStringValue()
	fmt.Fprintf(w, "%v", value)

	return nil
}

Java

이 샘플을 사용해 보기 전에 Vertex AI 빠른 시작: 클라이언트 라이브러리 사용Java 설정 안내를 따르세요. 자세한 내용은 Vertex AI Java API 참고 문서를 참조하세요.

Vertex AI에 인증하려면 애플리케이션 기본 사용자 인증 정보를 설정합니다. 자세한 내용은 로컬 개발 환경의 인증 설정을 참조하세요.


import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.gson.Gson;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Gemma2PredictGpu {

  private final PredictionServiceClient predictionServiceClient;

  // Constructor to inject the PredictionServiceClient
  public Gemma2PredictGpu(PredictionServiceClient predictionServiceClient) {
    this.predictionServiceClient = predictionServiceClient;
  }

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String endpointRegion = "us-east4";
    String endpointId = "YOUR_ENDPOINT_ID";

    PredictionServiceSettings predictionServiceSettings =
        PredictionServiceSettings.newBuilder()
            .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
            .build();
    PredictionServiceClient predictionServiceClient =
        PredictionServiceClient.create(predictionServiceSettings);
    Gemma2PredictGpu creator = new Gemma2PredictGpu(predictionServiceClient);

    creator.gemma2PredictGpu(projectId, endpointRegion, endpointId);
  }

  // Demonstrates how to run inference on a Gemma2 model
  // deployed to a Vertex AI endpoint with GPU accelerators.
  public String gemma2PredictGpu(String projectId, String region,
               String endpointId) throws IOException {
    Map<String, Object> paramsMap = new HashMap<>();
    paramsMap.put("temperature", 0.9);
    paramsMap.put("maxOutputTokens", 1024);
    paramsMap.put("topP", 1.0);
    paramsMap.put("topK", 1);
    Value parameters = mapToValue(paramsMap);

    // Prompt used in the prediction
    String instance = "{ \"inputs\": \"Why is the sky blue?\"}";
    Value.Builder instanceValue = Value.newBuilder();
    JsonFormat.parser().merge(instance, instanceValue);
    // Encapsulate the prompt in a correct format for GPUs
    // Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.8}}]
    List<Value> instances = new ArrayList<>();
    instances.add(instanceValue.build());

    EndpointName endpointName = EndpointName.of(projectId, region, endpointId);

    PredictResponse predictResponse = this.predictionServiceClient
        .predict(endpointName, instances, parameters);
    String textResponse = predictResponse.getPredictions(0).getStringValue();
    System.out.println(textResponse);
    return textResponse;
  }

  private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
    Gson gson = new Gson();
    String json = gson.toJson(map);
    Value.Builder builder = Value.newBuilder();
    JsonFormat.parser().merge(json, builder);
    return builder.build();
  }
}

Node.js

이 샘플을 사용해 보기 전에 Vertex AI 빠른 시작: 클라이언트 라이브러리 사용Node.js 설정 안내를 따르세요. 자세한 내용은 Vertex AI Node.js API 참고 문서를 참조하세요.

Vertex AI에 인증하려면 애플리케이션 기본 사용자 인증 정보를 설정합니다. 자세한 내용은 로컬 개발 환경의 인증 설정을 참조하세요.

async function gemma2PredictGpu(predictionServiceClient) {
  // Imports the Google Cloud Prediction Service Client library
  const {
    // TODO(developer): Uncomment PredictionServiceClient before running the sample.
    // PredictionServiceClient,
    helpers,
  } = require('@google-cloud/aiplatform');
  /**
   * TODO(developer): Update these variables before running the sample.
   */
  const projectId = 'your-project-id';
  const endpointRegion = 'your-vertex-endpoint-region';
  const endpointId = 'your-vertex-endpoint-id';

  // Default configuration
  const config = {maxOutputTokens: 1024, temperature: 0.9, topP: 1.0, topK: 1};
  // Prompt used in the prediction
  const prompt = 'Why is the sky blue?';

  // Encapsulate the prompt in a correct format for GPUs
  // Example format: [{inputs: 'Why is the sky blue?', parameters: {temperature: 0.9}}]
  const input = {
    inputs: prompt,
    parameters: config,
  };

  // Convert input message to a list of GAPIC instances for model input
  const instances = [helpers.toValue(input)];

  // TODO(developer): Uncomment apiEndpoint and predictionServiceClient before running the sample.
  // const apiEndpoint = `${endpointRegion}-aiplatform.googleapis.com`;

  // Create a client
  // predictionServiceClient = new PredictionServiceClient({apiEndpoint});

  // Call the Gemma2 endpoint
  const gemma2Endpoint = `projects/${projectId}/locations/${endpointRegion}/endpoints/${endpointId}`;

  const [response] = await predictionServiceClient.predict({
    endpoint: gemma2Endpoint,
    instances,
  });

  const predictions = response.predictions;
  const text = predictions[0].stringValue;

  console.log('Predictions:', text);
  return text;
}

module.exports = gemma2PredictGpu;

// TODO(developer): Uncomment below lines before running the sample.
// gemma2PredictGpu(...process.argv.slice(2)).catch(err => {
//   console.error(err.message);
//   process.exitCode = 1;
// });

Python

이 샘플을 사용해 보기 전에 Vertex AI 빠른 시작: 클라이언트 라이브러리 사용Python 설정 안내를 따르세요. 자세한 내용은 Vertex AI Python API 참고 문서를 참조하세요.

Vertex AI에 인증하려면 애플리케이션 기본 사용자 인증 정보를 설정합니다. 자세한 내용은 로컬 개발 환경의 인증 설정을 참조하세요.

"""
Sample to run inference on a Gemma2 model deployed to a Vertex AI endpoint with GPU accellerators.
"""

from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

# TODO(developer): Update & uncomment lines below
# PROJECT_ID = "your-project-id"
# ENDPOINT_REGION = "your-vertex-endpoint-region"
# ENDPOINT_ID = "your-vertex-endpoint-id"

# Default configuration
config = {"max_tokens": 1024, "temperature": 0.9, "top_p": 1.0, "top_k": 1}

# Prompt used in the prediction
prompt = "Why is the sky blue?"

# Encapsulate the prompt in a correct format for GPUs
# Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.9}}]
input = {"inputs": prompt, "parameters": config}

# Convert input message to a list of GAPIC instances for model input
instances = [json_format.ParseDict(input, Value())]

# Create a client
api_endpoint = f"{ENDPOINT_REGION}-aiplatform.googleapis.com"
client = aiplatform.gapic.PredictionServiceClient(
    client_options={"api_endpoint": api_endpoint}
)

# Call the Gemma2 endpoint
gemma2_end_point = (
    f"projects/{PROJECT_ID}/locations/{ENDPOINT_REGION}/endpoints/{ENDPOINT_ID}"
)
response = client.predict(
    endpoint=gemma2_end_point,
    instances=instances,
)
text_responses = response.predictions
print(text_responses[0])

다음 단계

다른 Google Cloud 제품의 코드 샘플을 검색하고 필터링하려면 Google Cloud 샘플 브라우저를 참조하세요.