Skip to content

Commit

Permalink
Enhance graph creation and prompting to reduce reliability issues dur…
Browse files Browse the repository at this point in the history
…ing conversations (#47)

* Fix reliability issues in graph creation and chat invocation

* Fix pg connection issue

* Fix linter errors

* Improve dummy human message
  • Loading branch information
StreetLamb authored Jun 8, 2024
1 parent 4c23607 commit 517b456
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 161 deletions.
86 changes: 29 additions & 57 deletions backend/app/core/graph/build.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
import json
from collections import defaultdict, deque
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Mapping
from functools import partial
from typing import Any

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
from langchain_core.runnables import RunnableLambda
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.graph import END, StateGraph
Expand Down Expand Up @@ -79,6 +79,8 @@ def convert_hierarchical_team_to_dict(
teams[leader_name] = GraphTeam(
name=leader_name,
model=member.model,
role=member.role,
backstory=member.backstory or "",
members={},
provider=member.provider,
temperature=member.temperature,
Expand All @@ -102,6 +104,7 @@ def convert_hierarchical_team_to_dict(
elif member.type == "leader":
teams[leader_name].members[member_name] = GraphLeader(
name=member_name,
backstory=member.backstory or "",
role=member.role,
provider=member.provider,
model=member.model,
Expand All @@ -115,7 +118,7 @@ def convert_hierarchical_team_to_dict(
return teams


def convert_sequential_team_to_dict(team: Team) -> dict[str, GraphMember]:
def convert_sequential_team_to_dict(team: Team) -> Mapping[str, GraphMember]:
team_dict: dict[str, GraphMember] = {}

in_counts: defaultdict[int, int] = defaultdict(int)
Expand Down Expand Up @@ -158,32 +161,6 @@ def convert_sequential_team_to_dict(team: Team) -> dict[str, GraphMember]:
return team_dict


def format_teams(teams: dict[str, dict[str, Any]]) -> dict[str, GraphTeam]:
"""
FOR TESTING PURPOSES ONLY!
This function takes a dictionary of teams and formats their member lists to use instances of the `Member` or `Leader`
classes.
Args:
teams (dict[str, any]): A dictionary where each key is a team name and the value is another dictionary containing
the team's members
Returns:
dict[str, Team]: The input dictionary with its member lists formatted to use instances of `Member` or `Leader`
"""
for team_name, team in teams.items():
if not isinstance(team, dict):
raise ValueError(f"Invalid team {team_name}. Teams must be dictionaries.")
members: dict[str, dict[str, Any]] = team.get("members", {})
for k, v in members.items():
if v["type"] == "leader":
teams[team_name]["members"][k] = GraphLeader(**v)
else:
teams[team_name]["members"][k] = GraphMember(**v)
return {team_name: GraphTeam(**team) for team_name, team in teams.items()}


def router(state: TeamState) -> str:
return state["next"]

Expand All @@ -195,25 +172,14 @@ def enter_chain(state: TeamState, team: GraphTeam) -> dict[str, Any]:
"""
task = state["task"]
results = {
"messages": task,
"team_name": team.name,
"main_task": task,
"team": team,
"team_members": team.members,
}
return results


def format_messages(state: TeamState) -> TeamState:
"""Add a human message to prevent consecutive AI messages"""
if len(state.get("messages", [])) > 0 and isinstance(
state["messages"][-1], AIMessage
):
state["messages"] = state.get("messages", []) + [
HumanMessage(content="what should you do next?", name="ignore")
]
return state


def exit_chain(state: TeamState) -> dict[str, list[BaseMessage]]:
def exit_chain(state: TeamState) -> dict[str, list[AnyMessage]]:
"""
Pass the final response back to the top-level graph's state.
"""
Expand All @@ -224,7 +190,7 @@ def exit_chain(state: TeamState) -> dict[str, list[BaseMessage]]:

def should_continue(state: TeamState) -> str:
"""Determine if graph should go to tool node or not. For tool calling agents."""
messages: list[BaseMessage] = state["messages"]
messages: list[AnyMessage] = state["messages"]
last_message = messages[-1]
# If there is no function call, then we finish
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
Expand Down Expand Up @@ -276,8 +242,7 @@ def create_hierarchical_graph(
# Add the start and end node
build.add_node(
leader_name,
format_messages
| RunnableLambda(
RunnableLambda(
LeaderNode(
teams[leader_name].provider,
teams[leader_name].model,
Expand All @@ -287,8 +252,7 @@ def create_hierarchical_graph(
)
build.add_node(
"FinalAnswer",
format_messages
| RunnableLambda(
RunnableLambda(
SummariserNode(
teams[leader_name].provider,
teams[leader_name].model,
Expand All @@ -302,8 +266,7 @@ def create_hierarchical_graph(
if isinstance(member, GraphMember):
build.add_node(
name,
format_messages
| RunnableLambda(
RunnableLambda(
WorkerNode(
member.provider,
member.model,
Expand Down Expand Up @@ -353,7 +316,7 @@ def create_hierarchical_graph(


def create_sequential_graph(
team: dict[str, GraphMember], memory: BaseCheckpointSaver
team: Mapping[str, GraphMember], memory: BaseCheckpointSaver
) -> CompiledGraph:
"""
Creates a sequential graph from a list of team members.
Expand Down Expand Up @@ -456,23 +419,31 @@ async def generator(
)
state = {
"messages": formatted_messages,
"team_name": teams[team_leader].name,
"team_members": teams[team_leader].members,
"team": teams[team_leader],
"main_task": formatted_messages,
}
else:
member_dict = convert_sequential_team_to_dict(team)
root = create_sequential_graph(member_dict, memory)
first_member = list(member_dict.values())[0]
state = {
"messages": formatted_messages,
"team_name": team.name,
"team_members": member_dict,
"next": list(member_dict.values())[0].name,
"team": GraphTeam(
name=first_member.name,
role=first_member.role,
backstory=first_member.backstory,
members=member_dict, # type: ignore[arg-type]
provider=first_member.provider,
model=first_member.model,
temperature=first_member.temperature,
),
"next": first_member.name,
}
async for output in root.astream_events(
state,
version="v1",
include_names=["work", "delegate", "summarise"],
config={"configurable": {"thread_id": thread_id}},
config={"configurable": {"thread_id": thread_id}, "recursion_limit": 25},
):
if output["event"] == "on_chain_end":
output_data = output["data"]["output"]
Expand All @@ -489,3 +460,4 @@ async def generator(
}
yield f"data: {json.dumps(error_message)}\n\n"
await asyncio.sleep(0.1) # Add a small delay to ensure the message is sent
raise e
9 changes: 7 additions & 2 deletions backend/app/core/graph/checkpoint/aiopostgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,20 @@ class AsyncPostgresSaver(BaseCheckpointSaver, AbstractAsyncContextManager): # t
serde = JsonPlusSerializerCompat()

conn: asyncpg.Connection # type: ignore[type-arg]
conn_string: str
lock: asyncio.Lock
is_setup: bool

def __init__(
self,
conn: asyncpg.Connection, # type: ignore[type-arg]
conn_string: str,
*,
serde: SerializerProtocol | None = None,
):
super().__init__(serde=serde)
self.conn = conn
self.conn_string = conn_string
self.lock = asyncio.Lock()
self.is_setup = False

Expand All @@ -112,7 +115,7 @@ async def from_conn_string(cls, conn_string: str) -> "AsyncPostgresSaver":
AsyncPostgresSaver: A new AsyncPostgresSaver instance.
"""
conn = await asyncpg.connect(conn_string)
return AsyncPostgresSaver(conn=conn)
return AsyncPostgresSaver(conn=conn, conn_string=conn_string)

async def __aenter__(self) -> Self:
return self
Expand Down Expand Up @@ -400,7 +403,9 @@ async def aput(
RunnableConfig: The updated config containing the saved checkpoint's timestamp.
"""
await self.setup()
await self.conn.execute(
# Fix cannot 'perform operation: another operation is in progress' issue
conn = await asyncpg.connect(self.conn_string)
await conn.execute(
"INSERT INTO checkpoints (thread_id, thread_ts, parent_ts, checkpoint, metadata) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (thread_id, thread_ts) DO UPDATE SET checkpoint = EXCLUDED.checkpoint, metadata = EXCLUDED.metadata",
str(config["configurable"]["thread_id"]),
checkpoint["id"],
Expand Down
Loading

0 comments on commit 517b456

Please sign in to comment.