|
| 1 | +import functools |
| 2 | +import operator |
| 3 | +from datetime import datetime |
| 4 | +from typing import TypedDict, Annotated, Sequence |
| 5 | +from langchain.agents import create_openai_tools_agent, AgentExecutor |
| 6 | +from langchain_core.messages import HumanMessage, BaseMessage |
| 7 | +from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser |
| 8 | +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
| 9 | +from langchain_core.runnables import Runnable, RunnableConfig |
| 10 | +from langchain_core.tools import Tool |
| 11 | +from langchain_openai import ChatOpenAI |
| 12 | +from langgraph.constants import START, END |
| 13 | +from langgraph.graph import StateGraph |
| 14 | + |
| 15 | +from langchain_openai_api_bridge.core.base_agent_factory import BaseAgentFactory |
| 16 | +from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto |
| 17 | + |
| 18 | + |
| 19 | +# Define a new tool that returns the current datetime |
| 20 | +datetime_tool = Tool( |
| 21 | + name="Datetime", |
| 22 | + func=lambda x: datetime.now().isoformat(), |
| 23 | + description="Returns the current datetime", |
| 24 | +) |
| 25 | + |
| 26 | +mock_search_tool = Tool( |
| 27 | + name="Search", |
| 28 | + func=lambda x: "light", |
| 29 | + description="Search the web about something", |
| 30 | +) |
| 31 | + |
| 32 | + |
| 33 | +def create_agent(llm: ChatOpenAI, system_prompt: str, tools: list): |
| 34 | + # Each worker node will be given a name and some tools. |
| 35 | + prompt = ChatPromptTemplate.from_messages( |
| 36 | + [ |
| 37 | + ( |
| 38 | + "system", |
| 39 | + system_prompt, |
| 40 | + ), |
| 41 | + MessagesPlaceholder(variable_name="messages"), |
| 42 | + MessagesPlaceholder(variable_name="agent_scratchpad"), |
| 43 | + ] |
| 44 | + ) |
| 45 | + agent = create_openai_tools_agent(llm, tools, prompt) |
| 46 | + executor = AgentExecutor(agent=agent, tools=tools) |
| 47 | + return executor |
| 48 | + |
| 49 | + |
| 50 | +def agent_node(state, agent, name): |
| 51 | + result = agent.invoke(state) |
| 52 | + return {"messages": [HumanMessage(content=result["output"], name=name)]} |
| 53 | + |
| 54 | + |
| 55 | +members = ["Researcher", "CurrentTime"] |
| 56 | +system_prompt = ( |
| 57 | + "You are a supervisor tasked with managing a conversation between the" |
| 58 | + " following workers: {members}. Given the following user request," |
| 59 | + " respond with the worker to act next. Each worker will perform a" |
| 60 | + " task and respond with their results and status. When finished," |
| 61 | + " respond with FINISH." |
| 62 | +) |
| 63 | +# Our team supervisor is an LLM node. It just picks the next agent to process and decides when the work is completed |
| 64 | +options = ["FINISH"] + members |
| 65 | + |
| 66 | +# Using openai function calling can make output parsing easier for us |
| 67 | +function_def = { |
| 68 | + "name": "route", |
| 69 | + "description": "Select the next role.", |
| 70 | + "parameters": { |
| 71 | + "title": "routeSchema", |
| 72 | + "type": "object", |
| 73 | + "properties": { |
| 74 | + "next": { |
| 75 | + "title": "Next", |
| 76 | + "anyOf": [ |
| 77 | + {"enum": options}, |
| 78 | + ], |
| 79 | + } |
| 80 | + }, |
| 81 | + "required": ["next"], |
| 82 | + }, |
| 83 | +} |
| 84 | + |
| 85 | +# Create the prompt using ChatPromptTemplate |
| 86 | +prompt = ChatPromptTemplate.from_messages( |
| 87 | + [ |
| 88 | + ("system", system_prompt), |
| 89 | + MessagesPlaceholder(variable_name="messages"), |
| 90 | + ( |
| 91 | + "system", |
| 92 | + "Given the conversation above, who should act next?" |
| 93 | + " Or should we FINISH? Select one of: {options}", |
| 94 | + ), |
| 95 | + ] |
| 96 | +).partial(options=str(options), members=", ".join(members)) |
| 97 | + |
| 98 | + |
| 99 | +# The agent state is the input to each node in the graph |
| 100 | +class AgentState(TypedDict): |
| 101 | + # The annotation tells the graph that new messages will always be added to the current states |
| 102 | + messages: Annotated[Sequence[BaseMessage], operator.add] |
| 103 | + # The 'next' field indicates where to route to next |
| 104 | + next: str |
| 105 | + |
| 106 | + |
| 107 | +def create_graph(llm): |
| 108 | + # Construction of the chain for the supervisor agent |
| 109 | + supervisor_chain = ( |
| 110 | + prompt |
| 111 | + | llm.bind_functions(functions=[function_def], function_call="route") |
| 112 | + | JsonOutputFunctionsParser() |
| 113 | + ) |
| 114 | + |
| 115 | + # Add the research agent using the create_agent helper function |
| 116 | + research_agent = create_agent(llm, "You are a web researcher.", [mock_search_tool]) |
| 117 | + research_node = functools.partial( |
| 118 | + agent_node, agent=research_agent, name="Researcher" |
| 119 | + ) |
| 120 | + |
| 121 | + # Add the time agent using the create_agent helper function |
| 122 | + current_time_agent = create_agent( |
| 123 | + llm, "You can tell the current time at", [datetime_tool] |
| 124 | + ) |
| 125 | + current_time_node = functools.partial( |
| 126 | + agent_node, agent=current_time_agent, name="CurrentTime" |
| 127 | + ) |
| 128 | + |
| 129 | + workflow = StateGraph(AgentState) |
| 130 | + |
| 131 | + # Add a "chatbot" node. Nodes represent units of work. They are typically regular python functions. |
| 132 | + workflow.add_node("Researcher", research_node) |
| 133 | + workflow.add_node("CurrentTime", current_time_node) |
| 134 | + workflow.add_node("supervisor", supervisor_chain) |
| 135 | + |
| 136 | + # We want our workers to ALWAYS "report back" to the supervisor when done |
| 137 | + for member in members: |
| 138 | + workflow.add_edge(member, "supervisor") |
| 139 | + |
| 140 | + # Conditional edges usually contain "if" statements to route |
| 141 | + # to different nodes depending on the current graph state. |
| 142 | + # These functions receive the current graph state and return a string |
| 143 | + # or list of strings indicating which node(s) to call next. |
| 144 | + conditional_map = {k: k for k in members} |
| 145 | + conditional_map["FINISH"] = END |
| 146 | + workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map) |
| 147 | + |
| 148 | + # Add an entry point. This tells our graph where to start its work each time we run it. |
| 149 | + workflow.add_edge(START, "supervisor") |
| 150 | + |
| 151 | + # To be able to run our graph, call "compile()" on the graph builder. |
| 152 | + # This creates a "CompiledGraph" we can use invoke on our state. |
| 153 | + graph = workflow.compile(debug=True).with_config( |
| 154 | + RunnableConfig( |
| 155 | + recursion_limit=10, |
| 156 | + ) |
| 157 | + ) |
| 158 | + |
| 159 | + return graph |
| 160 | + |
| 161 | + |
| 162 | +class MyOpenAIMultiAgentFactory(BaseAgentFactory): |
| 163 | + |
| 164 | + def create_agent(self, dto: CreateAgentDto) -> Runnable: |
| 165 | + llm = ChatOpenAI( |
| 166 | + model=dto.model, |
| 167 | + api_key=dto.api_key, |
| 168 | + streaming=True, |
| 169 | + temperature=dto.temperature, |
| 170 | + ) |
| 171 | + return create_graph(llm) |
0 commit comments