Skip to content

Commit 9bed96e

Browse files
committed
Enhance ChatAgent to support MCP servers
1 parent 45e6c4b commit 9bed96e

File tree

2 files changed

+98
-1
lines changed

2 files changed

+98
-1
lines changed

coagent/agents/chat_agent.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,21 @@
88

99
from coagent.core import Address, BaseAgent, Context, handler, logger
1010
from coagent.core.agent import is_async_iterator
11+
import jsonschema
1112
from pydantic_core import PydanticUndefined
1213
from pydantic.fields import FieldInfo
1314

1415
from .aswarm import Agent as SwarmAgent, Swarm
1516
from .aswarm.util import function_to_jsonschema
17+
from .mcp_server import (
18+
CallTool,
19+
CallToolResult,
20+
ListTools,
21+
ListToolsResult,
22+
MCPTool,
23+
MCPImageContent,
24+
MCPTextContent,
25+
)
1626
from .messages import ChatMessage, ChatHistory, StructuredOutput
1727
from .model_client import default_model_client, ModelClient
1828
from .util import is_user_confirmed
@@ -207,6 +217,8 @@ def __init__(
207217
name: str = "",
208218
system: str = "",
209219
tools: list[Callable] | None = None,
220+
mcp_servers: list[str] | None = None,
221+
mcp_server_agent_type: str = "mcp_server",
210222
client: ModelClient = default_model_client,
211223
timeout: float = 300,
212224
):
@@ -215,6 +227,8 @@ def __init__(
215227
self._name: str = name
216228
self._system: str = system
217229
self._tools: list[Callable] = tools or []
230+
self._mcp_servers: list[str] = mcp_servers or []
231+
self._mcp_server_agent_type: str = mcp_server_agent_type
218232
self._client: ModelClient = client
219233

220234
self._swarm_client: Swarm = Swarm(self.client)
@@ -239,6 +253,14 @@ def system(self) -> str:
239253
def tools(self) -> list[Callable]:
240254
return self._tools
241255

256+
@property
257+
def mcp_servers(self) -> list[str]:
258+
return self._mcp_servers
259+
260+
@property
261+
def mcp_server_agent_type(self) -> str:
262+
return self._mcp_server_agent_type
263+
242264
@property
243265
def client(self) -> ModelClient:
244266
return self._client
@@ -264,11 +286,17 @@ def get_swarm_client(self, extensions: dict) -> Swarm:
264286
async def get_swarm_agent(self) -> SwarmAgent:
265287
if not self._swarm_agent:
266288
tools = self.tools[:] # copy
289+
290+
# Collect all methods marked as tools.
267291
methods = inspect.getmembers(self, predicate=inspect.ismethod)
268292
for _name, meth in methods:
269293
if getattr(meth, "is_tool", False):
270294
tools.append(meth)
271295

296+
# Collect all tools from MCP servers.
297+
mcp_tools = await self._get_mcp_tools(self.mcp_servers)
298+
tools.extend(mcp_tools)
299+
272300
self._swarm_agent = SwarmAgent(
273301
name=self.name,
274302
model=self.client.model,
@@ -314,6 +342,67 @@ async def handle_structured_output(
314342
async for resp in response:
315343
yield resp
316344

345+
async def _get_mcp_tools(self, mcp_servers: list[str]) -> list[Callable]:
346+
all_tools = []
347+
348+
for server in mcp_servers:
349+
raw_result = await self.channel.publish(
350+
Address(name=self.mcp_server_agent_type, id=server),
351+
ListTools().encode(),
352+
request=True,
353+
timeout=10,
354+
)
355+
result = ListToolsResult.decode(raw_result)
356+
357+
tools = [self._to_function_tool(server, t) for t in result.tools]
358+
all_tools.extend(tools)
359+
360+
return all_tools
361+
362+
def _to_function_tool(self, server: str, t: MCPTool) -> Callable:
363+
async def tool(**kwargs) -> Any:
364+
# Validate the input against the schema
365+
jsonschema.validate(instance=kwargs, schema=t.inputSchema)
366+
367+
# Actually call the tool.
368+
raw_result = await self.channel.publish(
369+
Address(name=self.mcp_server_agent_type, id=server),
370+
CallTool(
371+
name=t.name,
372+
arguments=kwargs,
373+
).encode(),
374+
request=True,
375+
timeout=10,
376+
)
377+
result = CallToolResult.decode(raw_result)
378+
379+
if not result.content:
380+
return ""
381+
content = result.content[0]
382+
383+
if result.isError:
384+
raise ValueError(content.text)
385+
386+
match content:
387+
case MCPTextContent():
388+
return content.text
389+
case MCPImageContent():
390+
return content.data
391+
case _: # EmbeddedResource() or other types
392+
return ""
393+
394+
tool.__name__ = t.name
395+
tool.__doc__ = t.description
396+
397+
# Attach the schema and arguments to the tool.
398+
tool.__mcp_tool_schema__ = dict(
399+
name=t.name,
400+
description=t.description,
401+
parameters=t.inputSchema,
402+
)
403+
tool.__mcp_tool_args__ = tuple(t.inputSchema["properties"].keys())
404+
return tool
405+
317406
async def _handle_history(
318407
self,
319408
msg: ChatHistory,

coagent/agents/mcp_server.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Literal
44

55
from coagent.core import BaseAgent, Context, handler, logger, Message
6+
from coagent.core.messages import Cancel
67
from coagent.core.exceptions import InternalError
78
from mcp import ClientSession, Tool as MCPTool # noqa: F401
89
from mcp.client.sse import sse_client
@@ -33,7 +34,7 @@ class MCPServerSSEParams(BaseModel):
3334
class Connect(Message):
3435
"""A message to connect to the server.
3536
36-
To close the server, send a `Cancel` message to close the connection
37+
To close the server, send a `Close` message to close the connection
3738
and delete corresponding server agent.
3839
"""
3940

@@ -60,6 +61,13 @@ class Connect(Message):
6061
"""
6162

6263

64+
# A message to close the server.
65+
#
66+
# Note that this is an alias of the `Cancel` message since it's ok to close
67+
# the server by deleting the corresponding agent.
68+
Close = Cancel
69+
70+
6371
class InvalidateCache(Message):
6472
"""A message to invalidate the cache of the list result."""
6573

0 commit comments

Comments
 (0)