From dabdaf8df80d38d1e40857d1f1affe642b069c2f Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Fri, 15 Nov 2024 15:31:33 +0200 Subject: [PATCH] Small updates and refactors --- apps/chat/bots.py | 1 + apps/experiments/models.py | 5 +++-- apps/pipelines/nodes/nodes.py | 19 ++++++++++++++----- .../apps/pipeline/nodes/GetInputWidget.tsx | 2 +- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/apps/chat/bots.py b/apps/chat/bots.py index 8e0cc4a27..63fc5f284 100644 --- a/apps/chat/bots.py +++ b/apps/chat/bots.py @@ -289,6 +289,7 @@ def __init__(self, session: ExperimentSession, experiment: Experiment): self.ai_message_id = None def process_input(self, user_input: str, save_input_to_history=True, attachments: list["Attachment"] | None = None): + attachments = attachments or [] serializable_attachments = [attachment.model_dump() for attachment in attachments] output: PipelineState = self.experiment.pipeline.invoke( PipelineState(messages=[user_input], experiment_session=self.session, attachments=serializable_attachments), diff --git a/apps/experiments/models.py b/apps/experiments/models.py index d72eb9196..301c7fb41 100644 --- a/apps/experiments/models.py +++ b/apps/experiments/models.py @@ -992,8 +992,9 @@ def get_assistant(self): node_name = AssistantNode.model_json_schema()["title"] # TODO: What about multiple assistant nodes? assistant_node = Node.objects.filter(type=node_name, pipeline=self.pipeline).first() - if assistant_node and assistant_node.params["assistant_id"]: - return OpenAiAssistant.objects.get(id=assistant_node.params["assistant_id"]) + if assistant_node: + assistant_id = assistant_node.params.get("assistant_id") + return OpenAiAssistant.objects.get(id=assistant_id) return self.assistant diff --git a/apps/pipelines/nodes/nodes.py b/apps/pipelines/nodes/nodes.py index 23992a2f9..7a09110ce 100644 --- a/apps/pipelines/nodes/nodes.py +++ b/apps/pipelines/nodes/nodes.py @@ -52,7 +52,7 @@ class RenderTemplate(PipelineNode): description="Use {your_variable_name} to refer to designate input", ) - def _process(self, input, node_id: str, **kwargs) -> str: + def _process(self, input, node_id: str, **kwargs) -> PipelineState: def all_variables(in_): return {var: in_ for var in meta.find_undeclared_variables(env.parse(self.template_string))} @@ -154,7 +154,7 @@ class LLMResponse(PipelineNode, LLMResponseMixin): __human_name__ = "LLM response" __node_description__ = "Calls an LLM with the given input" - def _process(self, input, node_id: str, **kwargs) -> str: + def _process(self, input, node_id: str, **kwargs) -> PipelineState: llm = self.get_chat_model() output = llm.invoke(input, config=self._config) return PipelineState(messages=[output.content], outputs={node_id: output.content}) @@ -230,7 +230,7 @@ def recipient_list_has_valid_emails(cls, value): raise PydanticCustomError("invalid_recipient_list", "Invalid list of emails addresses") return value - def _process(self, input, node_id: str, **kwargs) -> str: + def _process(self, input, node_id: str, **kwargs) -> PipelineState: send_email_from_pipeline.delay( recipient_list=self.recipient_list.split(","), subject=self.subject, message=input ) @@ -518,8 +518,17 @@ class AssistantNode(Passthrough): @field_validator("input_formatter") def ensure_input_variable_exists(cls, value): - if value and "{input}" not in value: - raise PydanticCustomError("invalid_input_formatter", "The input formatter must contain {input}") + value = value or "" + acceptable_var = "input" + prompt_variables = set(PromptTemplate.from_template(value).input_variables) + if value: + if acceptable_var not in prompt_variables: + raise PydanticCustomError("invalid_input_formatter", "The input formatter must contain {input}") + + acceptable_vars = set([acceptable_var]) + extra_vars = prompt_variables - acceptable_vars + if extra_vars: + raise PydanticCustomError("invalid_input_formatter", "Only {input} is allowed") def _process(self, input, state: PipelineState, node_id: str, **kwargs) -> str: assistant = OpenAiAssistant.objects.get(id=self.assistant_id) diff --git a/assets/javascript/apps/pipeline/nodes/GetInputWidget.tsx b/assets/javascript/apps/pipeline/nodes/GetInputWidget.tsx index 4ca442d8b..14f9c83d1 100644 --- a/assets/javascript/apps/pipeline/nodes/GetInputWidget.tsx +++ b/assets/javascript/apps/pipeline/nodes/GetInputWidget.tsx @@ -162,7 +162,7 @@ export const getInputWidget = ({id, inputParam, params, updateParamValue}: Input className="toggle" name={inputParam.name} onChange={onChangeCallback} - defaultChecked={paramValue === "true"} + checked={paramValue === "true"} type="checkbox" >