8
8
9
9
from coagent .core import Address , BaseAgent , Context , handler , logger
10
10
from coagent .core .agent import is_async_iterator
11
+ import jsonschema
11
12
from pydantic_core import PydanticUndefined
12
13
from pydantic .fields import FieldInfo
13
14
14
15
from .aswarm import Agent as SwarmAgent , Swarm
15
16
from .aswarm .util import function_to_jsonschema
17
+ from .mcp_server import (
18
+ CallTool ,
19
+ CallToolResult ,
20
+ ListTools ,
21
+ ListToolsResult ,
22
+ MCPTool ,
23
+ MCPImageContent ,
24
+ MCPTextContent ,
25
+ )
16
26
from .messages import ChatMessage , ChatHistory , StructuredOutput
17
27
from .model_client import default_model_client , ModelClient
18
28
from .util import is_user_confirmed
@@ -207,6 +217,8 @@ def __init__(
207
217
name : str = "" ,
208
218
system : str = "" ,
209
219
tools : list [Callable ] | None = None ,
220
+ mcp_servers : list [str ] | None = None ,
221
+ mcp_server_agent_type : str = "mcp_server" ,
210
222
client : ModelClient = default_model_client ,
211
223
timeout : float = 300 ,
212
224
):
@@ -215,6 +227,8 @@ def __init__(
215
227
self ._name : str = name
216
228
self ._system : str = system
217
229
self ._tools : list [Callable ] = tools or []
230
+ self ._mcp_servers : list [str ] = mcp_servers or []
231
+ self ._mcp_server_agent_type : str = mcp_server_agent_type
218
232
self ._client : ModelClient = client
219
233
220
234
self ._swarm_client : Swarm = Swarm (self .client )
@@ -239,6 +253,14 @@ def system(self) -> str:
239
253
def tools (self ) -> list [Callable ]:
240
254
return self ._tools
241
255
256
+ @property
257
+ def mcp_servers (self ) -> list [str ]:
258
+ return self ._mcp_servers
259
+
260
+ @property
261
+ def mcp_server_agent_type (self ) -> str :
262
+ return self ._mcp_server_agent_type
263
+
242
264
@property
243
265
def client (self ) -> ModelClient :
244
266
return self ._client
@@ -264,11 +286,17 @@ def get_swarm_client(self, extensions: dict) -> Swarm:
264
286
async def get_swarm_agent (self ) -> SwarmAgent :
265
287
if not self ._swarm_agent :
266
288
tools = self .tools [:] # copy
289
+
290
+ # Collect all methods marked as tools.
267
291
methods = inspect .getmembers (self , predicate = inspect .ismethod )
268
292
for _name , meth in methods :
269
293
if getattr (meth , "is_tool" , False ):
270
294
tools .append (meth )
271
295
296
+ # Collect all tools from MCP servers.
297
+ mcp_tools = await self ._get_mcp_tools (self .mcp_servers )
298
+ tools .extend (mcp_tools )
299
+
272
300
self ._swarm_agent = SwarmAgent (
273
301
name = self .name ,
274
302
model = self .client .model ,
@@ -314,6 +342,67 @@ async def handle_structured_output(
314
342
async for resp in response :
315
343
yield resp
316
344
345
+ async def _get_mcp_tools (self , mcp_servers : list [str ]) -> list [Callable ]:
346
+ all_tools = []
347
+
348
+ for server in mcp_servers :
349
+ raw_result = await self .channel .publish (
350
+ Address (name = self .mcp_server_agent_type , id = server ),
351
+ ListTools ().encode (),
352
+ request = True ,
353
+ timeout = 10 ,
354
+ )
355
+ result = ListToolsResult .decode (raw_result )
356
+
357
+ tools = [self ._to_function_tool (server , t ) for t in result .tools ]
358
+ all_tools .extend (tools )
359
+
360
+ return all_tools
361
+
362
+ def _to_function_tool (self , server : str , t : MCPTool ) -> Callable :
363
+ async def tool (** kwargs ) -> Any :
364
+ # Validate the input against the schema
365
+ jsonschema .validate (instance = kwargs , schema = t .inputSchema )
366
+
367
+ # Actually call the tool.
368
+ raw_result = await self .channel .publish (
369
+ Address (name = self .mcp_server_agent_type , id = server ),
370
+ CallTool (
371
+ name = t .name ,
372
+ arguments = kwargs ,
373
+ ).encode (),
374
+ request = True ,
375
+ timeout = 10 ,
376
+ )
377
+ result = CallToolResult .decode (raw_result )
378
+
379
+ if not result .content :
380
+ return ""
381
+ content = result .content [0 ]
382
+
383
+ if result .isError :
384
+ raise ValueError (content .text )
385
+
386
+ match content :
387
+ case MCPTextContent ():
388
+ return content .text
389
+ case MCPImageContent ():
390
+ return content .data
391
+ case _: # EmbeddedResource() or other types
392
+ return ""
393
+
394
+ tool .__name__ = t .name
395
+ tool .__doc__ = t .description
396
+
397
+ # Attach the schema and arguments to the tool.
398
+ tool .__mcp_tool_schema__ = dict (
399
+ name = t .name ,
400
+ description = t .description ,
401
+ parameters = t .inputSchema ,
402
+ )
403
+ tool .__mcp_tool_args__ = tuple (t .inputSchema ["properties" ].keys ())
404
+ return tool
405
+
317
406
async def _handle_history (
318
407
self ,
319
408
msg : ChatHistory ,
0 commit comments