Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions kairon/actions/definitions/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,27 +152,31 @@ async def __get_llm_params(self, k_faq_action_config: dict, dispatcher: Collecti
similarity_prompt = []
params = {}
num_bot_responses = k_faq_action_config['num_bot_responses']
for prompt in k_faq_action_config['llm_prompts']:
if prompt['type'] == LlmPromptType.system.value and prompt['is_enabled']:
system_prompt = f"{prompt['data']}\n"
elif prompt['type'] == LlmPromptType.user.value and prompt['is_enabled']:
system_prompt = k_faq_action_config.get("system_prompt", "")
user_prompt = k_faq_action_config.get("user_prompt", "")

for prompt in k_faq_action_config.get("contexts", []):
if prompt['type'] == LlmPromptType.user.value and prompt['is_enabled']:
if prompt['source'] == LlmPromptSource.history.value:
history_prompt = ActionUtility.prepare_bot_responses(tracker, num_bot_responses)

elif prompt['source'] == LlmPromptSource.bot_content.value and prompt['is_enabled']:
use_similarity_prompt = True
hyperparameters = prompt.get("hyperparameters", {})
similarity_prompt.append({'similarity_prompt_name': prompt['name'],
'similarity_prompt_instructions': prompt['instructions'],
'collection': prompt['data'],
'use_similarity_prompt': use_similarity_prompt,
'top_results': hyperparameters.get('top_results', 10),
'similarity_threshold': hyperparameters.get('similarity_threshold',
0.70)})
sim_cfg = prompt.get("similarity_config") or {}
top_results = sim_cfg.get("top_results", 10)
similarity_threshold = sim_cfg.get("similarity_threshold", 0.7)
similarity_prompt.append({
'similarity_prompt_name': prompt['name'],
'collection': prompt['data'],
'use_similarity_prompt': use_similarity_prompt,
'top_results': top_results,
'similarity_threshold': similarity_threshold
})

elif prompt['source'] == LlmPromptSource.slot.value:
slot_data = tracker.get_slot(prompt['data'])
context_prompt += f"{prompt['name']}:\n{slot_data}\n"
if prompt['instructions']:
context_prompt += f"Instructions on how to use {prompt['name']}:\n{prompt['instructions']}\n\n"

elif prompt['source'] == LlmPromptSource.crud.value:
crud_config = prompt.get('crud_config', {})
collections = crud_config.get('collections', [])
Expand Down Expand Up @@ -210,29 +214,26 @@ async def __get_llm_params(self, k_faq_action_config: dict, dispatcher: Collecti
data_list = [rec["data"] for rec in records]
context_prompt += f"Collection data for {prompt['name']}:\n{data_list}\n"

if prompt['instructions']:
context_prompt += f"Instructions on how to use collection {prompt['name']}:\n{prompt['instructions']}\n\n"
elif prompt['source'] == LlmPromptSource.action.value:
action = ActionFactory.get_instance(self.bot, prompt['data'])
await action.execute(dispatcher, tracker, domain, **kwargs)
if action.is_success:
response = action.response
context_prompt += f"{prompt['name']}:\n{response}\n"
if prompt['instructions']:
context_prompt += f"Instructions on how to use {prompt['name']}:\n{prompt['instructions']}\n\n"
# if prompt.get('instructions'):
# context_prompt += f"Instructions on how to use {prompt['name']}:\n{prompt['instructions']}\n\n"

else:
context_prompt += f"{prompt['name']}:\n{prompt['data']}\n"
if prompt['instructions']:
context_prompt += f"Instructions on how to use {prompt['name']}:\n{prompt['instructions']}\n\n"

elif prompt['type'] == LlmPromptType.query.value and prompt['is_enabled']:
query_prompt += f"{prompt['name']}:\n{prompt['data']}\n"
if prompt['instructions']:
query_prompt += f"Instructions on how to use {prompt['name']}:\n{prompt['instructions']}\n\n"
is_query_prompt_enabled = True
query_prompt_dict.update({'query_prompt': query_prompt, 'use_query_prompt': is_query_prompt_enabled})

params["hyperparameters"] = k_faq_action_config['hyperparameters']
params["system_prompt"] = system_prompt
params["user_prompt"] = user_prompt
params["context_prompt"] = context_prompt
params["query_prompt"] = query_prompt_dict
params["previous_bot_responses"] = history_prompt
Expand Down
16 changes: 7 additions & 9 deletions kairon/importer/validator/file_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ def __validate_prompt_actions(prompt_actions: list, bot: Text = None):
action['num_bot_responses'] > 5 or not isinstance(action['num_bot_responses'], int)):
data_error.append(
f'num_bot_responses should not be greater than 5 and of type int: {action.get("name")}')
llm_prompts_errors = TrainingDataValidator.__validate_llm_prompts(action['llm_prompts'])
llm_prompts_errors = TrainingDataValidator.__validate_llm_prompts(action['contexts'])
if action.get('hyperparameters'):
llm_hyperparameters_errors = TrainingDataValidator.__validate_llm_prompts_hyperparameters(
action.get('hyperparameters'), action.get("llm_type", "openai"), bot)
Expand Down Expand Up @@ -753,13 +753,13 @@ def __validate_database_actions(database_actions: list):
return data_error

@staticmethod
def __validate_llm_prompts(llm_prompts: dict):
def __validate_llm_prompts(contexts: dict):
error_list = []
system_prompt_count = 0
history_prompt_count = 0
for prompt in llm_prompts:
if prompt.get('hyperparameters') is not None:
hyperparameters = prompt.get('hyperparameters')
for prompt in contexts:
if prompt.get('similarity_config') is not None:
hyperparameters = prompt.get('similarity_config')
for key, value in hyperparameters.items():
if key == 'similarity_threshold':
if not (0.3 <= value <= 1.0) or not (
Expand All @@ -769,13 +769,11 @@ def __validate_llm_prompts(llm_prompts: dict):
if key == 'top_results' and (value > 30 or not isinstance(value, int)):
error_list.append("top_results should not be greater than 30 and of type int!")

if prompt.get('type') == 'system':
system_prompt_count += 1
elif prompt.get('source') == 'history':
if prompt.get('source') == 'history':
history_prompt_count += 1
if prompt.get('type') not in ['user', 'system', 'query']:
error_list.append('Invalid prompt type')
if prompt.get('source') not in ['static', 'slot', 'action', 'history', 'bot_content']:
if prompt.get('source') not in ['static', 'slot', 'action', 'history', 'bot_content', 'crud']:
error_list.append('Invalid prompt source')
if prompt.get('type') and not isinstance(prompt.get('type'), str):
error_list.append('type in LLM Prompts should be of type string.')
Expand Down
20 changes: 9 additions & 11 deletions kairon/shared/actions/data_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,11 +792,9 @@ def validate(self, clean=True):
def clean(self):
self.collections = [col.strip() for col in self.collections if col and col.strip()]

class LlmPrompt(EmbeddedDocument):
class Context(EmbeddedDocument):
name = StringField(required=True)
hyperparameters = EmbeddedDocumentField(PromptHyperparameter)
data = StringField()
instructions = StringField()
type = StringField(
required=True,
choices=[
Expand All @@ -818,6 +816,7 @@ class LlmPrompt(EmbeddedDocument):
)
is_enabled = BooleanField(default=True)
crud_config = EmbeddedDocumentField(CrudConfig)
similarity_config = EmbeddedDocumentField(PromptHyperparameter)

def validate(self, clean=True):
if (
Expand All @@ -833,8 +832,6 @@ def validate(self, clean=True):
self.crud_config.validate()
elif self.crud_config:
raise ValidationError("crud_config should only be provided when source is 'crud'")
if self.hyperparameters:
self.hyperparameters.validate()
if self.source == LlmPromptSource.bot_content.value and Utility.check_empty_string(self.data):
self.data = "default"

Expand All @@ -857,11 +854,12 @@ class PromptAction(Auditlog):
timestamp = DateTimeField(default=datetime.utcnow)
llm_type = StringField(default=DEFAULT_LLM, choices=Utility.get_llms())
hyperparameters = DictField(default=Utility.get_default_llm_hyperparameters)
llm_prompts = EmbeddedDocumentListField(LlmPrompt, required=True)
instructions = ListField(StringField())
contexts = EmbeddedDocumentListField(Context)
system_prompt = StringField(required=True)
user_prompt = StringField(required=True)
set_slots = EmbeddedDocumentListField(SetSlotsFromResponse)
dispatch_response = BooleanField(default=True)
process_media=BooleanField(default=False)
process_media = BooleanField(default=False)
status = BooleanField(default=True)

meta = {"indexes": [{"fields": ["bot", ("bot", "name", "status")]}]}
Expand All @@ -876,13 +874,13 @@ def validate(self, clean=True):
self.clean()
if self.num_bot_responses > 5:
raise ValidationError("num_bot_responses should not be greater than 5")
if not self.llm_prompts:
if not self.contexts:
raise ValidationError("llm_prompts are required!")
Comment on lines +877 to 878
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Update error message to reflect the new field name.

The error message references the old field name llm_prompts, but the field has been renamed to contexts. This inconsistency could confuse users.

Apply this diff to fix the error message:

         if not self.contexts:
-            raise ValidationError("llm_prompts are required!")
+            raise ValidationError("contexts are required!")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if not self.contexts:
raise ValidationError("llm_prompts are required!")
if not self.contexts:
raise ValidationError("contexts are required!")
🧰 Tools
🪛 Ruff (0.14.3)

878-878: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In kairon/shared/actions/data_objects.py around lines 877 to 878, the
ValidationError message still references the old field name "llm_prompts";
update the exception text to reference the renamed field "contexts" instead
(e.g., raise ValidationError("contexts are required!")) so the error message
matches the current field name.

for prompts in self.llm_prompts:
for prompts in self.contexts:
prompts.validate()
dict_data = self.to_mongo().to_dict()
Utility.validate_kairon_faq_llm_prompts(
dict_data["llm_prompts"], ValidationError
dict_data["contexts"], ValidationError
)
Utility.validate_llm_hyperparameters(
dict_data["hyperparameters"], self.llm_type, self.bot, ValidationError
Expand Down
11 changes: 1 addition & 10 deletions kairon/shared/chat/user_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,14 +448,6 @@ def add_media_extraction_flow_if_not_exist( bot: str):
try:
instructions = ['Extract all relevant information from the media file and return as a markdown']

llm_prompts = [
{'name': 'System Prompt',
'instructions': f'{instructions[0]}',
'type': 'system',
'data': 'Extract information',
},
]

action_name = f"{UserMedia.MEDIA_EXTRACTION_FLOW_NAME}_prompt_action"
if not Actions.objects(bot=bot, name=action_name).first():
action = Actions(
Expand All @@ -471,8 +463,7 @@ def add_media_extraction_flow_if_not_exist( bot: str):
if not PromptAction.objects(bot=bot, name=action_name).first():
prompt_action = PromptAction(
name=action_name,
instructions=instructions,
llm_prompts=llm_prompts,
system_prompt = f'{instructions[0]}',
dispatch_response=True,
process_media=True,
bot=bot,
Expand Down
2 changes: 1 addition & 1 deletion kairon/shared/cognition/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def find_matching_metadata(bot: Text, data: Any, collection: Text = None):

@staticmethod
def validate_collection_name(bot: Text, collection: Text):
prompt_action = list(PromptAction.objects(bot=bot, llm_prompts__data__iexact=collection))
prompt_action = list(PromptAction.objects(bot=bot, contexts__data__iexact=collection))
database_action = list(DatabaseAction.objects(bot=bot, collection__iexact=collection))
if prompt_action:
raise AppException(f'Cannot remove collection {collection} linked to action "{prompt_action[0].name}"!')
Expand Down
54 changes: 39 additions & 15 deletions kairon/shared/data/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,9 +1102,9 @@ class CrudConfigRequest(BaseModel):
result_limit: int = 10
query_source: Optional[Literal["value", "slot"]] = None

class LlmPromptRequest(BaseModel, use_enum_values=True):
class Context(BaseModel, use_enum_values=True):
name: str
hyperparameters: PromptHyperparameters = None
similarity_config: Optional[PromptHyperparameters] = None
data: str = None
instructions: str = None
type: LlmPromptType
Expand Down Expand Up @@ -1135,8 +1135,11 @@ def check(cls, values):
else:
values.pop('crud_config', None)

if values.get('source') == LlmPromptSource.bot_content.value and Utility.check_empty_string(values.get('data')):
values['data'] = "default"
if values.get('source') == LlmPromptSource.bot_content.value:
if Utility.check_empty_string(values.get('data')):
values['data'] = "default"
else:
values.pop('similarity_config', None)

return values

Expand All @@ -1153,8 +1156,9 @@ class PromptActionConfigUploadValidation(BaseModel):
user_question: UserQuestionModel = UserQuestionModel()
llm_type: str
hyperparameters: dict
llm_prompts: List[LlmPromptRequest]
instructions: List[str] = []
contexts: List[Context] = []
system_prompt: str
user_prompt: str
set_slots: List[SetSlotsUsingActionResponse] = []
dispatch_response: bool = True

Expand All @@ -1166,8 +1170,9 @@ class PromptActionConfigRequest(BaseModel):
user_question: UserQuestionModel = UserQuestionModel()
llm_type: str
hyperparameters: dict
llm_prompts: List[LlmPromptRequest]
instructions: List[str] = []
contexts: List[Context] = Field(default_factory=list)
system_prompt: str
user_prompt: str
set_slots: List[SetSlotsUsingActionResponse] = []
dispatch_response: bool = True
process_media: bool = False
Expand All @@ -1180,7 +1185,7 @@ def validate_llm_type(cls, v, values, **kwargs):
raise ValueError("Invalid llm type")
return v

@validator("llm_prompts")
@validator("contexts")
def validate_llm_prompts(cls, v, values, **kwargs):
from kairon.shared.utils import Utility

Expand All @@ -1195,15 +1200,34 @@ def validate_num_bot_responses(cls, v, values, **kwargs):
raise ValueError("num_bot_responses should not be greater than 5")
return v

@validator("hyperparameters")
def validate_hyperparameters(cls, v, values, **kwargs):
@validator("user_prompt", "system_prompt", pre=True, always=True)
def validate_user_and_system_prompts(cls, v, field):
from kairon.shared.utils import Utility
bot = values.get('bot')
llm_type = values.get('llm_type')
if llm_type and v:
Utility.validate_llm_hyperparameters(v, llm_type, bot, ValueError)

if isinstance(v, list):
if len(v) != 1:
raise ValueError(f"Only one {field.name.replace('_', ' ')} is allowed.")
v = v[0]

if Utility.check_empty_string(v):
raise ValueError(f"{field.name.replace('_', ' ').capitalize()} cannot be empty!")

if not isinstance(v, str):
raise ValueError(f"{field.name.replace('_', ' ').capitalize()} must be a string.")

return v

@root_validator()
def validate_hyperparameters(cls, values):
from kairon.shared.utils import Utility
bot = values.get('bot')
llm_type = values.get("llm_type")
hyperparams = values.get("hyperparameters")
if llm_type and hyperparams:
Utility.validate_llm_hyperparameters(hyperparams, llm_type, bot, ValueError)

return values

@root_validator(pre=True)
def validate_required_fields(cls, values):
bot = values.get('bot')
Expand Down
15 changes: 7 additions & 8 deletions kairon/shared/data/data_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def validate_prompt_action(bot: str, data: dict):
data['num_bot_responses'] > 5 or not isinstance(data['num_bot_responses'], int)):
data_error.append(
f'num_bot_responses should not be greater than 5 and of type int: {data.get("name")}')
llm_prompts_errors = DataValidation.validate_llm_prompts(data['llm_prompts'])
llm_prompts_errors = DataValidation.validate_llm_prompts(data['contexts'])
data_error.extend(llm_prompts_errors)
if hyperparameters := data.get('hyperparameters'):
if not data.get("llm_type"):
Expand Down Expand Up @@ -132,19 +132,18 @@ def validate_llm_prompts_hyperparameters(hyperparameters: dict, llm_type: str, b
return error_list

@staticmethod
def validate_llm_prompts(llm_prompts: list):
def validate_llm_prompts(contexts: list):
error_list = []
system_prompt_count = 0
history_prompt_count = 0
for prompt in llm_prompts:
if prompt.get('hyperparameters') is not None:
hyperparameters = prompt.get('hyperparameters')
for prompt in contexts:
if prompt.get('similarity_config') is not None:
hyperparameters = prompt.get('similarity_config')
for key, value in hyperparameters.items():
if key == 'similarity_threshold':
if not (0.3 <= value <= 1.0) or not (
isinstance(value, float) or isinstance(value, int)):
error_list.append(
"similarity_threshold should be within 0.3 and 1.0 and of type int or float!")
error_list.append("similarity_threshold should be within 0.3 and 1.0 and of type int or float!")
if key == 'top_results' and (value > 30 or not isinstance(value, int)):
error_list.append("top_results should not be greater than 30 and of type int!")

Expand All @@ -154,7 +153,7 @@ def validate_llm_prompts(llm_prompts: list):
history_prompt_count += 1
if prompt.get('type') not in ['user', 'system', 'query']:
error_list.append('Invalid prompt type')
if prompt.get('source') not in ['static', 'slot', 'action', 'history', 'bot_content']:
if prompt.get('source') not in ['static', 'slot', 'action', 'history', 'bot_content', 'crud']:
error_list.append('Invalid prompt source')
if prompt.get('type') and not isinstance(prompt.get('type'), str):
error_list.append('type in LLM Prompts should be of type string.')
Expand Down
Loading
Loading