25
25
openpipe_data_converter ,
26
26
openpipe_dataset_creator ,
27
27
openpipe_finetuning_starter ,
28
+ openpipe_finetuning_starter_sdk ,
28
29
)
29
30
30
31
logger = get_logger (__name__ )
@@ -49,7 +50,7 @@ def openpipe_finetuning(
49
50
50
51
# Fine-tuning parameters
51
52
model_name : str = "zenml_finetuned_model" ,
52
- base_model : str = "meta-llama/Meta-Llama-3-8B-Instruct" ,
53
+ base_model : str = "meta-llama/Meta-Llama-3.1 -8B-Instruct" ,
53
54
enable_sft : bool = True ,
54
55
enable_preference_tuning : bool = False ,
55
56
learning_rate_multiplier : float = 1.0 ,
@@ -61,6 +62,9 @@ def openpipe_finetuning(
61
62
verbose_logs : bool = True ,
62
63
auto_rename : bool = True ,
63
64
force_overwrite : bool = False ,
65
+
66
+ # Implementation options
67
+ use_sdk : bool = False ,
64
68
):
65
69
"""
66
70
OpenPipe fine-tuning pipeline.
@@ -93,6 +97,7 @@ def openpipe_finetuning(
93
97
verbose_logs: Whether to log detailed model information during polling
94
98
auto_rename: If True, automatically append a timestamp to model name if it already exists
95
99
force_overwrite: If True, delete existing model with the same name before creating new one
100
+ use_sdk: If True, use the Python OpenPipe SDK instead of direct API calls
96
101
97
102
Returns:
98
103
A dictionary with details about the fine-tuning job, including model information
@@ -122,24 +127,46 @@ def openpipe_finetuning(
122
127
base_url = base_url ,
123
128
)
124
129
125
- # Start fine-tuning and monitor progress
126
- finetuning_result = openpipe_finetuning_starter (
127
- dataset_id = dataset_id ,
128
- model_name = model_name ,
129
- base_model = base_model ,
130
- openpipe_api_key = openpipe_api_key ,
131
- base_url = base_url ,
132
- enable_sft = enable_sft ,
133
- enable_preference_tuning = enable_preference_tuning ,
134
- learning_rate_multiplier = learning_rate_multiplier ,
135
- num_epochs = num_epochs ,
136
- batch_size = batch_size ,
137
- default_temperature = default_temperature ,
138
- wait_for_completion = wait_for_completion ,
139
- timeout_minutes = timeout_minutes ,
140
- verbose_logs = verbose_logs ,
141
- auto_rename = auto_rename ,
142
- force_overwrite = force_overwrite ,
143
- )
130
+ # Choose between SDK and direct API implementation
131
+ if use_sdk :
132
+ # Use the SDK implementation
133
+ finetuning_result = openpipe_finetuning_starter_sdk (
134
+ dataset_id = dataset_id ,
135
+ model_name = model_name ,
136
+ base_model = base_model ,
137
+ openpipe_api_key = openpipe_api_key ,
138
+ base_url = base_url ,
139
+ enable_sft = enable_sft ,
140
+ enable_preference_tuning = enable_preference_tuning ,
141
+ learning_rate_multiplier = learning_rate_multiplier ,
142
+ num_epochs = num_epochs ,
143
+ batch_size = batch_size ,
144
+ default_temperature = default_temperature ,
145
+ wait_for_completion = wait_for_completion ,
146
+ timeout_minutes = timeout_minutes ,
147
+ verbose_logs = verbose_logs ,
148
+ auto_rename = auto_rename ,
149
+ force_overwrite = force_overwrite ,
150
+ )
151
+ else :
152
+ # Use the original direct API implementation
153
+ finetuning_result = openpipe_finetuning_starter (
154
+ dataset_id = dataset_id ,
155
+ model_name = model_name ,
156
+ base_model = base_model ,
157
+ openpipe_api_key = openpipe_api_key ,
158
+ base_url = base_url ,
159
+ enable_sft = enable_sft ,
160
+ enable_preference_tuning = enable_preference_tuning ,
161
+ learning_rate_multiplier = learning_rate_multiplier ,
162
+ num_epochs = num_epochs ,
163
+ batch_size = batch_size ,
164
+ default_temperature = default_temperature ,
165
+ wait_for_completion = wait_for_completion ,
166
+ timeout_minutes = timeout_minutes ,
167
+ verbose_logs = verbose_logs ,
168
+ auto_rename = auto_rename ,
169
+ force_overwrite = force_overwrite ,
170
+ )
144
171
145
- return finetuning_result
172
+ return finetuning_result
0 commit comments