From 87bb87576a32e6f40e27e620acc997abbd3d336f Mon Sep 17 00:00:00 2001 From: Jerron Lim Date: Mon, 8 Jul 2024 23:46:26 +0800 Subject: [PATCH] Enhance agent prompts (#68) * Enhance prompts for hierarchical and sequential worker and leader so that they stick to their personas better. * Fix potential error in SummariserNode if team_responses exceed the context_length of the model --- backend/app/core/graph/members.py | 49 ++++++++++++++++--------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/backend/app/core/graph/members.py b/backend/app/core/graph/members.py index b993614e..f642c1fe 100644 --- a/backend/app/core/graph/members.py +++ b/backend/app/core/graph/members.py @@ -60,7 +60,7 @@ class GraphPerson(BaseModel): @property def persona(self) -> str: - return f"Name: {self.name}\nRole: {self.role}\nBackstory: {self.backstory}\n" + return f"\nName: {self.name}\nRole: {self.role}\nBackstory: {self.backstory}\n" class GraphMember(GraphPerson): @@ -163,14 +163,18 @@ class WorkerNode(BaseNode): "You are a team member of {team_name} and you are one of the following team members: {team_members_name}.\n" "Your team members (and other teams) will collaborate with you with their own set of skills. " "You are chosen by one of your team member to perform this task. Try your best to perform it using your skills. " - "Stay true to your persona and role:\n" - "{persona}" - "\nBEGIN!\n" + "Stay true to your persona and role:\n{persona}\n" + "" ), ), MessagesPlaceholder(variable_name="task"), MessagesPlaceholder(variable_name="messages"), - ("human", "?"), + ( + "human", + "\n" + "Remember to stay true to your persona and role:\n{persona}\n" + "BEGIN!", + ), ] ) @@ -216,12 +220,17 @@ class SequentialWorkerNode(WorkerNode): "If you are unable to perform the task, that's OK, another member with different tools " "will help where you left off. Do not attempt to communicate with other members. " "Execute what you can to make progress. " - "Stay true to your persona and role:\n" - "{persona}" - "\nBEGIN!\n" + "Stay true to your persona and role:\n{persona}\n" + "" ), ), MessagesPlaceholder(variable_name="messages"), + ( + "human", + "\n" + "Remember to stay true to your persona and role:\n{persona}\n" + "BEGIN!", + ), ] ) @@ -240,9 +249,7 @@ async def work(self, state: TeamState) -> ReturnTeamState: name = state["next"] member = team.members[name] assert isinstance(member, GraphMember), "member is unexpectedly not a Member" - team_members_name = self.get_team_members_name(team.members) prompt = self.worker_prompt.partial( - team_members_name=team_members_name, persona=member.persona, ) # If member has no tools, then use a regular model instead of an agent @@ -289,7 +296,10 @@ class LeaderNode(BaseNode): ), MessagesPlaceholder(variable_name="main_task"), MessagesPlaceholder(variable_name="messages"), - ("human", "?"), + ( + "human", + "Who should act next? Or should we FINISH? Select one of: {options}.", + ), ] ) @@ -385,25 +395,19 @@ class SummariserNode(BaseNode): "Here is the team's task:" "\n\n{team_task}\n\n" "These are the responses from your team members:" - "\n\n{team_responses}\n" - "Your role is to interpret all the responses and give the final answer to the team's task.\n" ), ), - ("human", "?"), + MessagesPlaceholder(variable_name="messages"), + ( + "human", + "Your role is to interpret all the responses and give the final answer to the team's task.\n", + ), ] ) - def get_team_responses(self, messages: list[AnyMessage]) -> str: - """Create a string containing the team's responses.""" - result = "" - for message in messages: - result += f"{message.name}: {message.content}\n" - return result - async def summarise(self, state: TeamState) -> dict[str, list[AnyMessage]]: team = state["team"] team_members_name = self.get_team_members_name(team.members) - team_responses = self.get_team_responses(state["messages"]) # TODO: optimise looking for task team_task = state["main_task"][0].content @@ -412,7 +416,6 @@ async def summarise(self, state: TeamState) -> dict[str, list[AnyMessage]]: team_name=team.name, team_members_name=team_members_name, team_task=team_task, - team_responses=team_responses, ) | self.final_answer_model | RunnableLambda(self.tag_with_name).bind(name=f"{team.name}_answer") # type: ignore[arg-type]