使用调整数据集调整语言基础模型。
深入探索
如需查看包含此代码示例的详细文档,请参阅以下内容:
代码示例
Java
在尝试此示例之前,请按照《Vertex AI 快速入门:使用客户端库》中的 Java 设置说明执行操作。如需了解详情,请参阅 Vertex AI Java API 参考文档。
如需向 Vertex AI 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证。
import com.google.cloud.aiplatform.v1.CreatePipelineJobRequest;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.PipelineJob;
import com.google.cloud.aiplatform.v1.PipelineJob.RuntimeConfig;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.protobuf.Value;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
public class CreatePipelineJobModelTuningSample {
public static void main(String[] args) throws IOException {
// TODO(developer): Replace these variables before running the sample.
String project = "PROJECT";
String location = "europe-west4"; // europe-west4 and us-central1 are the supported regions
String pipelineJobDisplayName = "PIPELINE_JOB_DISPLAY_NAME";
String modelDisplayName = "MODEL_DISPLAY_NAME";
String outputDir = "OUTPUT_DIR";
String datasetUri = "DATASET_URI";
int trainingSteps = 300;
createPipelineJobModelTuningSample(
project,
location,
pipelineJobDisplayName,
modelDisplayName,
outputDir,
datasetUri,
trainingSteps);
}
// Create a model tuning job
public static void createPipelineJobModelTuningSample(
String project,
String location,
String pipelineJobDisplayName,
String modelDisplayName,
String outputDir,
String datasetUri,
int trainingSteps)
throws IOException {
final String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
PipelineServiceSettings pipelineServiceSettings =
PipelineServiceSettings.newBuilder().setEndpoint(endpoint).build();
// Initialize client that will be used to send requests. This client only needs to be created
// once, and can be reused for multiple requests.
try (PipelineServiceClient client = PipelineServiceClient.create(pipelineServiceSettings)) {
Map<String, Value> parameterValues = new HashMap<>();
parameterValues.put("project", stringToValue(project));
parameterValues.put("model_display_name", stringToValue(modelDisplayName));
parameterValues.put("dataset_uri", stringToValue(datasetUri));
parameterValues.put(
"location",
stringToValue(
"us-central1")); // Deployment is only supported in us-central1 for Public Preview
parameterValues.put("large_model_reference", stringToValue("text-bison@001"));
parameterValues.put("train_steps", numberToValue(trainingSteps));
parameterValues.put("accelerator_type", stringToValue("GPU")); // Optional: GPU or TPU
RuntimeConfig runtimeConfig =
RuntimeConfig.newBuilder()
.setGcsOutputDirectory(outputDir)
.putAllParameterValues(parameterValues)
.build();
PipelineJob pipelineJob =
PipelineJob.newBuilder()
.setTemplateUri(
"https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v2.0.0")
.setDisplayName(pipelineJobDisplayName)
.setRuntimeConfig(runtimeConfig)
.build();
LocationName parent = LocationName.of(project, location);
CreatePipelineJobRequest request =
CreatePipelineJobRequest.newBuilder()
.setParent(parent.toString())
.setPipelineJob(pipelineJob)
.build();
PipelineJob response = client.createPipelineJob(request);
System.out.format("response: %s\n", response);
System.out.format("Name: %s\n", response.getName());
}
}
static Value stringToValue(String str) {
return Value.newBuilder().setStringValue(str).build();
}
static Value numberToValue(int n) {
return Value.newBuilder().setNumberValue(n).build();
}
}
Python
在尝试此示例之前,请按照《Vertex AI 快速入门:使用客户端库》中的 Python 设置说明执行操作。如需了解详情,请参阅 Vertex AI Python API 参考文档。
如需向 Vertex AI 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证。
from __future__ import annotations
def tuning(
project_id: str,
) -> None:
import vertexai
from vertexai.language_models import TextGenerationModel
from google.auth import default
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
# Initialize Vertex AI
# TODO(developer): Update project_id
vertexai.init(project=project_id, location="us-central1", credentials=credentials)
model = TextGenerationModel.from_pretrained("text-bison@002")
tuning_job = model.tune_model(
training_data="gs://cloud-samples-data/ai-platform/generative_ai/headline_classification.jsonl",
tuning_job_location="europe-west4",
tuned_model_location="us-central1",
)
print(tuning_job._status)
return model
后续步骤
如需搜索和过滤其他 Google Cloud 产品的代码示例,请参阅 Google Cloud 示例浏览器。