From da9a32abb45a6f13e13c6525e86c26eb47826cf0 Mon Sep 17 00:00:00 2001 From: Jerron Lim Date: Thu, 2 May 2024 00:10:55 +0800 Subject: [PATCH] Fix lint errors (#9) --- backend/app/api/routes/members.py | 8 ++-- backend/app/api/routes/teams.py | 8 ++-- backend/app/core/graph/build.py | 41 ++++++++++++--------- backend/app/core/graph/members.py | 61 +++++++++++++++++-------------- backend/app/core/graph/skills.py | 8 ++-- backend/app/models.py | 18 ++++----- 6 files changed, 80 insertions(+), 64 deletions(-) diff --git a/backend/app/api/routes/members.py b/backend/app/api/routes/members.py index 09436597..5dfc4f47 100644 --- a/backend/app/api/routes/members.py +++ b/backend/app/api/routes/members.py @@ -20,7 +20,7 @@ def check_duplicate_names_on_create( session: SessionDep, team_id: int, member_in: MemberCreate -): +) -> None: """Check if (name, team_id) is unique""" statement = select(Member).where( Member.name == member_in.name, @@ -35,7 +35,7 @@ def check_duplicate_names_on_create( def check_duplicate_names_on_update( session: SessionDep, team_id: int, member_in: MemberUpdate, id: int -): +) -> None: """Check if (name, team_id) is unique""" statement = select(Member).where( Member.name == member_in.name, @@ -133,6 +133,8 @@ def create_member( """ if not current_user.is_superuser: team = session.get(Team, team_id) + if not team: + raise HTTPException(status_code=404, detail="Team not found.") if team.owner_id != current_user.id: raise HTTPException(status_code=400, detail="Not enough permissions") member = Member.model_validate(member_in, update={"belongs_to": team_id}) @@ -181,7 +183,7 @@ def update_member( if member_in.skills is not None: skill_ids = [skill.id for skill in member_in.skills] skills = session.exec(select(Skill).where(col(Skill.id).in_(skill_ids))).all() - member.skills = skills + member.skills = list(skills) update_dict = member_in.model_dump(exclude_unset=True) member.sqlmodel_update(update_dict) diff --git a/backend/app/api/routes/teams.py b/backend/app/api/routes/teams.py index 5751e5d7..af18bd05 100644 --- a/backend/app/api/routes/teams.py +++ b/backend/app/api/routes/teams.py @@ -62,7 +62,9 @@ router = APIRouter() -async def check_duplicate_name_on_create(session: SessionDep, team_in: TeamCreate): +async def check_duplicate_name_on_create( + session: SessionDep, team_in: TeamCreate +) -> None: """Validate that team name is unique""" statement = select(Team).where(Team.name == team_in.name) team = session.exec(statement).first() @@ -72,7 +74,7 @@ async def check_duplicate_name_on_create(session: SessionDep, team_in: TeamCreat async def check_duplicate_name_on_update( session: SessionDep, team_in: TeamUpdate, id: int -): +) -> None: """Validate that team name is unique""" statement = select(Team).where(Team.name == team_in.name, Team.id != id) team = session.exec(statement).first() @@ -199,7 +201,7 @@ def delete_team(session: SessionDep, current_user: CurrentUser, id: int) -> Any: @router.post("/{id}/stream") async def stream( session: SessionDep, current_user: CurrentUser, id: int, team_chat: TeamChat -): +) -> StreamingResponse: """ Stream a response to a user's input. """ diff --git a/backend/app/core/graph/build.py b/backend/app/core/graph/build.py index 97669271..e07a9eae 100644 --- a/backend/app/core/graph/build.py +++ b/backend/app/core/graph/build.py @@ -1,10 +1,13 @@ import json from collections import defaultdict, deque +from collections.abc import AsyncGenerator from functools import partial +from typing import Any -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.runnables import RunnableLambda from langgraph.graph import StateGraph +from langgraph.graph.graph import CompiledGraph from app.core.graph.members import ( Leader, @@ -41,11 +44,12 @@ def convert_team_to_dict( """ teams: dict[str, Team] = {} - in_counts = defaultdict(int) - out_counts = defaultdict(list[int]) + in_counts: defaultdict[int, int] = defaultdict(int) + out_counts: defaultdict[int, list[int]] = defaultdict(list[int]) members_lookup: dict[int, MemberModel] = {} for member in members: + assert member.id is not None, "member.id is unexpectedly None" if member.source: in_counts[member.id] += 1 out_counts[member.source].append(member.id) @@ -53,7 +57,7 @@ def convert_team_to_dict( in_counts[member.id] = 0 members_lookup[member.id] = member - queue = deque() + queue: deque[int] = deque() for member_id in in_counts: if in_counts[member_id] == 0: @@ -73,12 +77,11 @@ def convert_team_to_dict( temperature=member.temperature, ) # If member is not root team leader, add as a member - if member.type != "root": + if member.type != "root" and member.source: member_name = member.name leader = members_lookup[member.source] leader_name = leader.name teams[leader_name].members[member_name] = Member( - type=member.type, name=member_name, backstory=member.backstory or "", role=member.role, @@ -96,7 +99,7 @@ def convert_team_to_dict( return teams -def format_teams(teams: dict[str, any]) -> dict[str, Team]: +def format_teams(teams: dict[str, dict[str, Any]]) -> dict[str, Team]: """ FOR TESTING PURPOSES ONLY! @@ -113,7 +116,7 @@ def format_teams(teams: dict[str, any]) -> dict[str, Team]: for team_name, team in teams.items(): if not isinstance(team, dict): raise ValueError(f"Invalid team {team_name}. Teams must be dictionaries.") - members = team.get("members", {}) + members: dict[str, dict[str, Any]] = team.get("members", {}) for k, v in members.items(): if v["type"] == "leader": teams[team_name]["members"][k] = Leader(**v) @@ -122,11 +125,13 @@ def format_teams(teams: dict[str, any]) -> dict[str, Team]: return {team_name: Team(**team) for team_name, team in teams.items()} -def router(state: TeamState): +def router(state: TeamState) -> str: return state["next"] -def enter_chain(state: TeamState, team: dict[str, str | list[Member | Leader]]): +def enter_chain( + state: TeamState, team: dict[str, str | list[Member | Leader]] +) -> dict[str, Any]: """ Initialise the sub-graph state. This makes it so that the states of each graph don't get intermixed. @@ -143,7 +148,7 @@ def enter_chain(state: TeamState, team: dict[str, str | list[Member | Leader]]): return results -def exit_chain(state: TeamState): +def exit_chain(state: TeamState) -> dict[str, list[BaseMessage]]: """ Pass the final response back to the top-level graph's state. """ @@ -151,7 +156,7 @@ def exit_chain(state: TeamState): return {"messages": [answer]} -def create_graph(teams: dict[str, Team], leader_name: str): +def create_graph(teams: dict[str, Team], leader_name: str) -> CompiledGraph: """Create the team's graph. This function creates a graph representation of the given teams. The graph is represented as a dictionary where each key is a team name, @@ -220,23 +225,23 @@ def create_graph(teams: dict[str, Team], leader_name: str): async def generator( - team: TeamModel, members: list[Member], messages: list[ChatMessage] -): + team: TeamModel, members: list[MemberModel], messages: list[ChatMessage] +) -> AsyncGenerator[Any, Any]: """Create the graph and stream responses as JSON.""" teams = convert_team_to_dict(team, members) team_leader = list(teams.keys())[0] root = create_graph(teams, leader_name=team_leader) - messages = [ - HumanMessage(message.content) + formatted_messages = [ + HumanMessage(content=message.content) if message.type == "human" - else AIMessage(message.content) + else AIMessage(content=message.content) for message in messages ] # TODO: Figure out how to use async_stream to stream responses from subgraphs async for output in root.astream( { - "messages": messages, + "messages": formatted_messages, "team_name": teams[team_leader].name, "team_members": teams[team_leader].members, } diff --git a/backend/app/core/graph/members.py b/backend/app/core/graph/members.py index 16f38c95..14f9cf25 100644 --- a/backend/app/core/graph/members.py +++ b/backend/app/core/graph/members.py @@ -1,5 +1,6 @@ import operator -from typing import Annotated, TypedDict +from collections.abc import Sequence +from typing import Annotated, Any, TypedDict from langchain.agents import ( AgentExecutor, @@ -9,7 +10,8 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.output_parsers.openai_tools import JsonOutputKeyToolsParser from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain_core.runnables import RunnableLambda +from langchain_core.runnables import RunnableLambda, RunnableSerializable +from langchain_core.tools import BaseTool from pydantic import BaseModel, Field from app.core.graph.models import all_models @@ -50,7 +52,7 @@ class Team(BaseModel): ) -def update_name(name: str, new_name: str): +def update_name(name: str, new_name: str) -> str: """Update name at the onset.""" if not name: return new_name @@ -59,7 +61,7 @@ def update_name(name: str, new_name: str): def update_members( members: dict[str, Member | Leader] | None, new_members: dict[str, Member | Leader] -): +) -> dict[str, Member | Leader]: """Update members at the onset""" if not members: members = {} @@ -79,15 +81,15 @@ class TeamState(TypedDict): class BaseNode: def __init__(self, provider: str, model: str, temperature: float): - self.model = all_models[provider](model=model, temperature=temperature) - self.final_answer_model = all_models[provider](model=model, temperature=0) + self.model = all_models[provider](model=model, temperature=temperature) # type: ignore[call-arg] + self.final_answer_model = all_models[provider](model=model, temperature=0) # type: ignore[call-arg] - def tag_with_name(self, ai_message: AIMessage, name: str): + def tag_with_name(self, ai_message: AIMessage, name: str) -> AIMessage: """Tag a name to the AI message""" ai_message.name = name return ai_message - def get_team_members_name(self, team_members: dict[str, Person]): + def get_team_members_name(self, team_members: dict[str, Member | Leader]) -> str: """Get the names of all team members as a string""" return ",".join(list(team_members)) @@ -111,23 +113,24 @@ class WorkerNode(BaseNode): ] ) - def convert_output_to_ai_message(self, state: TeamState): + def convert_output_to_ai_message(self, agent_output: dict[str, str]) -> AIMessage: """Convert agent executor output to ai message""" - output = state["output"] + output = agent_output["output"] return AIMessage(content=output) def create_agent( self, llm: BaseChatModel, prompt: ChatPromptTemplate, tools: list[str] - ): + ) -> AgentExecutor: """Create the agent executor. Tools must non-empty.""" - tools = [all_skills[tool].tool for tool in tools] - agent = create_tool_calling_agent(llm, tools, prompt) - executor = AgentExecutor(agent=agent, tools=tools) + formatted_tools: Sequence[BaseTool] = [all_skills[tool].tool for tool in tools] + agent = create_tool_calling_agent(llm, formatted_tools, prompt) + executor = AgentExecutor(agent=agent, tools=formatted_tools) # type: ignore[arg-type] return executor - async def work(self, state: TeamState): + async def work(self, state: TeamState) -> dict[str, list[BaseMessage]]: name = state["next"] member = state["team_members"][name] + assert isinstance(member, Member), "member is unexpectedly not a Member" tools = member.tools team_members_name = self.get_team_members_name(state["team_members"]) prompt = self.worker_prompt.partial( @@ -139,9 +142,13 @@ async def work(self, state: TeamState): agent = self.create_agent(self.model, prompt, tools) chain = agent | RunnableLambda(self.convert_output_to_ai_message) else: - chain = prompt.partial(agent_scratchpad=[]) | self.model - work_chain = chain | RunnableLambda(self.tag_with_name).bind(name=member.name) - result = await work_chain.ainvoke(state) + chain: RunnableSerializable[dict[str, Any], BaseMessage] = ( # type: ignore[no-redef] + prompt.partial(agent_scratchpad=[]) | self.model + ) + work_chain: RunnableSerializable[dict[str, Any], Any] = chain | RunnableLambda( + self.tag_with_name # type: ignore[arg-type] + ).bind(name=member.name) + result = await work_chain.ainvoke(state) # type: ignore[arg-type] return {"messages": [result]} @@ -165,14 +172,14 @@ class LeaderNode(BaseNode): ] ) - def get_team_members_info(self, team_members: list[Member]): + def get_team_members_info(self, team_members: dict[str, Member | Leader]) -> str: """Create a string containing team members name and role.""" result = "" for member in team_members.values(): result += f"name: {member.name}\nrole: {member.role}\n\n" return result - def get_tool_definition(self, options: list[str]): + def get_tool_definition(self, options: list[str]) -> dict[str, Any]: """Return the tool definition to choose next team member and provide the task.""" return { "type": "function", @@ -199,14 +206,14 @@ def get_tool_definition(self, options: list[str]): }, } - async def delegate(self, state: TeamState): + async def delegate(self, state: TeamState) -> dict[str, Any]: team_members_name = self.get_team_members_name(state["team_members"]) team_name = state["team_name"] team_members_info = self.get_team_members_info(state["team_members"]) options = list(state["team_members"]) + ["FINISH"] tools = [self.get_tool_definition(options)] - delegate_chain = ( + delegate_chain: RunnableSerializable[Any, Any] = ( self.leader_prompt.partial( team_name=team_name, team_members_name=team_members_name, @@ -216,7 +223,7 @@ async def delegate(self, state: TeamState): | self.model.bind_tools(tools=tools) | JsonOutputKeyToolsParser(key_name="route", first_tool_only=True) ) - result = await delegate_chain.ainvoke(state) + result: dict[str, Any] = await delegate_chain.ainvoke(state) if not result: return { "task": [HumanMessage(content="No further tasks.", name=team_name)], @@ -254,20 +261,20 @@ class SummariserNode(BaseNode): ] ) - def get_team_responses(self, messages: list[BaseMessage]): + def get_team_responses(self, messages: list[BaseMessage]) -> str: """Create a string containing the team's responses.""" result = "" for message in messages: result += f"{message.name}: {message.content}\n" return result - async def summarise(self, state: TeamState): + async def summarise(self, state: TeamState) -> dict[str, list[BaseMessage]]: team_members_name = self.get_team_members_name(state["team_members"]) team_name = state["team_name"] team_responses = self.get_team_responses(state["messages"]) team_task = state["messages"][0].content - summarise_chain = ( + summarise_chain: RunnableSerializable[Any, Any] = ( self.summariser_prompt.partial( team_name=team_name, team_members_name=team_members_name, @@ -275,7 +282,7 @@ async def summarise(self, state: TeamState): team_responses=team_responses, ) | self.final_answer_model - | RunnableLambda(self.tag_with_name).bind(name="FinalAnswer") + | RunnableLambda(self.tag_with_name).bind(name="FinalAnswer") # type: ignore[arg-type] ) result = await summarise_chain.ainvoke(state) return {"messages": [result]} diff --git a/backend/app/core/graph/skills.py b/backend/app/core/graph/skills.py index ec56b216..0220e29a 100644 --- a/backend/app/core/graph/skills.py +++ b/backend/app/core/graph/skills.py @@ -27,15 +27,15 @@ class SkillInfo(BaseModel): ), "wikipedia": SkillInfo( description="Searches Wikipedia", - tool=WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()), + tool=WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()), # type: ignore[call-arg] ), "google-finance": SkillInfo( description="Get information from Google Finance Page via SerpApi.", - tool=GoogleFinanceQueryRun(api_wrapper=GoogleFinanceAPIWrapper()), + tool=GoogleFinanceQueryRun(api_wrapper=GoogleFinanceAPIWrapper()), # type: ignore[call-arg] ), "google-jobs": SkillInfo( description="Fetch current job postings from Google Jobs via SerpApi.", - tool=GoogleJobsQueryRun(api_wrapper=GoogleJobsAPIWrapper()), + tool=GoogleJobsQueryRun(api_wrapper=GoogleJobsAPIWrapper()), # type: ignore[call-arg] ), "google-scholar": SkillInfo( description="Fetch papers from Google Scholar via SerpApi.", @@ -43,7 +43,7 @@ class SkillInfo(BaseModel): ), "google-trends": SkillInfo( description="Get information from Google Trends Page via SerpApi.", - tool=GoogleTrendsQueryRun(api_wrapper=GoogleTrendsAPIWrapper()), + tool=GoogleTrendsQueryRun(api_wrapper=GoogleTrendsAPIWrapper()), # type: ignore[call-arg] ), "yahoo-finance": SkillInfo( description="Get information from Yahoo Finance News.", diff --git a/backend/app/models.py b/backend/app/models.py index 9b09db3f..f030f580 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -134,7 +134,7 @@ class TeamCreate(TeamBase): class TeamUpdate(TeamBase): - name: str | None = PydanticField(pattern=r"^[a-zA-Z0-9_-]{1,64}$", default=None) + name: str | None = PydanticField(pattern=r"^[a-zA-Z0-9_-]{1,64}$", default=None) # type: ignore[assignment] class ChatMessageType(str, Enum): @@ -201,17 +201,17 @@ class MemberCreate(MemberBase): class MemberUpdate(MemberBase): - name: str | None = PydanticField(pattern=r"^[a-zA-Z0-9_-]{1,64}$", default=None) + name: str | None = PydanticField(pattern=r"^[a-zA-Z0-9_-]{1,64}$", default=None) # type: ignore[assignment] backstory: str | None = None - role: str | None = None - type: str | None = None + role: str | None = None # type: ignore[assignment] + type: str | None = None # type: ignore[assignment] belongs_to: int | None = None - position_x: float | None = None - position_y: float | None = None + position_x: float | None = None # type: ignore[assignment] + position_y: float | None = None # type: ignore[assignment] skills: list["Skill"] | None = None - provider: str | None = None - model: str | None = None - temperature: float | None = None + provider: str | None = None # type: ignore[assignment] + model: str | None = None # type: ignore[assignment] + temperature: float | None = None # type: ignore[assignment] class Member(MemberBase, table=True):