- 1.122.0 (latest)
- 1.121.0
- 1.120.0
- 1.119.0
- 1.118.0
- 1.117.0
- 1.95.1
- 1.94.0
- 1.93.1
- 1.92.0
- 1.91.0
- 1.90.0
- 1.89.0
- 1.88.0
- 1.87.0
- 1.86.0
- 1.85.0
- 1.84.0
- 1.83.0
- 1.82.0
- 1.81.0
- 1.80.0
- 1.79.0
- 1.78.0
- 1.77.0
- 1.76.0
- 1.75.0
- 1.74.0
- 1.73.0
- 1.72.0
- 1.71.1
- 1.70.0
- 1.69.0
- 1.68.0
- 1.67.1
- 1.66.0
- 1.65.0
- 1.63.0
- 1.62.0
- 1.60.0
- 1.59.0
SourceModel(base_model: str, custom_base_model: str = "")A model that is used in managed OSS supervised tuning.
Usage:
model = SourceModel(
base_model="meta/llama3.1-8b", # OSS model name <publisher>/<model_name>
custom_base_model="gs://user-bucket/custom-weights",
)
sft_tuning_job = sft.train(
source_model=model,
train_dataset="gs://my-bucket/train.jsonl",
validation_dataset="gs://my-bucket/validation.jsonl",
epochs=4,
tuned_model_display_name="my-tuned-model",
output_uri="gs://user-bucket/tuned-model"
)
while not sft_tuning_job.has_ended:
time.sleep(60)
sft_tuning_job.refresh()
tuned_model = aiplatform.Model(sft_tuning_job.tuned_model_name)
```
Methods
SourceModel
SourceModel(base_model: str, custom_base_model: str = "")Initializes SourceModel.