Tune text embeddings

This guide shows you how to tune text embedding models to improve performance for your specific tasks. This guide covers the following topics:

Foundation embedding models are pre-trained on a large dataset of text, which provides a strong baseline for many tasks. For scenarios that require specialized knowledge or highly tailored performance, you can use model tuning to fine-tune the model's representations with your own data. Tuning is supported for stable versions of the textembedding-gecko and textembedding-gecko-multilingual models.

Text embedding models support supervised tuning. Supervised tuning uses labeled examples that demonstrate the type of output that you want from your text embedding model during inference.

To learn more about model tuning, see How model tuning works.

Expected quality improvement

Vertex AI uses a parameter-efficient tuning method for customization. This method shows significant quality gains of up to 41% (average 12%) on experiments performed on public retrieval benchmark datasets.

Use case for tuning an embedding model

Tuning a text embedding model helps adapt the model to a specific domain or task. This is useful when the pre-trained model isn't suited to your needs. For example, fine-tuning an embedding model on your company's customer support tickets can help a chatbot better understand common issues and answer questions more effectively. Without tuning, the model lacks specific knowledge about your support tickets and product solutions.

Tuning workflow

The model tuning workflow on Vertex AI for textembedding-gecko and textembedding-gecko-multilingual involves the following steps:

  1. Prepare your model tuning dataset.
  2. Upload the model tuning dataset to a Cloud Storage bucket.
  3. Configure your project for Vertex AI Pipelines.
  4. Create a model tuning job.

After the tuning job is complete, the tuned model is added to the Model Registry. Unlike other model tuning jobs, a text embedding tuning job doesn't automatically deploy the model to an endpoint. You must manually deploy your tuned model.

Prepare your embeddings dataset

The dataset that you use to tune an embedding model includes data that aligns with the task you want the model to perform.

Dataset format for tuning an embeddings model

The training dataset consists of the following files, which must be in Cloud Storage. You define the path for each file with a parameter when you launch the tuning pipeline.

  • Corpus file: The path is defined by the corpus_path parameter. This is a JSONL file where each line has the fields _id, title, and text with string values. _id and text are required, while title is optional. The following is an example corpus.jsonl file:

    {"_id": "doc1", "title": "Get an introduction to generative AI on Vertex AI", "text": "Vertex AI Studio offers a Google Cloud console tool for rapidly prototyping and testing generative AI models. Learn how you can use Vertex AI Studio to test models using prompt samples, design and save prompts, tune a foundation model, and convert between speech and text."}
    {"_id": "doc2", "title": "Use gen AI for summarization, classification, and extraction", "text": "Learn how to create text prompts for handling any number of tasks with Vertex AI's generative AI support. Some of the most common tasks are classification, summarization, and extraction. Vertex AI's PaLM API for text lets you design prompts with flexibility in terms of their structure and format."}
    {"_id": "doc3", "title": "Custom ML training overview and documentation", "text": "Get an overview of the custom training workflow in Vertex AI, the benefits of custom training, and the various training options that are available. This page also details every step involved in the ML training workflow from preparing data to predictions."}
    {"_id": "doc4", "text": "Text embeddings are useful for clustering, information retrieval, retrieval-augmented generation (RAG), and more."}
    {"_id": "doc5", "title": "Text embedding tuning", "text": "Google's text embedding models can be tuned on Vertex AI."}
    
  • Query file: The path is defined by the queries_path parameter. The query file contains your example queries, is in JSONL format, and has the same fields as the corpus file. The following is an example queries.jsonl file:

    {"_id": "query1", "text": "Does Vertex support generative AI?"}
    {"_id": "query2", "text": "What can I do with Vertex GenAI offerings?"}
    {"_id": "query3", "text": "How do I train my models using Vertex?"}
    {"_id": "query4", "text": "What is a text embedding?"}
    {"_id": "query5", "text": "Can text embedding models be tuned on Vertex?"}
    {"_id": "query6", "text": "embeddings"}
    {"_id": "query7", "text": "embeddings for rag"}
    {"_id": "query8", "text": "custom model training"}
    {"_id": "query9", "text": "Google Cloud PaLM API"}
    
  • Training labels: The path is defined by the train_label_path parameter, which is the Cloud Storage URI to the training label data. The labels must be in a TSV file with a header. Your training labels file must include a subset of the queries and the corpus. The file must have the columns query-id, corpus-id, and score.

    • query-id: A string that matches the _id key from the query file.
    • corpus-id: A string that matches the _id key in the corpus file.
    • score: A non-negative integer. A score greater than 0 indicates that the document is related to the query. Higher scores indicate greater relevance. If the score is omitted, the default value is 1. If a query and document pair is unrelated, you can either omit it from the labels file or include it with a score of 0.

    The following is an example train_labels.tsv file:

    query-id    corpus-id   score
    query1  doc1    1
    query2  doc2    1
    query3  doc3    2
    query3  doc5  1
    query4  doc4  1
    query4  doc5  1
    query5  doc5  2
    query6  doc4  1
    query6  doc5  1
    query7  doc4  1
    query8  doc3  1
    query9  doc2  1
    
  • Test labels (Optional): The test labels have the same format as the training labels. You specify them with the test_label_path parameter. If you don't provide a test_label_path, the test labels are automatically split from the training labels.

  • Validation labels (Optional): The validation labels have the same format as the training labels. You specify them with the validation_label_path parameter. If you don't provide a validation_label_path, the validation labels are automatically split from the training labels.

