Skip to content

Commit

Permalink
Fix no tools logic, initial team leader name and delegate bug. Rename…
Browse files Browse the repository at this point in the history
… summariser to FinalAnswer.
  • Loading branch information
StreetLamb committed Apr 28, 2024
1 parent 687dea4 commit 3078e1f
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 54 deletions.
2 changes: 1 addition & 1 deletion backend/app/api/routes/teams.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def create_team(
# Create team leader
member = Member(
**{
"name": "Team Leader",
"name": "TeamLeader",
"type": "root",
"role": "Gather inputs from your team and answer the question.",
"owner_of": team.id,
Expand Down
6 changes: 3 additions & 3 deletions backend/app/core/graph/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def create_graph(teams: dict[str, Team], leader_name: str):
),
)
build.add_node(
"summariser",
"FinalAnswer",
RunnableLambda(
SummariserNode(
teams[leader_name].provider,
Expand Down Expand Up @@ -210,11 +210,11 @@ def create_graph(teams: dict[str, Team], leader_name: str):
build.add_edge(name, leader_name)

conditional_mapping = {v: v for v in members}
conditional_mapping["FINISH"] = "summariser"
conditional_mapping["FINISH"] = "FinalAnswer"
build.add_conditional_edges(leader_name, router, conditional_mapping)

build.set_entry_point(leader_name)
build.set_finish_point("summariser")
build.set_finish_point("FinalAnswer")
graph = build.compile()
return graph

Expand Down
107 changes: 65 additions & 42 deletions backend/app/core/graph/members.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
AgentExecutor,
create_tool_calling_agent,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.output_parsers.openai_tools import JsonOutputKeyToolsParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
Expand Down Expand Up @@ -80,24 +80,32 @@ class TeamState(TypedDict):
class BaseNode:
def __init__(self, provider: str, model: str, temperature: float):
self.model = all_models[provider](model=model, temperature=temperature)
self.final_answer_model = all_models[provider](model=model, temperature=0)

def tag_with_name(self, ai_message: AIMessage, name: str):
"""Tag a name to the AI message"""
ai_message.name = name
return ai_message

def get_team_members_name(self, team_members: dict[str, Person]):
"""Get the names of all team members as a string"""
return ",".join(list(team_members))


class WorkerNode(BaseNode):
worker_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a team member of {team_name} and you are one of the following team members: {team_members}.
Your team members (and other teams) will collaborate with you with their own set of skills.
You are chosen by one of your team member to perform this task. Try your best to perform it using your skills.
Stay true to your perspective:
{persona}""",
(
"You are a team member of {team_name} and you are one of the following team members: {team_members_name}.\n"
"Your team members (and other teams) will collaborate with you with their own set of skills. "
"You are chosen by one of your team member to perform this task. Try your best to perform it using your skills. "
"Stay true to your perspective:\n"
"{persona}"
),
),
MessagesPlaceholder(variable_name="messages"),
MessagesPlaceholder(variable_name="task"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
Expand All @@ -109,30 +117,30 @@ def convert_output_to_ai_message(self, state: TeamState):
return AIMessage(content=output)

def create_agent(
self, llm: BaseLanguageModel, prompt: ChatPromptTemplate, tools: list[str]
self, llm: BaseChatModel, prompt: ChatPromptTemplate, tools: list[str]
):
"""Create the agent executor"""
"""Create the agent executor. Tools must non-empty."""
tools = [all_skills[tool].tool for tool in tools]
# Tools cannot be empty, add a placeholder
if len(tools) < 1:
tools = [all_skills["nothing"].tool]

agent = create_tool_calling_agent(llm, tools, prompt)
# agent = create_openai_functions_agent(llm, tools, prompt)
executor = AgentExecutor(agent=agent, tools=tools)
return executor

async def work(self, state: TeamState):
name = state["next"]
member = state["team_members"][name]
tools = member.tools
prompt = self.worker_prompt.partial(persona=member.persona)
agent = self.create_agent(self.model, prompt, tools)
work_chain = (
agent
| RunnableLambda(self.convert_output_to_ai_message)
| RunnableLambda(self.tag_with_name).bind(name=member.name)
team_members_name = self.get_team_members_name(state["team_members"])
prompt = self.worker_prompt.partial(
team_members_name=team_members_name,
persona=member.persona,
)
# If member has no tools, then use a regular model instead of an agent
if len(tools) >= 1:
agent = self.create_agent(self.model, prompt, tools)
chain = agent | RunnableLambda(self.convert_output_to_ai_message)
else:
chain = prompt.partial(agent_scratchpad=[]) | self.model
work_chain = chain | RunnableLambda(self.tag_with_name).bind(name=member.name)
result = await work_chain.ainvoke(state)
return {"messages": [result]}

Expand All @@ -142,10 +150,12 @@ class LeaderNode(BaseNode):
[
(
"system",
"""You are the team leader of {team_name} and you have the following team members: {team_members}.
Your team is given a task and you have to delegate the work among your team members based on their skills.
Team member info:
{team_members_info}""",
(
"You are the team leader of {team_name} and you have the following team members: {team_members_name}.\n"
"Your team is given a task and you have to delegate the work among your team members based on their skills.\n"
"Team member info:\n\n"
"{team_members_info}"
),
),
MessagesPlaceholder(variable_name="messages"),
(
Expand Down Expand Up @@ -190,7 +200,7 @@ def get_tool_definition(self, options: list[str]):
}

async def delegate(self, state: TeamState):
team_members = ", ".join(state["team_members"])
team_members_name = self.get_team_members_name(state["team_members"])
team_name = state["team_name"]
team_members_info = self.get_team_members_info(state["team_members"])
options = list(state["team_members"]) + ["FINISH"]
Expand All @@ -199,34 +209,47 @@ async def delegate(self, state: TeamState):
delegate_chain = (
self.leader_prompt.partial(
team_name=team_name,
team_members=team_members,
team_members_name=team_members_name,
team_members_info=team_members_info,
options=str(options),
)
| self.model.bind_tools(tools=tools)
| JsonOutputKeyToolsParser(key_name="route", first_tool_only=True)
)
result = await delegate_chain.ainvoke(state)
# Convert task from string to list[HumanMessage] because Worker's MessagesPlaceholder only accepts list of messages.
result["task"] = [
HumanMessage(content=result.get("task", "None"), name=team_name)
]
return result
if not result:
return {
"task": [HumanMessage(content="No further tasks.", name=team_name)],
"next": "FINISH",
}
else:
# Convert task from string to list[HumanMessage] because Worker's MessagesPlaceholder only accepts list of messages.
result["task"] = [
HumanMessage(
content=result.get("task", "No further tasks."), name=team_name
)
]
return result


class SummariserNode(BaseNode):
summariser_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a team member of {team_name} and you have the following team members: {team_members}.
Your team was given a task and your team members have performed their roles and returned their responses to the team leader.
Your role as a Summariser is to summarise the responses by your team members and give the final answer.
Here is the team's task:
{team_task}
These are the responses from your team members:
{team_responses}""",
(
"You are a team member of {team_name} and you have the following team members: {team_members_name}. "
"Your team was given a task and your team members have performed their roles and returned their responses to the team leader.\n\n"
"Here is the team's task:\n"
"'''\n"
"{team_task}\n"
"'''\n\n"
"These are the responses from your team members:\n"
"'''\n"
"{team_responses}\n"
"'''\n\n"
"Your role is to interpret all the responses and give the final answer to the team's task.\n"
),
)
]
)
Expand All @@ -239,20 +262,20 @@ def get_team_responses(self, messages: list[BaseMessage]):
return result

async def summarise(self, state: TeamState):
team_members = ", ".join(state["team_members"])
team_members_name = self.get_team_members_name(state["team_members"])
team_name = state["team_name"]
team_responses = self.get_team_responses(state["messages"])
team_task = state["messages"][0].content

summarise_chain = (
self.summariser_prompt.partial(
team_name=team_name,
team_members=team_members,
team_members_name=team_members_name,
team_task=team_task,
team_responses=team_responses,
)
| self.model
| RunnableLambda(self.tag_with_name).bind(name="summariser")
| self.final_answer_model
| RunnableLambda(self.tag_with_name).bind(name="FinalAnswer")
)
result = await summarise_chain.ainvoke(state)
return {"messages": [result]}
3 changes: 2 additions & 1 deletion backend/app/core/graph/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from langchain_anthropic import ChatAnthropic
from langchain_cohere import ChatCohere
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI

# Define a dictionary to store all models
all_models = {
all_models: dict[str, type[BaseChatModel]] = {
"ChatOpenAI": ChatOpenAI,
"ChatAnthropic": ChatAnthropic,
"ChatCohere": ChatCohere,
Expand Down
7 changes: 0 additions & 7 deletions backend/app/core/graph/skills.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,7 @@ class SkillInfo(BaseModel):
tool: Callable


@tool
def nothing(query: str) -> str:
"""Placeholder Tool. Does nothing"""
return ""


all_skills: dict[str, SkillInfo] = {
"nothing": SkillInfo(description="Does nothing", tool=nothing),
"search": SkillInfo(
description="Searches the web using Duck Duck Go", tool=DuckDuckGoSearchRun()
),
Expand Down

0 comments on commit 3078e1f

Please sign in to comment.