Skip to content

Commit

Permalink
Fix lint errors (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
StreetLamb authored May 1, 2024
1 parent 2b56441 commit da9a32a
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 64 deletions.
8 changes: 5 additions & 3 deletions backend/app/api/routes/members.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

def check_duplicate_names_on_create(
session: SessionDep, team_id: int, member_in: MemberCreate
):
) -> None:
"""Check if (name, team_id) is unique"""
statement = select(Member).where(
Member.name == member_in.name,
Expand All @@ -35,7 +35,7 @@ def check_duplicate_names_on_create(

def check_duplicate_names_on_update(
session: SessionDep, team_id: int, member_in: MemberUpdate, id: int
):
) -> None:
"""Check if (name, team_id) is unique"""
statement = select(Member).where(
Member.name == member_in.name,
Expand Down Expand Up @@ -133,6 +133,8 @@ def create_member(
"""
if not current_user.is_superuser:
team = session.get(Team, team_id)
if not team:
raise HTTPException(status_code=404, detail="Team not found.")
if team.owner_id != current_user.id:
raise HTTPException(status_code=400, detail="Not enough permissions")
member = Member.model_validate(member_in, update={"belongs_to": team_id})
Expand Down Expand Up @@ -181,7 +183,7 @@ def update_member(
if member_in.skills is not None:
skill_ids = [skill.id for skill in member_in.skills]
skills = session.exec(select(Skill).where(col(Skill.id).in_(skill_ids))).all()
member.skills = skills
member.skills = list(skills)

update_dict = member_in.model_dump(exclude_unset=True)
member.sqlmodel_update(update_dict)
Expand Down
8 changes: 5 additions & 3 deletions backend/app/api/routes/teams.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@
router = APIRouter()


async def check_duplicate_name_on_create(session: SessionDep, team_in: TeamCreate):
async def check_duplicate_name_on_create(
session: SessionDep, team_in: TeamCreate
) -> None:
"""Validate that team name is unique"""
statement = select(Team).where(Team.name == team_in.name)
team = session.exec(statement).first()
Expand All @@ -72,7 +74,7 @@ async def check_duplicate_name_on_create(session: SessionDep, team_in: TeamCreat

async def check_duplicate_name_on_update(
session: SessionDep, team_in: TeamUpdate, id: int
):
) -> None:
"""Validate that team name is unique"""
statement = select(Team).where(Team.name == team_in.name, Team.id != id)
team = session.exec(statement).first()
Expand Down Expand Up @@ -199,7 +201,7 @@ def delete_team(session: SessionDep, current_user: CurrentUser, id: int) -> Any:
@router.post("/{id}/stream")
async def stream(
session: SessionDep, current_user: CurrentUser, id: int, team_chat: TeamChat
):
) -> StreamingResponse:
"""
Stream a response to a user's input.
"""
Expand Down
41 changes: 23 additions & 18 deletions backend/app/core/graph/build.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import json
from collections import defaultdict, deque
from collections.abc import AsyncGenerator
from functools import partial
from typing import Any

from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.runnables import RunnableLambda
from langgraph.graph import StateGraph
from langgraph.graph.graph import CompiledGraph

from app.core.graph.members import (
Leader,
Expand Down Expand Up @@ -41,19 +44,20 @@ def convert_team_to_dict(
"""
teams: dict[str, Team] = {}

in_counts = defaultdict(int)
out_counts = defaultdict(list[int])
in_counts: defaultdict[int, int] = defaultdict(int)
out_counts: defaultdict[int, list[int]] = defaultdict(list[int])
members_lookup: dict[int, MemberModel] = {}

for member in members:
assert member.id is not None, "member.id is unexpectedly None"
if member.source:
in_counts[member.id] += 1
out_counts[member.source].append(member.id)
else:
in_counts[member.id] = 0
members_lookup[member.id] = member

queue = deque()
queue: deque[int] = deque()

for member_id in in_counts:
if in_counts[member_id] == 0:
Expand All @@ -73,12 +77,11 @@ def convert_team_to_dict(
temperature=member.temperature,
)
# If member is not root team leader, add as a member
if member.type != "root":
if member.type != "root" and member.source:
member_name = member.name
leader = members_lookup[member.source]
leader_name = leader.name
teams[leader_name].members[member_name] = Member(
type=member.type,
name=member_name,
backstory=member.backstory or "",
role=member.role,
Expand All @@ -96,7 +99,7 @@ def convert_team_to_dict(
return teams


def format_teams(teams: dict[str, any]) -> dict[str, Team]:
def format_teams(teams: dict[str, dict[str, Any]]) -> dict[str, Team]:
"""
FOR TESTING PURPOSES ONLY!
Expand All @@ -113,7 +116,7 @@ def format_teams(teams: dict[str, any]) -> dict[str, Team]:
for team_name, team in teams.items():
if not isinstance(team, dict):
raise ValueError(f"Invalid team {team_name}. Teams must be dictionaries.")
members = team.get("members", {})
members: dict[str, dict[str, Any]] = team.get("members", {})
for k, v in members.items():
if v["type"] == "leader":
teams[team_name]["members"][k] = Leader(**v)
Expand All @@ -122,11 +125,13 @@ def format_teams(teams: dict[str, any]) -> dict[str, Team]:
return {team_name: Team(**team) for team_name, team in teams.items()}


def router(state: TeamState):
def router(state: TeamState) -> str:
return state["next"]


def enter_chain(state: TeamState, team: dict[str, str | list[Member | Leader]]):
def enter_chain(
state: TeamState, team: dict[str, str | list[Member | Leader]]
) -> dict[str, Any]:
"""
Initialise the sub-graph state.
This makes it so that the states of each graph don't get intermixed.
Expand All @@ -143,15 +148,15 @@ def enter_chain(state: TeamState, team: dict[str, str | list[Member | Leader]]):
return results


def exit_chain(state: TeamState):
def exit_chain(state: TeamState) -> dict[str, list[BaseMessage]]:
"""
Pass the final response back to the top-level graph's state.
"""
answer = state["messages"][-1]
return {"messages": [answer]}


def create_graph(teams: dict[str, Team], leader_name: str):
def create_graph(teams: dict[str, Team], leader_name: str) -> CompiledGraph:
"""Create the team's graph.
This function creates a graph representation of the given teams. The graph is represented as a dictionary where each key is a team name,
Expand Down Expand Up @@ -220,23 +225,23 @@ def create_graph(teams: dict[str, Team], leader_name: str):


async def generator(
team: TeamModel, members: list[Member], messages: list[ChatMessage]
):
team: TeamModel, members: list[MemberModel], messages: list[ChatMessage]
) -> AsyncGenerator[Any, Any]:
"""Create the graph and stream responses as JSON."""
teams = convert_team_to_dict(team, members)
team_leader = list(teams.keys())[0]
root = create_graph(teams, leader_name=team_leader)
messages = [
HumanMessage(message.content)
formatted_messages = [
HumanMessage(content=message.content)
if message.type == "human"
else AIMessage(message.content)
else AIMessage(content=message.content)
for message in messages
]

# TODO: Figure out how to use async_stream to stream responses from subgraphs
async for output in root.astream(
{
"messages": messages,
"messages": formatted_messages,
"team_name": teams[team_leader].name,
"team_members": teams[team_leader].members,
}
Expand Down
61 changes: 34 additions & 27 deletions backend/app/core/graph/members.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import operator
from typing import Annotated, TypedDict
from collections.abc import Sequence
from typing import Annotated, Any, TypedDict

from langchain.agents import (
AgentExecutor,
Expand All @@ -9,7 +10,8 @@
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.output_parsers.openai_tools import JsonOutputKeyToolsParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables import RunnableLambda, RunnableSerializable
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field

from app.core.graph.models import all_models
Expand Down Expand Up @@ -50,7 +52,7 @@ class Team(BaseModel):
)


def update_name(name: str, new_name: str):
def update_name(name: str, new_name: str) -> str:
"""Update name at the onset."""
if not name:
return new_name
Expand All @@ -59,7 +61,7 @@ def update_name(name: str, new_name: str):

def update_members(
members: dict[str, Member | Leader] | None, new_members: dict[str, Member | Leader]
):
) -> dict[str, Member | Leader]:
"""Update members at the onset"""
if not members:
members = {}
Expand All @@ -79,15 +81,15 @@ class TeamState(TypedDict):

class BaseNode:
def __init__(self, provider: str, model: str, temperature: float):
self.model = all_models[provider](model=model, temperature=temperature)
self.final_answer_model = all_models[provider](model=model, temperature=0)
self.model = all_models[provider](model=model, temperature=temperature) # type: ignore[call-arg]
self.final_answer_model = all_models[provider](model=model, temperature=0) # type: ignore[call-arg]

def tag_with_name(self, ai_message: AIMessage, name: str):
def tag_with_name(self, ai_message: AIMessage, name: str) -> AIMessage:
"""Tag a name to the AI message"""
ai_message.name = name
return ai_message

def get_team_members_name(self, team_members: dict[str, Person]):
def get_team_members_name(self, team_members: dict[str, Member | Leader]) -> str:
"""Get the names of all team members as a string"""
return ",".join(list(team_members))

Expand All @@ -111,23 +113,24 @@ class WorkerNode(BaseNode):
]
)

def convert_output_to_ai_message(self, state: TeamState):
def convert_output_to_ai_message(self, agent_output: dict[str, str]) -> AIMessage:
"""Convert agent executor output to ai message"""
output = state["output"]
output = agent_output["output"]
return AIMessage(content=output)

def create_agent(
self, llm: BaseChatModel, prompt: ChatPromptTemplate, tools: list[str]
):
) -> AgentExecutor:
"""Create the agent executor. Tools must non-empty."""
tools = [all_skills[tool].tool for tool in tools]
agent = create_tool_calling_agent(llm, tools, prompt)
executor = AgentExecutor(agent=agent, tools=tools)
formatted_tools: Sequence[BaseTool] = [all_skills[tool].tool for tool in tools]
agent = create_tool_calling_agent(llm, formatted_tools, prompt)
executor = AgentExecutor(agent=agent, tools=formatted_tools) # type: ignore[arg-type]
return executor

async def work(self, state: TeamState):
async def work(self, state: TeamState) -> dict[str, list[BaseMessage]]:
name = state["next"]
member = state["team_members"][name]
assert isinstance(member, Member), "member is unexpectedly not a Member"
tools = member.tools
team_members_name = self.get_team_members_name(state["team_members"])
prompt = self.worker_prompt.partial(
Expand All @@ -139,9 +142,13 @@ async def work(self, state: TeamState):
agent = self.create_agent(self.model, prompt, tools)
chain = agent | RunnableLambda(self.convert_output_to_ai_message)
else:
chain = prompt.partial(agent_scratchpad=[]) | self.model
work_chain = chain | RunnableLambda(self.tag_with_name).bind(name=member.name)
result = await work_chain.ainvoke(state)
chain: RunnableSerializable[dict[str, Any], BaseMessage] = ( # type: ignore[no-redef]
prompt.partial(agent_scratchpad=[]) | self.model
)
work_chain: RunnableSerializable[dict[str, Any], Any] = chain | RunnableLambda(
self.tag_with_name # type: ignore[arg-type]
).bind(name=member.name)
result = await work_chain.ainvoke(state) # type: ignore[arg-type]
return {"messages": [result]}


Expand All @@ -165,14 +172,14 @@ class LeaderNode(BaseNode):
]
)

def get_team_members_info(self, team_members: list[Member]):
def get_team_members_info(self, team_members: dict[str, Member | Leader]) -> str:
"""Create a string containing team members name and role."""
result = ""
for member in team_members.values():
result += f"name: {member.name}\nrole: {member.role}\n\n"
return result

def get_tool_definition(self, options: list[str]):
def get_tool_definition(self, options: list[str]) -> dict[str, Any]:
"""Return the tool definition to choose next team member and provide the task."""
return {
"type": "function",
Expand All @@ -199,14 +206,14 @@ def get_tool_definition(self, options: list[str]):
},
}

async def delegate(self, state: TeamState):
async def delegate(self, state: TeamState) -> dict[str, Any]:
team_members_name = self.get_team_members_name(state["team_members"])
team_name = state["team_name"]
team_members_info = self.get_team_members_info(state["team_members"])
options = list(state["team_members"]) + ["FINISH"]
tools = [self.get_tool_definition(options)]

delegate_chain = (
delegate_chain: RunnableSerializable[Any, Any] = (
self.leader_prompt.partial(
team_name=team_name,
team_members_name=team_members_name,
Expand All @@ -216,7 +223,7 @@ async def delegate(self, state: TeamState):
| self.model.bind_tools(tools=tools)
| JsonOutputKeyToolsParser(key_name="route", first_tool_only=True)
)
result = await delegate_chain.ainvoke(state)
result: dict[str, Any] = await delegate_chain.ainvoke(state)
if not result:
return {
"task": [HumanMessage(content="No further tasks.", name=team_name)],
Expand Down Expand Up @@ -254,28 +261,28 @@ class SummariserNode(BaseNode):
]
)

def get_team_responses(self, messages: list[BaseMessage]):
def get_team_responses(self, messages: list[BaseMessage]) -> str:
"""Create a string containing the team's responses."""
result = ""
for message in messages:
result += f"{message.name}: {message.content}\n"
return result

async def summarise(self, state: TeamState):
async def summarise(self, state: TeamState) -> dict[str, list[BaseMessage]]:
team_members_name = self.get_team_members_name(state["team_members"])
team_name = state["team_name"]
team_responses = self.get_team_responses(state["messages"])
team_task = state["messages"][0].content

summarise_chain = (
summarise_chain: RunnableSerializable[Any, Any] = (
self.summariser_prompt.partial(
team_name=team_name,
team_members_name=team_members_name,
team_task=team_task,
team_responses=team_responses,
)
| self.final_answer_model
| RunnableLambda(self.tag_with_name).bind(name="FinalAnswer")
| RunnableLambda(self.tag_with_name).bind(name="FinalAnswer") # type: ignore[arg-type]
)
result = await summarise_chain.ainvoke(state)
return {"messages": [result]}
Loading

0 comments on commit da9a32a

Please sign in to comment.