Dataset size requirements

The provided dataset files must meet the following constraints:

  • The number of queries is between 9 and 10,000.
  • The number of documents in the corpus is between 9 and 500,000.
  • Each dataset label file includes at least 3 query IDs, and across all dataset splits there are at least 9 query IDs.
  • The total number of labels is less than 500,000.

Configure your project for Vertex AI Pipelines

You run tuning jobs in your project by using the Vertex AI Pipelines platform.

Configuring permissions

The pipeline executes training code under two service agents. You must grant specific roles to these service agents to run training using your project and dataset.

  • Compute Engine default service account PROJECT_NUMBER-compute@developer.gserviceaccount.com This service account requires the following permissions:

    • Storage Object Viewer access to each dataset file you created in Cloud Storage.
    • Storage Object User access to the output Cloud Storage directory of your pipeline, PIPELINE_OUTPUT_DIRECTORY.
    • Vertex AI User access to your project.

    Instead of the Compute Engine default service account, you can specify a custom service account. For more information, see Configure a service account with granular permissions.

  • Vertex AI Tuning Service Agent service-PROJECT_NUMBER@gcp-sa-aiplatform-ft.iam.gserviceaccount.com This service account requires the following permissions:

    • Storage Object Viewer access to each dataset file you created in Cloud Storage.
    • Storage Object User access to the output Cloud Storage directory of your pipeline, PIPELINE_OUTPUT_DIRECTORY.

For more information about configuring Cloud Storage dataset permissions, see Configure a Cloud Storage bucket for pipeline artifacts.

Using accelerators

Tuning requires GPU accelerators. You can use any of the following accelerators for the text embedding tuning pipeline:

  • NVIDIA_L4
  • NVIDIA_TESLA_A100
  • NVIDIA_TESLA_T4
  • NVIDIA_TESLA_V100
  • NVIDIA_TESLA_P100

To launch a tuning job, you need adequate Restricted image training GPUs quota for the accelerator type and region you have selected, for example Restricted image training Nvidia V100 GPUs per region. To increase the quota of your project, see Request additional quota.

Not all accelerators are available in all regions. For more information, see Using accelerators in Vertex AI.

Create an embedding model tuning job

You can create an embedding model tuning job by using the Google Cloud console, the REST API, or the client libraries.

REST

To create an embedding model tuning job, use the projects.locations.pipelineJobs.create method.

Before using any of the request data, make the following replacements:

  • PROJECT_ID: Your Google Cloud project ID.
  • PIPELINE_OUTPUT_DIRECTORY: Path for the pipeline output artifacts, starting with "gs://".

HTTP method and URL:

POST https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/pipelineJobs

Request JSON body:

{
  "displayName": "tune_text_embeddings_model_sample",
  "runtimeConfig": {
    "gcsOutputDirectory": "PIPELINE_OUTPUT_DIRECTORY",
    "parameterValues": {
      "corpus_path": "gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/corpus.jsonl",
      "queries_path": "gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/queries.jsonl",
      "train_label_path": "gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/train.tsv",
      "test_label_path": "gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/test.tsv",
      "base_model_version_id":"text-embedding-004",
      "task_type": "DEFAULT",
      "batch_size": "128",
      "train_steps": "1000",
      "output_dimensionality": "768",
      "learning_rate_multiplier": "1.0"
    }
  },
  "templateUri": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.3"
}

To send your request, expand one of these options:

You should receive a JSON response similar to the following:

