Skip to content

Commit c204660

Browse files
committed
Fixes model name problem for the lora modeules
1 parent d917042 commit c204660

File tree

1 file changed

+34
-26
lines changed

1 file changed

+34
-26
lines changed

ads/aqua/model/model.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -295,65 +295,73 @@ def create_multi(
295295

296296
selected_models_deployment_containers = set()
297297

298-
# Process each model
298+
# Process each model in the input list
299299
for model in models:
300+
# Retrieve model metadata from the Model Catalog using the model ID
300301
source_model = DataScienceModel.from_id(model.model_id)
301302
display_name = source_model.display_name
302303
model_file_description = source_model.model_file_description
303-
# Update model name in user's input model
304+
# If model_name is not explicitly provided, use the model's display name
304305
model.model_name = model.model_name or display_name
305306

306-
# TODO Uncomment the section below, if only service models should be allowed for multi-model deployment
307-
# if not source_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, UNKNOWN):
308-
# raise AquaValueError(
309-
# f"Invalid selected model {display_name}. "
310-
# "Currently only service models are supported for multi model deployment."
311-
# )
307+
if not model_file_description:
308+
raise AquaValueError(
309+
f"Model '{source_model.display_name}' (ID: {model.model_id}) has no file description. "
310+
"Please register the model first."
311+
)
312312

313-
# check if model is a fine-tuned model and if so, add the fine tuned weights path to the fine_tune_weights_location pydantic field
313+
# Check if the model is a fine-tuned model based on its tags
314314
is_fine_tuned_model = (
315315
Tags.AQUA_FINE_TUNED_MODEL_TAG in source_model.freeform_tags
316316
)
317317

318-
model_artifact_path = ""
318+
base_model_artifact_path = ""
319319
fine_tune_path = ""
320320

321321
if is_fine_tuned_model:
322-
model_artifact_path, fine_tune_path = extract_fine_tune_artifacts_path(
323-
source_model
322+
# Extract artifact paths for the base and fine-tuned model
323+
base_model_artifact_path, fine_tune_path = (
324+
extract_fine_tune_artifacts_path(source_model)
324325
)
325-
# once we support multiple LoRA Modules use [LoraModuleSpec(**lora_module) for lora_module in model.fine_tune_weights]
326+
327+
# Create a single LoRA module specification for the fine-tuned model
328+
# TODO: Support multiple LoRA modules in the future
326329
model.fine_tune_weights = [
327330
LoraModuleSpec(
328331
model_id=model.model_id,
329-
model_name=display_name,
332+
model_name=model.model_name,
330333
model_path=fine_tune_path,
331334
)
332335
]
336+
337+
# Use the LoRA module name as the model's display name
338+
display_name = model.model_name
339+
340+
# Temporarily override model ID and name with those of the base model
341+
# TODO: Revisit this logic once proper base/FT model handling is implemented
333342
model.model_id, model.model_name = extract_base_model_from_ft(
334343
source_model
335344
)
336345
else:
337-
# Retrieve model artifact for base models
338-
model_artifact_path = source_model.artifact
346+
# For base models, use the original artifact path
347+
base_model_artifact_path = source_model.artifact
348+
display_name = model.model_name
339349

340-
if not model_artifact_path:
350+
if not base_model_artifact_path:
351+
# Fail if no artifact is found for the base model model
341352
raise AquaValueError(
342-
f"Model '{display_name}' (ID: {model.model_id}) has no artifacts. "
353+
f"Model '{model.model_name}' (ID: {model.model_id}) has no artifacts. "
343354
"Please register the model first."
344355
)
345356

346-
# Update model artifact location in user's input model
347-
model.artifact_location = model_artifact_path
357+
# Update the artifact path in the model configuration
358+
model.artifact_location = base_model_artifact_path
348359
display_name_list.append(display_name)
349-
self._extract_model_task(model, source_model)
350360

351-
if not model_file_description:
352-
raise AquaValueError(
353-
f"Model '{display_name}' (ID: {model.model_id}) has no file description. "
354-
"Please register the model first."
355-
)
361+
# Extract model task metadata from source model
362+
self._extract_model_task(model, source_model)
356363

364+
# Track model file description in a validated structure
357365
model_file_description_list.append(
358366
ModelFileDescription(**model_file_description)
359367
)

0 commit comments

Comments
 (0)