Ajustar um modelo de embedding usando os parâmetros especificados

Este exemplo de código demonstra como ajustar um modelo de incorporação usando a Vertex AI. O exemplo usa um modelo pré-treinado e o ajusta em um conjunto de dados específico.

Mais informações

Para ver a documentação detalhada que inclui este exemplo de código, consulte:

Exemplo de código

Python

Antes de testar esse exemplo, siga as instruções de configuração para Python no Guia de início rápido da Vertex AI sobre como usar bibliotecas de cliente. Para mais informações, consulte a documentação de referência da API Vertex AI para Python.

Para autenticar na Vertex AI, configure o Application Default Credentials. Para mais informações, consulte Configurar a autenticação para um ambiente de desenvolvimento local.

import re

from google.cloud.aiplatform import initializer as aiplatform_init
from vertexai.language_models import TextEmbeddingModel


def tune_embedding_model(
    api_endpoint: str,
    base_model_name: str = "text-embedding-005",
    corpus_path: str = "gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/corpus.jsonl",
    queries_path: str = "gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/queries.jsonl",
    train_label_path: str = "gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/train.tsv",
    test_label_path: str = "gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/test.tsv",
):  # noqa: ANN201
    """Tune an embedding model using the specified parameters.
    Args:
        api_endpoint (str): The API endpoint for the Vertex AI service.
        base_model_name (str): The name of the base model to use for tuning.
        corpus_path (str): GCS URI of the JSONL file containing the corpus data.
        queries_path (str): GCS URI of the JSONL file containing the queries data.
        train_label_path (str): GCS URI of the TSV file containing the training labels.
        test_label_path (str): GCS URI of the TSV file containing the test labels.
    """
    match = re.search(r"^(\w+-\w+)", api_endpoint)
    location = match.group(1) if match else "us-central1"
    base_model = TextEmbeddingModel.from_pretrained(base_model_name)
    tuning_job = base_model.tune_model(
        task_type="DEFAULT",
        corpus_data=corpus_path,
        queries_data=queries_path,
        training_data=train_label_path,
        test_data=test_label_path,
        batch_size=128,  # The batch size to use for training.
        train_steps=1000,  # The number of training steps.
        tuned_model_location=location,
        output_dimensionality=768,  # The dimensionality of the output embeddings.
        learning_rate_multiplier=1.0,  # The multiplier for the learning rate.
    )
    return tuning_job

A seguir

Para pesquisar e filtrar exemplos de código de outros Google Cloud produtos, consulte a pesquisa de exemplos de código doGoogle Cloud .