After launching the pipeline, follow the progress of your tuning job through the Google Cloud console.

Go to Google Cloud console

Python

To learn how to install or update the Vertex AI SDK for Python, see Install the Vertex AI SDK for Python. For more information, see the Python API reference documentation.

import re

from google.cloud.aiplatform import initializer as aiplatform_init
from vertexai.language_models import TextEmbeddingModel


def tune_embedding_model(
    api_endpoint: str,
    base_model_name: str = "text-embedding-005",
    corpus_path: str = "gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/corpus.jsonl",
    queries_path: str = "gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/queries.jsonl",
    train_label_path: str = "gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/train.tsv",
    test_label_path: str = "gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/test.tsv",
):  # noqa: ANN201
    """Tune an embedding model using the specified parameters.
    Args:
        api_endpoint (str): The API endpoint for the Vertex AI service.
        base_model_name (str): The name of the base model to use for tuning.
        corpus_path (str): GCS URI of the JSONL file containing the corpus data.
        queries_path (str): GCS URI of the JSONL file containing the queries data.
        train_label_path (str): GCS URI of the TSV file containing the training labels.
        test_label_path (str): GCS URI of the TSV file containing the test labels.
    """
    match = re.search(r"^(\w+-\w+)", api_endpoint)
    location = match.group(1) if match else "us-central1"
    base_model = TextEmbeddingModel.from_pretrained(base_model_name)
    tuning_job = base_model.tune_model(
        task_type="DEFAULT",
        corpus_data=corpus_path,
        queries_data=queries_path,
        training_data=train_label_path,
        test_data=test_label_path,
        batch_size=128,  # The batch size to use for training.
        train_steps=1000,  # The number of training steps.
        tuned_model_location=location,
        output_dimensionality=768,  # The dimensionality of the output embeddings.
        learning_rate_multiplier=1.0,  # The multiplier for the learning rate.
    )
    return tuning_job

Java

Before trying this sample, follow the Java setup instructions in the Vertex AI quickstart using client libraries. For more information, see the Vertex AI Java API reference documentation.

To authenticate to Vertex AI, set up Application Default Credentials. For more information, see Set up authentication for a local development environment.

import com.google.cloud.aiplatform.v1.CreatePipelineJobRequest;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.PipelineJob;
import com.google.cloud.aiplatform.v1.PipelineJob.RuntimeConfig;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.protobuf.Value;
import java.io.IOException;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class EmbeddingModelTuningSample {
  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running this sample.
    String apiEndpoint = "us-central1-aiplatform.googleapis.com:443";
    String project = "PROJECT";
    String baseModelVersionId = "BASE_MODEL_VERSION_ID";
    String taskType = "DEFAULT";
    String pipelineJobDisplayName = "PIPELINE_JOB_DISPLAY_NAME";
    String outputDir = "OUTPUT_DIR";
    String queriesPath = "QUERIES_PATH";
    String corpusPath = "CORPUS_PATH";
    String trainLabelPath = "TRAIN_LABEL_PATH";
    String testLabelPath = "TEST_LABEL_PATH";
    double learningRateMultiplier = 1.0;
    int outputDimensionality = 768;
    int batchSize = 128;
    int trainSteps = 1000;

    createEmbeddingModelTuningPipelineJob(
        apiEndpoint,
        project,
        baseModelVersionId,
        taskType,
        pipelineJobDisplayName,
        outputDir,
        queriesPath,
        corpusPath,
        trainLabelPath,
        testLabelPath,
        learningRateMultiplier,
        outputDimensionality,
        batchSize,
        trainSteps);
  }

  public static PipelineJob createEmbeddingModelTuningPipelineJob(
      String apiEndpoint,
      String project,
      String baseModelVersionId,
      String taskType,
      String pipelineJobDisplayName,
      String outputDir,
      String queriesPath,
      String corpusPath,
      String trainLabelPath,
      String testLabelPath,
      double learningRateMultiplier,
      int outputDimensionality,
      int batchSize,
      int trainSteps)
      throws IOException {
    Matcher matcher = Pattern.compile("^(?<Location>\\w+-\\w+)").matcher(apiEndpoint);
    String location = matcher.matches() ? matcher.group("Location") : "us-central1";
    String templateUri =
        "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.4";
    PipelineServiceSettings settings =
        PipelineServiceSettings.newBuilder().setEndpoint(apiEndpoint).build();
    try (PipelineServiceClient client = PipelineServiceClient.create(settings)) {
      Map<String, Value> parameterValues =
          Map.of(
              "base_model_version_id", valueOf(baseModelVersionId),
              "task_type", valueOf(taskType),
              "queries_path", valueOf(queriesPath),
              "corpus_path", valueOf(corpusPath),
              "train_label_path", valueOf(trainLabelPath),
              "test_label_path", valueOf(testLabelPath),
              "learning_rate_multiplier", valueOf(learningRateMultiplier),
              "output_dimensionality", valueOf(outputDimensionality),
              "batch_size", valueOf(batchSize),
              "train_steps", valueOf(trainSteps));
      PipelineJob pipelineJob =
          PipelineJob.newBuilder()
              .setTemplateUri(templateUri)
              .setDisplayName(pipelineJobDisplayName)
              .setRuntimeConfig(
                  RuntimeConfig.newBuilder()
                      .setGcsOutputDirectory(outputDir)
                      .putAllParameterValues(parameterValues)
                      .build())
              .build();
      CreatePipelineJobRequest request =
          CreatePipelineJobRequest.newBuilder()
              .setParent(LocationName.of(project, location).toString())
              .setPipelineJob(pipelineJob)
              .build();
      return client.createPipelineJob(request);
    }
  }

  private static Value valueOf(String s) {
    return Value.newBuilder().setStringValue(s).build();
  }

  private static Value valueOf(int n) {
    return Value.newBuilder().setNumberValue(n).build();
  }

  private static Value valueOf(double n) {
    return Value.newBuilder().setNumberValue(n).build();
  }
}

