Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 29 additions & 57 deletions backend/app/core/graph/build.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
import json
from collections import defaultdict, deque
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Mapping
from functools import partial
from typing import Any

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
from langchain_core.runnables import RunnableLambda
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.graph import END, StateGraph
Expand Down Expand Up @@ -79,6 +79,8 @@ def convert_hierarchical_team_to_dict(
teams[leader_name] = GraphTeam(
name=leader_name,
model=member.model,
role=member.role,
backstory=member.backstory or "",
members={},
provider=member.provider,
temperature=member.temperature,
Expand All @@ -102,6 +104,7 @@ def convert_hierarchical_team_to_dict(
elif member.type == "leader":
teams[leader_name].members[member_name] = GraphLeader(
name=member_name,
backstory=member.backstory or "",
role=member.role,
provider=member.provider,
model=member.model,
Expand All @@ -115,7 +118,7 @@ def convert_hierarchical_team_to_dict(
return teams


def convert_sequential_team_to_dict(team: Team) -> dict[str, GraphMember]:
def convert_sequential_team_to_dict(team: Team) -> Mapping[str, GraphMember]:
team_dict: dict[str, GraphMember] = {}

in_counts: defaultdict[int, int] = defaultdict(int)
Expand Down Expand Up @@ -158,32 +161,6 @@ def convert_sequential_team_to_dict(team: Team) -> dict[str, GraphMember]:
return team_dict


def format_teams(teams: dict[str, dict[str, Any]]) -> dict[str, GraphTeam]:
"""
FOR TESTING PURPOSES ONLY!

This function takes a dictionary of teams and formats their member lists to use instances of the `Member` or `Leader`
classes.

Args:
teams (dict[str, any]): A dictionary where each key is a team name and the value is another dictionary containing
the team's members

Returns:
dict[str, Team]: The input dictionary with its member lists formatted to use instances of `Member` or `Leader`
"""
for team_name, team in teams.items():
if not isinstance(team, dict):
raise ValueError(f"Invalid team {team_name}. Teams must be dictionaries.")
members: dict[str, dict[str, Any]] = team.get("members", {})
for k, v in members.items():
if v["type"] == "leader":
teams[team_name]["members"][k] = GraphLeader(**v)
else:
teams[team_name]["members"][k] = GraphMember(**v)
return {team_name: GraphTeam(**team) for team_name, team in teams.items()}


def router(state: TeamState) -> str:
return state["next"]

Expand All @@ -195,25 +172,14 @@ def enter_chain(state: TeamState, team: GraphTeam) -> dict[str, Any]:
"""
task = state["task"]
results = {
"messages": task,
"team_name": team.name,
"main_task": task,
"team": team,
"team_members": team.members,
}
return results


def format_messages(state: TeamState) -> TeamState:
"""Add a human message to prevent consecutive AI messages"""
if len(state.get("messages", [])) > 0 and isinstance(
state["messages"][-1], AIMessage
):
state["messages"] = state.get("messages", []) + [
HumanMessage(content="what should you do next?", name="ignore")
]
return state


def exit_chain(state: TeamState) -> dict[str, list[BaseMessage]]:
def exit_chain(state: TeamState) -> dict[str, list[AnyMessage]]:
"""
Pass the final response back to the top-level graph's state.
"""
Expand All @@ -224,7 +190,7 @@ def exit_chain(state: TeamState) -> dict[str, list[BaseMessage]]:

def should_continue(state: TeamState) -> str:
"""Determine if graph should go to tool node or not. For tool calling agents."""
messages: list[BaseMessage] = state["messages"]
messages: list[AnyMessage] = state["messages"]
last_message = messages[-1]
# If there is no function call, then we finish
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
Expand Down Expand Up @@ -276,8 +242,7 @@ def create_hierarchical_graph(
# Add the start and end node
build.add_node(
leader_name,
format_messages
| RunnableLambda(
RunnableLambda(
LeaderNode(
teams[leader_name].provider,
teams[leader_name].model,
Expand All @@ -287,8 +252,7 @@ def create_hierarchical_graph(
)
build.add_node(
"FinalAnswer",
format_messages
| RunnableLambda(
RunnableLambda(
SummariserNode(
teams[leader_name].provider,
teams[leader_name].model,
Expand All @@ -302,8 +266,7 @@ def create_hierarchical_graph(
if isinstance(member, GraphMember):
build.add_node(
name,
format_messages
| RunnableLambda(
RunnableLambda(
WorkerNode(
member.provider,
member.model,
Expand Down Expand Up @@ -353,7 +316,7 @@ def create_hierarchical_graph(


def create_sequential_graph(
team: dict[str, GraphMember], memory: BaseCheckpointSaver
team: Mapping[str, GraphMember], memory: BaseCheckpointSaver
) -> CompiledGraph:
"""
Creates a sequential graph from a list of team members.
Expand Down Expand Up @@ -456,23 +419,31 @@ async def generator(
)
state = {
"messages": formatted_messages,
"team_name": teams[team_leader].name,
"team_members": teams[team_leader].members,
"team": teams[team_leader],
"main_task": formatted_messages,
}
else:
member_dict = convert_sequential_team_to_dict(team)
root = create_sequential_graph(member_dict, memory)
first_member = list(member_dict.values())[0]
state = {
"messages": formatted_messages,
"team_name": team.name,
"team_members": member_dict,
"next": list(member_dict.values())[0].name,
"team": GraphTeam(
name=first_member.name,
role=first_member.role,
backstory=first_member.backstory,
members=member_dict, # type: ignore[arg-type]
provider=first_member.provider,
model=first_member.model,
temperature=first_member.temperature,
),
"next": first_member.name,
}
async for output in root.astream_events(
state,
version="v1",
include_names=["work", "delegate", "summarise"],
config={"configurable": {"thread_id": thread_id}},
config={"configurable": {"thread_id": thread_id}, "recursion_limit": 25},
):
if output["event"] == "on_chain_end":
output_data = output["data"]["output"]
Expand All @@ -489,3 +460,4 @@ async def generator(
}
yield f"data: {json.dumps(error_message)}\n\n"
await asyncio.sleep(0.1) # Add a small delay to ensure the message is sent
raise e
9 changes: 7 additions & 2 deletions backend/app/core/graph/checkpoint/aiopostgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,20 @@ class AsyncPostgresSaver(BaseCheckpointSaver, AbstractAsyncContextManager): # t
serde = JsonPlusSerializerCompat()

conn: asyncpg.Connection # type: ignore[type-arg]
conn_string: str
lock: asyncio.Lock
is_setup: bool

def __init__(
self,
conn: asyncpg.Connection, # type: ignore[type-arg]
conn_string: str,
*,
serde: SerializerProtocol | None = None,
):
super().__init__(serde=serde)
self.conn = conn
self.conn_string = conn_string
self.lock = asyncio.Lock()
self.is_setup = False

Expand All @@ -112,7 +115,7 @@ async def from_conn_string(cls, conn_string: str) -> "AsyncPostgresSaver":
AsyncPostgresSaver: A new AsyncPostgresSaver instance.
"""
conn = await asyncpg.connect(conn_string)
return AsyncPostgresSaver(conn=conn)
return AsyncPostgresSaver(conn=conn, conn_string=conn_string)

async def __aenter__(self) -> Self:
return self
Expand Down Expand Up @@ -400,7 +403,9 @@ async def aput(
RunnableConfig: The updated config containing the saved checkpoint's timestamp.
"""
await self.setup()
await self.conn.execute(
# Fix cannot 'perform operation: another operation is in progress' issue
conn = await asyncpg.connect(self.conn_string)
await conn.execute(
"INSERT INTO checkpoints (thread_id, thread_ts, parent_ts, checkpoint, metadata) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (thread_id, thread_ts) DO UPDATE SET checkpoint = EXCLUDED.checkpoint, metadata = EXCLUDED.metadata",
str(config["configurable"]["thread_id"]),
checkpoint["id"],
Expand Down
Loading