チューニング用データセットを使用して言語基盤モデルを調整する。
もっと見る
このコードサンプルを含む詳細なドキュメントについては、以下をご覧ください。
コードサンプル
Java
このサンプルを試す前に、Vertex AI クライアント ライブラリをインストールするにある Java の設定手順を完了してください。詳細については、Vertex AI Java API のリファレンス ドキュメントをご覧ください。
Vertex AI に対する認証を行うには、アプリケーションのデフォルト認証情報を設定します。詳細については、ローカル開発環境の認証を設定するをご覧ください。
import com.google.cloud.aiplatform.v1beta1.CreatePipelineJobRequest;
import com.google.cloud.aiplatform.v1beta1.LocationName;
import com.google.cloud.aiplatform.v1beta1.PipelineJob;
import com.google.cloud.aiplatform.v1beta1.PipelineJob.RuntimeConfig;
import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings;
import com.google.protobuf.Value;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
public class CreatePipelineJobModelTuningSample {
public static void main(String[] args) throws IOException {
// TODO(developer): Replace these variables before running the sample.
String project = "PROJECT";
String location = "europe-west4"; // europe-west4 and us-central1 are the supported regions
String pipelineJobDisplayName = "PIPELINE_JOB_DISPLAY_NAME";
String modelDisplayName = "MODEL_DISPLAY_NAME";
String outputDir = "OUTPUT_DIR";
String datasetUri = "DATASET_URI";
int trainingSteps = 300;
createPipelineJobModelTuningSample(
project,
location,
pipelineJobDisplayName,
modelDisplayName,
outputDir,
datasetUri,
trainingSteps);
}
// Create a model tuning job
public static void createPipelineJobModelTuningSample(
String project,
String location,
String pipelineJobDisplayName,
String modelDisplayName,
String outputDir,
String datasetUri,
int trainingSteps)
throws IOException {
final String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
PipelineServiceSettings pipelineServiceSettings =
PipelineServiceSettings.newBuilder().setEndpoint(endpoint).build();
// Initialize client that will be used to send requests. This client only needs to be created
// once, and can be reused for multiple requests.
try (PipelineServiceClient client = PipelineServiceClient.create(pipelineServiceSettings)) {
Map<String, Value> parameterValues = new HashMap<>();
parameterValues.put("project", stringToValue(project));
parameterValues.put("model_display_name", stringToValue(modelDisplayName));
parameterValues.put("dataset_uri", stringToValue(datasetUri));
parameterValues.put(
"location",
stringToValue(
"us-central1")); // Deployment is only supported in us-central1 for Public Preview
parameterValues.put("large_model_reference", stringToValue("text-bison@001"));
parameterValues.put("train_steps", numberToValue(trainingSteps));
parameterValues.put("accelerator_type", stringToValue("GPU")); // Optional: GPU or TPU
RuntimeConfig runtimeConfig =
RuntimeConfig.newBuilder()
.setGcsOutputDirectory(outputDir)
.putAllParameterValues(parameterValues)
.build();
PipelineJob pipelineJob =
PipelineJob.newBuilder()
.setTemplateUri(
"https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v2.0.0")
.setDisplayName(pipelineJobDisplayName)
.setRuntimeConfig(runtimeConfig)
.build();
LocationName parent = LocationName.of(project, location);
CreatePipelineJobRequest request =
CreatePipelineJobRequest.newBuilder()
.setParent(parent.toString())
.setPipelineJob(pipelineJob)
.build();
PipelineJob response = client.createPipelineJob(request);
System.out.format("response: %s\n", response);
System.out.format("Name: %s\n", response.getName());
}
}
static Value stringToValue(String str) {
return Value.newBuilder().setStringValue(str).build();
}
static Value numberToValue(int n) {
return Value.newBuilder().setNumberValue(n).build();
}
}
Python
このサンプルを試す前に、Vertex AI クライアント ライブラリをインストールするにある Python の設定手順を完了してください。詳細については、Vertex AI Python API のリファレンス ドキュメントをご覧ください。
Vertex AI に対する認証を行うには、アプリケーションのデフォルト認証情報を設定します。詳細については、ローカル開発環境の認証を設定するをご覧ください。
from __future__ import annotations
from typing import Optional
from google.auth import default
from google.cloud import aiplatform
import pandas as pd
import vertexai
from vertexai.language_models import TextGenerationModel
from vertexai.preview.language_models import TuningEvaluationSpec
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
def tuning(
project_id: str,
location: str,
model_display_name: str,
training_data: pd.DataFrame | str,
train_steps: int = 10,
evaluation_dataset: Optional[str] = None,
tensorboard_instance_name: Optional[str] = None,
) -> TextGenerationModel:
"""Tune a new model, based on a prompt-response data.
"training_data" can be either the GCS URI of a file formatted in JSONL format
(for example: training_data=f'gs://{bucket}/{filename}.jsonl'), or a pandas
DataFrame. Each training example should be JSONL record with two keys, for
example:
{
"input_text": <input prompt>,
"output_text": <associated output>
},
or the pandas DataFame should contain two columns:
['input_text', 'output_text']
with rows for each training example.
Args:
project_id: GCP Project ID, used to initialize vertexai
location: GCP Region, used to initialize vertexai
model_display_name: Customized Tuned LLM model name.
training_data: GCS URI of jsonl file or pandas dataframe of training data.
train_steps: Number of training steps to use when tuning the model.
evaluation_dataset: GCS URI of jsonl file of evaluation data.
tensorboard_instance_name: The full name of the existing Vertex AI TensorBoard instance:
projects/PROJECT_ID/locations/LOCATION_ID/tensorboards/TENSORBOARD_INSTANCE_ID
Note that this instance must be in the same region as your tuning job.
"""
vertexai.init(project=project_id, location=location, credentials=credentials)
eval_spec = TuningEvaluationSpec(evaluation_data=evaluation_dataset)
eval_spec.tensorboard = aiplatform.Tensorboard(
tensorboard_name=tensorboard_instance_name
)
model = TextGenerationModel.from_pretrained("text-bison@002")
model.tune_model(
training_data=training_data,
# Optional:
model_display_name=model_display_name,
train_steps=train_steps,
tuning_job_location="europe-west4",
tuned_model_location=location,
tuning_evaluation_spec=eval_spec,
)
print(model._job.status)
return model
次のステップ
他の Google Cloud プロダクトに関連するコードサンプルの検索およびフィルタ検索を行うには、Google Cloud のサンプルをご覧ください。