Skip to content

Commit d33a5d8

Browse files
test:add langgraph multi agent tests (#39)
Co-authored-by: Samuel <[email protected]>
1 parent dc398f0 commit d33a5d8

File tree

8 files changed

+1216
-839
lines changed

8 files changed

+1216
-839
lines changed

langchain_openai_api_bridge/assistant/adapter/openai_event_factory.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ThreadRunStepDelta,
1212
RunStep,
1313
)
14+
from langchain_core.messages.tool import ToolMessage
1415

1516
from openai.types.beta.threads import (
1617
Message,
@@ -163,7 +164,26 @@ def create_langchain_function(
163164
output: Optional[Union[dict[object], float, str]] = None,
164165
) -> function_tool_call.Function:
165166
arguments_json = json.dumps(arguments) if arguments else None
166-
output_json = json.dumps(output) if output else None
167+
168+
output_json = _serialize_output(output=output)
169+
167170
return function_tool_call.Function(
168171
name=name, arguments=arguments_json, output=output_json
169172
)
173+
174+
175+
def _serialize_output(output: Optional[Union[dict[object], float, str]] = None):
176+
if output is None:
177+
return None
178+
179+
if isinstance(output, ToolMessage):
180+
output_obj = {
181+
"content": output.content,
182+
"tool_call_id": output.tool_call_id,
183+
"status": output.status,
184+
}
185+
if output.artifact is not None:
186+
output_obj["artifact"] = output.artifact
187+
return json.dumps(output_obj)
188+
189+
return json.dumps(output) if output else None

langchain_openai_api_bridge/core/types/openai/chat_completion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class OpenAIChatCompletionChoice(BaseModel):
3131

3232

3333
class OpenAIChatCompletionObject(BaseModel):
34-
id: str
34+
id: Optional[str]
3535
object: str = ("chat.completion",)
3636
created: int
3737
model: str

poetry.lock

+914-836
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ langchain = { version = "^0.2.6", optional = true }
1515
langchain-openai = { version = "^0.1.8", optional = true }
1616
fastapi = { version = "^0.111.0", optional = true }
1717
python-dotenv = { version = "^1.0.1", optional = true }
18-
langgraph = { version = "^0.0.62", optional = true }
18+
langgraph = { version = "^0.2.16", optional = true }
1919
langchain-anthropic = { version = "^0.1.19", optional = true }
2020
langchain-groq = { version = "^0.1.6", optional = true }
2121

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from fastapi import FastAPI
2+
from fastapi.middleware.cors import CORSMiddleware
3+
from dotenv import load_dotenv, find_dotenv
4+
import uvicorn
5+
6+
from langchain_openai_api_bridge.fastapi.langchain_openai_api_bridge_fastapi import (
7+
LangchainOpenaiApiBridgeFastAPI,
8+
)
9+
from tests.test_functional.fastapi_chat_completion_multi_agent_openai.my_openai_multi_agent_factory import (
10+
MyOpenAIMultiAgentFactory,
11+
)
12+
13+
_ = load_dotenv(find_dotenv())
14+
app = FastAPI(
15+
title="Langgraph Multi Agent OpenAI API Bridge",
16+
version="1.0",
17+
description="OpenAI API exposing langgraph multi agent",
18+
)
19+
20+
app.add_middleware(
21+
CORSMiddleware,
22+
allow_origins=["*"],
23+
allow_credentials=True,
24+
allow_methods=["*"],
25+
allow_headers=["*"],
26+
expose_headers=["*"],
27+
)
28+
29+
bridge = LangchainOpenaiApiBridgeFastAPI(
30+
app=app, agent_factory_provider=lambda: MyOpenAIMultiAgentFactory()
31+
)
32+
bridge.bind_openai_chat_completion()
33+
34+
if __name__ == "__main__":
35+
uvicorn.run(app, host="localhost")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import functools
2+
import operator
3+
from datetime import datetime
4+
from typing import TypedDict, Annotated, Sequence
5+
from langchain.agents import create_openai_tools_agent, AgentExecutor
6+
from langchain_core.messages import HumanMessage, BaseMessage
7+
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser
8+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
9+
from langchain_core.runnables import Runnable, RunnableConfig
10+
from langchain_core.tools import Tool
11+
from langchain_openai import ChatOpenAI
12+
from langgraph.constants import START, END
13+
from langgraph.graph import StateGraph
14+
15+
from langchain_openai_api_bridge.core.base_agent_factory import BaseAgentFactory
16+
from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto
17+
18+
19+
# Define a new tool that returns the current datetime
20+
datetime_tool = Tool(
21+
name="Datetime",
22+
func=lambda x: datetime.now().isoformat(),
23+
description="Returns the current datetime",
24+
)
25+
26+
mock_search_tool = Tool(
27+
name="Search",
28+
func=lambda x: "light",
29+
description="Search the web about something",
30+
)
31+
32+
33+
def create_agent(llm: ChatOpenAI, system_prompt: str, tools: list):
34+
# Each worker node will be given a name and some tools.
35+
prompt = ChatPromptTemplate.from_messages(
36+
[
37+
(
38+
"system",
39+
system_prompt,
40+
),
41+
MessagesPlaceholder(variable_name="messages"),
42+
MessagesPlaceholder(variable_name="agent_scratchpad"),
43+
]
44+
)
45+
agent = create_openai_tools_agent(llm, tools, prompt)
46+
executor = AgentExecutor(agent=agent, tools=tools)
47+
return executor
48+
49+
50+
def agent_node(state, agent, name):
51+
result = agent.invoke(state)
52+
return {"messages": [HumanMessage(content=result["output"], name=name)]}
53+
54+
55+
members = ["Researcher", "CurrentTime"]
56+
system_prompt = (
57+
"You are a supervisor tasked with managing a conversation between the"
58+
" following workers: {members}. Given the following user request,"
59+
" respond with the worker to act next. Each worker will perform a"
60+
" task and respond with their results and status. When finished,"
61+
" respond with FINISH."
62+
)
63+
# Our team supervisor is an LLM node. It just picks the next agent to process and decides when the work is completed
64+
options = ["FINISH"] + members
65+
66+
# Using openai function calling can make output parsing easier for us
67+
function_def = {
68+
"name": "route",
69+
"description": "Select the next role.",
70+
"parameters": {
71+
"title": "routeSchema",
72+
"type": "object",
73+
"properties": {
74+
"next": {
75+
"title": "Next",
76+
"anyOf": [
77+
{"enum": options},
78+
],
79+
}
80+
},
81+
"required": ["next"],
82+
},
83+
}
84+
85+
# Create the prompt using ChatPromptTemplate
86+
prompt = ChatPromptTemplate.from_messages(
87+
[
88+
("system", system_prompt),
89+
MessagesPlaceholder(variable_name="messages"),
90+
(
91+
"system",
92+
"Given the conversation above, who should act next?"
93+
" Or should we FINISH? Select one of: {options}",
94+
),
95+
]
96+
).partial(options=str(options), members=", ".join(members))
97+
98+
99+
# The agent state is the input to each node in the graph
100+
class AgentState(TypedDict):
101+
# The annotation tells the graph that new messages will always be added to the current states
102+
messages: Annotated[Sequence[BaseMessage], operator.add]
103+
# The 'next' field indicates where to route to next
104+
next: str
105+
106+
107+
def create_graph(llm):
108+
# Construction of the chain for the supervisor agent
109+
supervisor_chain = (
110+
prompt
111+
| llm.bind_functions(functions=[function_def], function_call="route")
112+
| JsonOutputFunctionsParser()
113+
)
114+
115+
# Add the research agent using the create_agent helper function
116+
research_agent = create_agent(llm, "You are a web researcher.", [mock_search_tool])
117+
research_node = functools.partial(
118+
agent_node, agent=research_agent, name="Researcher"
119+
)
120+
121+
# Add the time agent using the create_agent helper function
122+
current_time_agent = create_agent(
123+
llm, "You can tell the current time at", [datetime_tool]
124+
)
125+
current_time_node = functools.partial(
126+
agent_node, agent=current_time_agent, name="CurrentTime"
127+
)
128+
129+
workflow = StateGraph(AgentState)
130+
131+
# Add a "chatbot" node. Nodes represent units of work. They are typically regular python functions.
132+
workflow.add_node("Researcher", research_node)
133+
workflow.add_node("CurrentTime", current_time_node)
134+
workflow.add_node("supervisor", supervisor_chain)
135+
136+
# We want our workers to ALWAYS "report back" to the supervisor when done
137+
for member in members:
138+
workflow.add_edge(member, "supervisor")
139+
140+
# Conditional edges usually contain "if" statements to route
141+
# to different nodes depending on the current graph state.
142+
# These functions receive the current graph state and return a string
143+
# or list of strings indicating which node(s) to call next.
144+
conditional_map = {k: k for k in members}
145+
conditional_map["FINISH"] = END
146+
workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map)
147+
148+
# Add an entry point. This tells our graph where to start its work each time we run it.
149+
workflow.add_edge(START, "supervisor")
150+
151+
# To be able to run our graph, call "compile()" on the graph builder.
152+
# This creates a "CompiledGraph" we can use invoke on our state.
153+
graph = workflow.compile(debug=True).with_config(
154+
RunnableConfig(
155+
recursion_limit=10,
156+
)
157+
)
158+
159+
return graph
160+
161+
162+
class MyOpenAIMultiAgentFactory(BaseAgentFactory):
163+
164+
def create_agent(self, dto: CreateAgentDto) -> Runnable:
165+
llm = ChatOpenAI(
166+
model=dto.model,
167+
api_key=dto.api_key,
168+
streaming=True,
169+
temperature=dto.temperature,
170+
)
171+
return create_graph(llm)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import pytest
2+
from openai import OpenAI
3+
from fastapi.testclient import TestClient
4+
from multi_agent_server_openai import app
5+
6+
test_api = TestClient(app)
7+
8+
9+
@pytest.fixture
10+
def openai_client():
11+
return OpenAI(
12+
base_url="http://testserver/openai/v1",
13+
http_client=test_api,
14+
)
15+
16+
17+
def test_chat_completion_invoke(openai_client):
18+
chat_completion = openai_client.chat.completions.create(
19+
model="gpt-4o-mini",
20+
messages=[
21+
{
22+
"role": "user",
23+
"content": 'What time is it?',
24+
}
25+
],
26+
)
27+
assert "time" in chat_completion.choices[0].message.content
28+
29+
30+
def test_chat_completion_stream(openai_client):
31+
chunks = openai_client.chat.completions.create(
32+
model="gpt-4o-mini",
33+
messages=[{"role": "user", "content": 'How does photosynthesis work?'}],
34+
stream=True,
35+
)
36+
every_content = []
37+
for chunk in chunks:
38+
if chunk.choices and isinstance(chunk.choices[0].delta.content, str):
39+
every_content.append(chunk.choices[0].delta.content)
40+
41+
stream_output = "".join(every_content)
42+
43+
assert "light" in stream_output

