Skip to content

Commit

Permalink
Small updates and refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
SmittieC committed Nov 15, 2024
1 parent 19a0a46 commit dabdaf8
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 8 deletions.
1 change: 1 addition & 0 deletions apps/chat/bots.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def __init__(self, session: ExperimentSession, experiment: Experiment):
self.ai_message_id = None

def process_input(self, user_input: str, save_input_to_history=True, attachments: list["Attachment"] | None = None):
attachments = attachments or []
serializable_attachments = [attachment.model_dump() for attachment in attachments]
output: PipelineState = self.experiment.pipeline.invoke(
PipelineState(messages=[user_input], experiment_session=self.session, attachments=serializable_attachments),
Expand Down
5 changes: 3 additions & 2 deletions apps/experiments/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,8 +992,9 @@ def get_assistant(self):
node_name = AssistantNode.model_json_schema()["title"]
# TODO: What about multiple assistant nodes?
assistant_node = Node.objects.filter(type=node_name, pipeline=self.pipeline).first()
if assistant_node and assistant_node.params["assistant_id"]:
return OpenAiAssistant.objects.get(id=assistant_node.params["assistant_id"])
if assistant_node:
assistant_id = assistant_node.params.get("assistant_id")
return OpenAiAssistant.objects.get(id=assistant_id)
return self.assistant


Expand Down
19 changes: 14 additions & 5 deletions apps/pipelines/nodes/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class RenderTemplate(PipelineNode):
description="Use {your_variable_name} to refer to designate input",
)

def _process(self, input, node_id: str, **kwargs) -> str:
def _process(self, input, node_id: str, **kwargs) -> PipelineState:
def all_variables(in_):
return {var: in_ for var in meta.find_undeclared_variables(env.parse(self.template_string))}

Expand Down Expand Up @@ -154,7 +154,7 @@ class LLMResponse(PipelineNode, LLMResponseMixin):
__human_name__ = "LLM response"
__node_description__ = "Calls an LLM with the given input"

def _process(self, input, node_id: str, **kwargs) -> str:
def _process(self, input, node_id: str, **kwargs) -> PipelineState:
llm = self.get_chat_model()
output = llm.invoke(input, config=self._config)
return PipelineState(messages=[output.content], outputs={node_id: output.content})
Expand Down Expand Up @@ -230,7 +230,7 @@ def recipient_list_has_valid_emails(cls, value):
raise PydanticCustomError("invalid_recipient_list", "Invalid list of emails addresses")
return value

def _process(self, input, node_id: str, **kwargs) -> str:
def _process(self, input, node_id: str, **kwargs) -> PipelineState:
send_email_from_pipeline.delay(
recipient_list=self.recipient_list.split(","), subject=self.subject, message=input
)
Expand Down Expand Up @@ -518,8 +518,17 @@ class AssistantNode(Passthrough):

@field_validator("input_formatter")
def ensure_input_variable_exists(cls, value):
if value and "{input}" not in value:
raise PydanticCustomError("invalid_input_formatter", "The input formatter must contain {input}")
value = value or ""
acceptable_var = "input"
prompt_variables = set(PromptTemplate.from_template(value).input_variables)
if value:
if acceptable_var not in prompt_variables:
raise PydanticCustomError("invalid_input_formatter", "The input formatter must contain {input}")

acceptable_vars = set([acceptable_var])
extra_vars = prompt_variables - acceptable_vars
if extra_vars:
raise PydanticCustomError("invalid_input_formatter", "Only {input} is allowed")

def _process(self, input, state: PipelineState, node_id: str, **kwargs) -> str:
assistant = OpenAiAssistant.objects.get(id=self.assistant_id)
Expand Down
2 changes: 1 addition & 1 deletion assets/javascript/apps/pipeline/nodes/GetInputWidget.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ export const getInputWidget = ({id, inputParam, params, updateParamValue}: Input
className="toggle"
name={inputParam.name}
onChange={onChangeCallback}
defaultChecked={paramValue === "true"}
checked={paramValue === "true"}
type="checkbox"
></input>
</InputField>
Expand Down

0 comments on commit dabdaf8

Please sign in to comment.