Node.js

Before trying this sample, follow the Node.js setup instructions in the Vertex AI quickstart using client libraries. For more information, see the Vertex AI Node.js API reference documentation.

To authenticate to Vertex AI, set up Application Default Credentials. For more information, see Set up authentication for a local development environment.

async function main(
  apiEndpoint,
  project,
  outputDir,
  pipelineJobDisplayName = 'embedding-customization-pipeline-sample',
  baseModelVersionId = 'text-embedding-005',
  taskType = 'DEFAULT',
  corpusPath = 'gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/corpus.jsonl',
  queriesPath = 'gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/queries.jsonl',
  trainLabelPath = 'gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/train.tsv',
  testLabelPath = 'gs://cloud-samples-data/ai-platform/embedding/goog-10k-2024/r11/test.tsv',
  outputDimensionality = 768,
  learningRateMultiplier = 1.0,
  batchSize = 128,
  trainSteps = 1000
) {
  const aiplatform = require('@google-cloud/aiplatform');
  const {PipelineServiceClient} = aiplatform.v1;
  const {helpers} = aiplatform; // helps construct protobuf.Value objects.

  const client = new PipelineServiceClient({apiEndpoint});
  const match = apiEndpoint.match(/(?<L>\w+-\w+)/);
  const location = match ? match.groups.L : 'us-central1';
  const parent = `projects/${project}/locations/${location}`;
  const params = {
    base_model_version_id: baseModelVersionId,
    task_type: taskType,
    queries_path: queriesPath,
    corpus_path: corpusPath,
    train_label_path: trainLabelPath,
    test_label_path: testLabelPath,
    batch_size: batchSize,
    train_steps: trainSteps,
    output_dimensionality: outputDimensionality,
    learning_rate_multiplier: learningRateMultiplier,
  };
  const runtimeConfig = {
    gcsOutputDirectory: outputDir,
    parameterValues: Object.fromEntries(
      Object.entries(params).map(([k, v]) => [k, helpers.toValue(v)])
    ),
  };
  const pipelineJob = {
    templateUri:
      'https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.4',
    displayName: pipelineJobDisplayName,
    runtimeConfig,
  };
  async function createTuneJob() {
    const [response] = await client.createPipelineJob({parent, pipelineJob});
    console.log(`job_name: ${response.name}`);
    console.log(`job_state: ${response.state}`);
  }

  await createTuneJob();
}

Console

