Ajusta un modelo de inserción con los parámetros especificados.

En este código de ejemplo se muestra cómo ajustar un modelo de inserciones con Vertex AI. En el ejemplo se usa un modelo preentrenado y se ajusta con un conjunto de datos específico.

Investigar más

Para obtener documentación detallada que incluya este código de muestra, consulta lo siguiente:

Código de ejemplo

Python

Antes de probar este ejemplo, sigue las Python instrucciones de configuración de la guía de inicio rápido de Vertex AI con bibliotecas de cliente. Para obtener más información, consulta la documentación de referencia de la API Python de Vertex AI.

Para autenticarte en Vertex AI, configura las credenciales predeterminadas de la aplicación. Para obtener más información, consulta el artículo Configurar la autenticación en un entorno de desarrollo local.

import re

from google.cloud.aiplatform import initializer as aiplatform_init
from vertexai.language_models import TextEmbeddingModel


def tune_embedding_model(
    api_endpoint: str,
    base_model_name: str = "text-embedding-005",
    corpus_path: str = "gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/corpus.jsonl",
    queries_path: str = "gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/queries.jsonl",
    train_label_path: str = "gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/train.tsv",
    test_label_path: str = "gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/test.tsv",
):  # noqa: ANN201
    """Tune an embedding model using the specified parameters.
    Args:
        api_endpoint (str): The API endpoint for the Vertex AI service.
        base_model_name (str): The name of the base model to use for tuning.
        corpus_path (str): GCS URI of the JSONL file containing the corpus data.
        queries_path (str): GCS URI of the JSONL file containing the queries data.
        train_label_path (str): GCS URI of the TSV file containing the training labels.
        test_label_path (str): GCS URI of the TSV file containing the test labels.
    """
    match = re.search(r"^(\w+-\w+)", api_endpoint)
    location = match.group(1) if match else "us-central1"
    base_model = TextEmbeddingModel.from_pretrained(base_model_name)
    tuning_job = base_model.tune_model(
        task_type="DEFAULT",
        corpus_data=corpus_path,
        queries_data=queries_path,
        training_data=train_label_path,
        test_data=test_label_path,
        batch_size=128,  # The batch size to use for training.
        train_steps=1000,  # The number of training steps.
        tuned_model_location=location,
        output_dimensionality=768,  # The dimensionality of the output embeddings.
        learning_rate_multiplier=1.0,  # The multiplier for the learning rate.
    )
    return tuning_job

Siguientes pasos

Para buscar y filtrar ejemplos de código de otros Google Cloud productos, consulta el Google Cloud navegador de ejemplos.