tests/test_unit/assistant/adapter/test_openai_event_factory.py

+30
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import json
12
from langchain_openai_api_bridge.assistant.adapter.openai_event_factory import (
23
create_langchain_function,
34
)
5+
from langchain_core.messages.tool import ToolMessage
46

57

68
class TestCreateLangchainFunction:
@@ -20,3 +22,31 @@ def test_float_output_is_set_to_string(self):
2022
result = create_langchain_function(arguments={"a": "a"}, output=2.1)
2123

2224
assert result.output == "2.1"
25+
26+
def test_ToolMessageOutput_is_serialized_to_json(self):
27+
tool_message_output = ToolMessage(
28+
content="some",
29+
tool_call_id="123",
30+
)
31+
result = create_langchain_function(
32+
arguments={"a": "a"}, output=tool_message_output
33+
)
34+
35+
output = json.loads(result.output)
36+
assert output["content"] == "some"
37+
assert output["tool_call_id"] == "123"
38+
assert output["status"] == "success"
39+
assert output.get("artifact") is None
40+
41+
def test_ToolMessageOutput_with_artifact_is_serialized_to_json(self):
42+
tool_message_output = ToolMessage(
43+
content="some",
44+
tool_call_id="123",
45+
artifact={"test": "test"},
46+
)
47+
result = create_langchain_function(
48+
arguments={"a": "a"}, output=tool_message_output
49+
)
50+
51+
output = json.loads(result.output)
52+
assert output["artifact"] == {"test": "test"}

0 commit comments

Comments
 (0)