To tune a text embedding model by using the Google Cloud console, you can launch a customization pipeline using the following steps:

  1. In the Vertex AI section of the Google Cloud console, go to the Vertex AI Pipelines page.

    Go to Vertex AI Pipelines

  2. Click Create run to open the Create pipeline run pane.
  3. Click Select from existing pipelines and enter the following details:
    1. Select "ml-pipeline" from the select a resource drop-down.
    2. Select "llm-text-embedding" from the Repository drop-down.
    3. Select "tune-text-embedding-model" from the Pipeline or component drop-down.
    4. Select the version labeled "v1.1.3" from the Version drop-down.
  4. Specify a Run name to uniquely identify the pipeline run.
  5. In the Region drop-down list, select the region to create the pipeline run, which will be the same region in which your tuned model is created.
  6. Click Continue. The Runtime configuration pane appears.
  7. Under Cloud storage location, click Browse to select the Cloud Storage bucket for storing the pipeline output artifacts, and then click Select.
  8. Under Pipeline parameters, specify your parameters for the tuning pipeline. The three required parameters are corpus_path, queries_path, and train_label_path, with formats described in Prepare your embeddings dataset. For more detailed information about each parameter, refer to the REST tab of this section.
  9. Click Submit to create your pipeline run.

Other supported features

Text embedding tuning supports VPC Service Controls. To run the tuning job within a Virtual Private Cloud (VPC), pass the network parameter when you create the PipelineJob.

To use CMEK (customer-managed encryption keys), pass the key to the parameterValues.encryption_spec_key_name pipeline parameter and the encryptionSpec.kmsKeyName parameter when you create the PipelineJob.

Use your tuned model

View tuned models in Model Registry

When your tuning job completes, the tuned model isn't automatically deployed to an endpoint. You can find the tuned model as a Model resource in Model Registry. You can view a list of models in your current project, including your tuned models, by using the Google Cloud console.

To view your tuned models in the Google Cloud console, go to the Vertex AI Model Registry page.

Go to Vertex AI Model Registry

Deploy your model

After you tune the embedding model, you need to deploy the Model resource. To deploy your tuned embedding model, see Deploy a model to an endpoint.

Unlike foundation models, you manage tuned text embedding models. This includes managing serving resources, like machine type and accelerators. To prevent out-of-memory errors during prediction, we recommend that you deploy using the NVIDIA_TESLA_A100 GPU type, which can support batch sizes up to 5 for any input length.

Similar to the textembedding-gecko foundation model, your tuned model supports up to 3,072 tokens and can truncate longer inputs.

Get predictions on a deployed model

After your tuned model is deployed, you can use one of the following commands to send requests to the tuned model endpoint.

Example curl command for tuned textembedding-gecko@001 models

To get predictions from a tuned version of textembedding-gecko@001, use the following curl command.

PROJECT_ID=PROJECT_ID
LOCATION=LOCATION
ENDPOINT_URI=https://${LOCATION}-aiplatform.googleapis.com
MODEL_ENDPOINT=TUNED_MODEL_ENDPOINT_ID

curl -X POST -H "Authorization: Bearer $(gcloud auth print-access-token)" \
    -H "Content-Type: application/json"  \
    ${ENDPOINT_URI}/v1/projects/${PROJECT_ID}/locations/${LOCATION}/endpoints/${MODEL_ENDPOINT}:predict \
    -d '{
  "instances": [
    {
      "content": "Dining in New York City"
    },
    {
      "content": "Best resorts on the east coast"
    }
  ]
}'

Example curl command for other models

Tuned versions of other models (for example, textembedding-gecko@003 and textembedding-gecko-multilingual@001) require two additional inputs: task_type and title. For more information about these parameters, see curl command.

PROJECT_ID=PROJECT_ID
LOCATION=LOCATION
ENDPOINT_URI=https://${LOCATION}-aiplatform.googleapis.com
MODEL_ENDPOINT=TUNED_MODEL_ENDPOINT_ID

curl -X POST -H "Authorization: Bearer $(gcloud auth print-access-token)" \
    -H "Content-Type: application/json"  \
    ${ENDPOINT_URI}/v1/projects/${PROJECT_ID}/locations/${LOCATION}/endpoints/${MODEL_ENDPOINT}:predict \
    -d '{
  "instances": [
    {
      "content": "Dining in New York City",
      "task_type": "DEFAULT",
      "title": ""
    },
    {
      "content": "There are many resorts to choose from on the East coast...",
      "task_type": "RETRIEVAL_DOCUMENT",
      "title": "East Coast Resorts"
    }
  ]
}'

Example output

This output applies to both textembedding-gecko and textembedding-gecko-multilingual models, regardless of version.

{
 "predictions": [
   [ ... ],
   [ ... ],
   ...
 ],
 "deployedModelId": "...",
 "model": "projects/.../locations/.../models/...",
 "modelDisplayName": "tuned-text-embedding-model",
 "modelVersionId": "1"
}

What's next