|
| 1 | +from unittest.mock import patch |
| 2 | + |
1 | 3 | import pytest
|
2 | 4 |
|
3 | 5 | from apps.service_providers.llm_service.default_models import (
|
4 | 6 | DEFAULT_LLM_PROVIDER_MODELS,
|
| 7 | + Model, |
5 | 8 | get_default_model,
|
6 | 9 | update_llm_provider_models,
|
7 | 10 | )
|
@@ -109,6 +112,52 @@ def test_converts_old_global_models_to_custom_models_pipelines():
|
109 | 112 | assert node_data[0]["data"]["params"]["llm_provider_model_id"] == custom_model.id
|
110 | 113 |
|
111 | 114 |
|
| 115 | +@pytest.mark.django_db() |
| 116 | +def test_replaces_custom_models_with_global_models(): |
| 117 | + experiment = ExperimentFactory() |
| 118 | + old_custom_model = experiment.llm_provider_model |
| 119 | + assert old_custom_model.team is not None # Ensure it's a custom model |
| 120 | + |
| 121 | + # no global model should exist |
| 122 | + assert not LlmProviderModel.objects.filter( |
| 123 | + team=None, type=old_custom_model.type, name=old_custom_model.name |
| 124 | + ).exists() |
| 125 | + |
| 126 | + defaults = {old_custom_model.type: [Model(old_custom_model.name, old_custom_model.max_token_limit)]} |
| 127 | + with patch("apps.service_providers.llm_service.default_models.DEFAULT_LLM_PROVIDER_MODELS", defaults): |
| 128 | + update_llm_provider_models() |
| 129 | + |
| 130 | + new_global_model = LlmProviderModel.objects.get(team=None, type=old_custom_model.type, name=old_custom_model.name) |
| 131 | + # custom model should now point to the global model |
| 132 | + assert not LlmProviderModel.objects.filter(id=old_custom_model.id).exists() |
| 133 | + experiment.refresh_from_db() |
| 134 | + assert experiment.llm_provider_model_id == new_global_model.id |
| 135 | + |
| 136 | + |
| 137 | +@pytest.mark.django_db() |
| 138 | +def test_converts_custom_models_to_global_models_pipelines(): |
| 139 | + custom_model = LlmProviderModelFactory() |
| 140 | + pipeline = get_pipeline(custom_model) |
| 141 | + |
| 142 | + # no custom model should exist |
| 143 | + assert not LlmProviderModel.objects.filter(team=None, type=custom_model.type, name=custom_model.name).exists() |
| 144 | + |
| 145 | + defaults = {custom_model.type: [Model(custom_model.name, custom_model.max_token_limit)]} |
| 146 | + with patch("apps.service_providers.llm_service.default_models.DEFAULT_LLM_PROVIDER_MODELS", defaults): |
| 147 | + update_llm_provider_models() |
| 148 | + |
| 149 | + # custom model is removed |
| 150 | + assert not LlmProviderModel.objects.filter(id=custom_model.id).exists() |
| 151 | + |
| 152 | + # global model is created |
| 153 | + global_model = LlmProviderModel.objects.get(team=None, type=custom_model.type, name=custom_model.name) |
| 154 | + # pipeline is updated to use the custom model |
| 155 | + pipeline.refresh_from_db() |
| 156 | + assert pipeline.node_set.get(type="LLMResponseWithPrompt").params["llm_provider_model_id"] == global_model.id |
| 157 | + node_data = [node for node in pipeline.data["nodes"] if node["data"]["type"] == "LLMResponseWithPrompt"] |
| 158 | + assert node_data[0]["data"]["params"]["llm_provider_model_id"] == global_model.id |
| 159 | + |
| 160 | + |
112 | 161 | def get_pipeline(llm_provider_model):
|
113 | 162 | pipeline = PipelineFactory()
|
114 | 163 | pipeline.data["nodes"].append(
|
|
0 commit comments