Skip to content

Commit de7d6fe

Browse files
authored
Merge pull request #1759 from dimagi/sk/custom-models
replace custom models
2 parents cc86c9e + 12a28f2 commit de7d6fe

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

apps/service_providers/llm_service/default_models.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def _update_llm_provider_models(LlmProviderModel):
137137
for m in existing_custom_by_team.values():
138138
existing_custom_global[(m.type, m.name)].append(m)
139139

140+
created_models = dict()
140141
for provider_type, provider_models in DEFAULT_LLM_PROVIDER_MODELS.items():
141142
for model in provider_models:
142143
key = (provider_type, model.name)
@@ -147,7 +148,7 @@ def _update_llm_provider_models(LlmProviderModel):
147148
existing_global_model.max_token_limit = model.token_limit
148149
existing_global_model.save()
149150
else:
150-
LlmProviderModel.objects.create(
151+
created_models[(provider_type, model.name)] = LlmProviderModel.objects.create(
151152
team=None,
152153
type=provider_type,
153154
name=model.name,
@@ -170,6 +171,22 @@ def _update_llm_provider_models(LlmProviderModel):
170171

171172
provider_model.delete()
172173

174+
# replace existing custom models with the new global model and delete the custom models
175+
for key, model in created_models.items():
176+
if key in existing_custom_global:
177+
for custom_model in existing_custom_global[key]:
178+
related_objects = get_related_objects(custom_model)
179+
for obj in related_objects:
180+
field = [f for f in obj._meta.fields if f.related_model == LlmProviderModel][0]
181+
setattr(obj, field.attname, model.id)
182+
obj.save(update_fields=[field.name])
183+
184+
related_pipeline_nodes = get_related_pipelines_queryset(custom_model, "llm_provider_model_id")
185+
for node in related_pipeline_nodes.select_related("pipeline").all():
186+
_update_pipeline_node_param(node.pipeline, node, "llm_provider_model_id", model.id)
187+
188+
custom_model.delete()
189+
173190

174191
def _get_or_create_custom_model(team_object, key, global_model, existing_custom_by_team):
175192
"""Check the `existing_custom_by_team` mapping for a custom model for the given team and key.

apps/service_providers/tests/test_default_models.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from unittest.mock import patch
2+
13
import pytest
24

35
from apps.service_providers.llm_service.default_models import (
46
DEFAULT_LLM_PROVIDER_MODELS,
7+
Model,
58
get_default_model,
69
update_llm_provider_models,
710
)
@@ -109,6 +112,52 @@ def test_converts_old_global_models_to_custom_models_pipelines():
109112
assert node_data[0]["data"]["params"]["llm_provider_model_id"] == custom_model.id
110113

111114

115+
@pytest.mark.django_db()
116+
def test_replaces_custom_models_with_global_models():
117+
experiment = ExperimentFactory()
118+
old_custom_model = experiment.llm_provider_model
119+
assert old_custom_model.team is not None # Ensure it's a custom model
120+
121+
# no global model should exist
122+
assert not LlmProviderModel.objects.filter(
123+
team=None, type=old_custom_model.type, name=old_custom_model.name
124+
).exists()
125+
126+
defaults = {old_custom_model.type: [Model(old_custom_model.name, old_custom_model.max_token_limit)]}
127+
with patch("apps.service_providers.llm_service.default_models.DEFAULT_LLM_PROVIDER_MODELS", defaults):
128+
update_llm_provider_models()
129+
130+
new_global_model = LlmProviderModel.objects.get(team=None, type=old_custom_model.type, name=old_custom_model.name)
131+
# custom model should now point to the global model
132+
assert not LlmProviderModel.objects.filter(id=old_custom_model.id).exists()
133+
experiment.refresh_from_db()
134+
assert experiment.llm_provider_model_id == new_global_model.id
135+
136+
137+
@pytest.mark.django_db()
138+
def test_converts_custom_models_to_global_models_pipelines():
139+
custom_model = LlmProviderModelFactory()
140+
pipeline = get_pipeline(custom_model)
141+
142+
# no custom model should exist
143+
assert not LlmProviderModel.objects.filter(team=None, type=custom_model.type, name=custom_model.name).exists()
144+
145+
defaults = {custom_model.type: [Model(custom_model.name, custom_model.max_token_limit)]}
146+
with patch("apps.service_providers.llm_service.default_models.DEFAULT_LLM_PROVIDER_MODELS", defaults):
147+
update_llm_provider_models()
148+
149+
# custom model is removed
150+
assert not LlmProviderModel.objects.filter(id=custom_model.id).exists()
151+
152+
# global model is created
153+
global_model = LlmProviderModel.objects.get(team=None, type=custom_model.type, name=custom_model.name)
154+
# pipeline is updated to use the custom model
155+
pipeline.refresh_from_db()
156+
assert pipeline.node_set.get(type="LLMResponseWithPrompt").params["llm_provider_model_id"] == global_model.id
157+
node_data = [node for node in pipeline.data["nodes"] if node["data"]["type"] == "LLMResponseWithPrompt"]
158+
assert node_data[0]["data"]["params"]["llm_provider_model_id"] == global_model.id
159+
160+
112161
def get_pipeline(llm_provider_model):
113162
pipeline = PipelineFactory()
114163
pipeline.data["nodes"].append(

0 commit comments

Comments
 (0)