Skip to content

Commit 2a1e1ae

Browse files
committed
Merge pull request 'feat: stream all mcp logs' (#505) from listen-log into main
Reviewed-on: https://git.biggo.com/Funmula/dive-mcp-host/pulls/505
2 parents 1b327ae + 398d551 commit 2a1e1ae

File tree

6 files changed

+165
-51
lines changed

6 files changed

+165
-51
lines changed

dive_mcp_host/host/errors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,5 @@ class LogBufferNotFoundError(MCPHostError):
115115

116116
def __init__(self, name: str) -> None:
117117
"""Initialize the error."""
118+
self.mcp_name = name
118119
super().__init__(f"Log buffer {name} not found")

dive_mcp_host/host/tools/log.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
import asyncio
4343
import sys
4444
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Coroutine
45-
from contextlib import asynccontextmanager
45+
from contextlib import AsyncExitStack, asynccontextmanager
4646
from datetime import UTC, datetime
4747
from enum import StrEnum
4848
from logging import INFO, getLogger
@@ -373,30 +373,53 @@ async def register_buffer(self, buffer: LogBuffer) -> AsyncGenerator[None, None]
373373
@asynccontextmanager
374374
async def listen_log(
375375
self,
376-
name: str,
376+
names: list[str],
377377
listener: Callable[[LogMsg], Coroutine[None, None, None]],
378378
) -> AsyncGenerator[None, None]:
379-
"""Listen to log updates from a specific MCP server.
379+
"""Listen to log updates from MCP servers.
380380
381381
The listener is a context manager,
382382
user can decide how long it will listen to the buffer.
383383
384-
Only buffers that are registered to the log manager
385-
can be listened to. If the buffer is not registered,
386-
`LogBufferNotFoundError` will be raised.
384+
Args:
385+
names: The names of the MCP servers to listen to.
386+
listener: The callback function to call when a log is received.
387+
388+
Raises:
389+
LogBufferNotFoundError: If a specific name is provided but not found.
387390
388391
Example:
389392
```python
390393
async def listener(log: LogMsg) -> None:
391394
print(log)
392395
393396
394-
async with log_manager.listen_log(buffer.name, listener):
397+
async with log_manager.listen_log(["mcp_server"], listener):
395398
await asyncio.sleep(10)
396399
```
397400
"""
398-
buffer = self._buffers.get(name)
399-
if buffer is None:
400-
raise LogBufferNotFoundError(name)
401-
async with buffer.add_listener(listener):
402-
yield
401+
if not names:
402+
raise LogBufferNotFoundError("no name provided")
403+
404+
async with AsyncExitStack() as stack:
405+
found_buffer: list[LogBuffer] = []
406+
not_found: list[str] = []
407+
408+
for name in names:
409+
if buffer := self._buffers.get(name):
410+
found_buffer.append(self._buffers[name])
411+
else:
412+
not_found.append(name)
413+
414+
if not_found:
415+
raise LogBufferNotFoundError(",".join(not_found))
416+
417+
for buffer in found_buffer:
418+
await stack.enter_async_context(buffer.add_listener(listener))
419+
420+
try:
421+
yield
422+
except Exception:
423+
raise
424+
finally:
425+
pass

dive_mcp_host/httpd/routers/tools.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from fastapi import APIRouter, Depends
44
from fastapi.responses import StreamingResponse
5-
from pydantic import ValidationError
5+
from pydantic import BaseModel, ValidationError
66

77
from dive_mcp_host.host.tools.model_types import ClientState
88
from dive_mcp_host.httpd.conf.mcp_servers import Config
@@ -125,21 +125,28 @@ async def list_tools( # noqa: PLR0912, C901
125125
return ToolsResult(success=True, message=None, tools=list(result.values()))
126126

127127

128-
@tools.get("/logs/stream")
128+
class LogsStreamBody(BaseModel):
129+
"""Body for logs stream API."""
130+
131+
names: list[str]
132+
stream_until: ClientState | None = None
133+
stop_on_notfound: bool = True
134+
max_retries: int = 10
135+
136+
137+
@tools.post("/logs/stream")
129138
async def stream_server_logs(
130-
server_name: str,
131-
stream_until: ClientState | None = None,
132-
stop_on_notfound: bool = True,
133-
max_retries: int = 10,
139+
body: LogsStreamBody,
134140
app: DiveHostAPI = Depends(get_app),
135141
) -> StreamingResponse:
136-
"""Stream logs from a specific MCP server.
142+
"""Stream logs from MCP servers.
137143
138144
Args:
139-
server_name (str): The name of the MCP server to stream logs from.
140-
stream_until (ClientState | None): stream until client state is reached.
141-
stop_on_notfound (bool): If True, stop streaming if the server is not found.
142-
max_retries (int): The maximum number of retries to stream logs.
145+
body (LogsStreamBody):
146+
- names: MCP servers to listen for logs
147+
- stream_until: Stream until mcp server state matches the provided state
148+
- stop_on_notfound: Stop streaming if mcp server is not found
149+
- max_retries: Retry N times to listen for logs
143150
app (DiveHostAPI): The DiveHostAPI instance.
144151
145152
Returns:
@@ -155,11 +162,12 @@ async def process() -> None:
155162
processor = LogStreamHandler(
156163
stream=stream,
157164
log_manager=log_manager,
158-
stream_until=stream_until,
159-
stop_on_notfound=stop_on_notfound,
160-
max_retries=max_retries,
165+
stream_until=body.stream_until,
166+
stop_on_notfound=body.stop_on_notfound,
167+
max_retries=body.max_retries,
168+
server_names=body.names,
161169
)
162-
await processor.stream_logs(server_name)
170+
await processor.stream_logs()
163171

164172
stream.add_task(process)
165173
return response

dive_mcp_host/httpd/routers/utils.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -814,13 +814,14 @@ async def _get_history_user_input(
814814
class LogStreamHandler:
815815
"""Handles streaming of logs."""
816816

817-
def __init__(
817+
def __init__( # noqa: PLR0913
818818
self,
819819
stream: EventStreamContextManager,
820820
log_manager: LogManager,
821821
stream_until: ClientState | None = None,
822822
stop_on_notfound: bool = True,
823823
max_retries: int = 10,
824+
server_names: list[str] | None = None,
824825
) -> None:
825826
"""Initialize the log processor."""
826827
self._stream = stream
@@ -836,26 +837,34 @@ def __init__(
836837
if stream_until:
837838
self._stream_until.add(stream_until)
838839

840+
self._servers_reached_target: set[str] = set()
841+
self._server_names: set[str] = set(server_names) if server_names else set()
842+
839843
async def _log_listener(self, msg: LogMsg) -> None:
840844
await self._stream.write(msg.model_dump_json())
841845
if msg.client_state in self._stream_until:
842-
self._end_event.set()
846+
self._servers_reached_target.add(msg.mcp_server_name)
847+
if self._server_names == self._servers_reached_target:
848+
self._end_event.set()
843849

844-
async def stream_logs(self, server_name: str) -> None:
845-
"""Stream logs from specific MCP server.
850+
async def stream_logs(self) -> None:
851+
"""Stream logs from MCP servers.
846852
847853
Keep the connection open until client disconnects or
848854
client state is reached.
849855
856+
Streams logs from the given server names.
857+
850858
If self._stop_on_notfound is False, it will keep retrying until
851859
the log buffer is found or max retries is reached.
852860
"""
853861
while self._max_retries > 0:
854862
self._max_retries -= 1
863+
self._servers_reached_target = set()
855864

856865
try:
857866
async with self._log_manager.listen_log(
858-
name=server_name,
867+
names=list(self._server_names),
859868
listener=self._log_listener,
860869
):
861870
with suppress(asyncio.CancelledError):
@@ -864,28 +873,30 @@ async def stream_logs(self, server_name: str) -> None:
864873
except LogBufferNotFoundError as e:
865874
logger.warning(
866875
"Log buffer not found for server %s, retries left: %d",
867-
server_name,
876+
e.mcp_name,
868877
self._max_retries,
869878
)
870879

871880
msg = LogMsg(
872881
event=LogEvent.STREAMING_ERROR,
873882
body=f"Error streaming logs: {e}",
874-
mcp_server_name=server_name,
883+
mcp_server_name=e.mcp_name,
875884
)
876885
await self._stream.write(msg.model_dump_json())
877886

878887
if self._stop_on_notfound or self._max_retries == 0:
879888
break
880889

881-
await asyncio.sleep(1)
890+
await asyncio.sleep(0.5)
882891

883892
except Exception as e:
884-
logger.exception("Error in log streaming for server %s", server_name)
893+
logger.exception(
894+
"Error in log streaming for servers %s", self._server_names
895+
)
885896
msg = LogMsg(
886897
event=LogEvent.STREAMING_ERROR,
887898
body=f"Error streaming logs: {e}",
888-
mcp_server_name=server_name,
899+
mcp_server_name="unknown",
889900
)
890901
await self._stream.write(msg.model_dump_json())
891902
break

tests/httpd/routers/test_tools.py

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -377,9 +377,7 @@ def test_tools_cache_after_update(test_client):
377377
def test_stream_logs_notfound(test_client: tuple[TestClient, DiveHostAPI]):
378378
"""Test stream_logs function with not found server."""
379379
client, _ = test_client
380-
response = client.get(
381-
"/api/tools/logs/stream", params={"server_name": "missing_server"}
382-
)
380+
response = client.post("/api/tools/logs/stream", json={"names": ["missing_server"]})
383381
for line in response.iter_lines():
384382
content = line.removeprefix("data: ")
385383
if content in ("[DONE]", ""):
@@ -397,7 +395,7 @@ def test_stream_logs_notfound_wait(test_client: tuple[TestClient, DiveHostAPI]):
397395

398396
def update_tools():
399397
time.sleep(2)
400-
_ = client.post(
398+
response = client.post(
401399
"/api/config/mcpserver",
402400
json={
403401
"mcpServers": {
@@ -418,12 +416,12 @@ def update_tools():
418416

419417
with ThreadPoolExecutor(1) as executor:
420418
executor.submit(update_tools)
421-
response = client.get(
419+
response = client.post(
422420
"/api/tools/logs/stream",
423-
params={
424-
"server_name": "missing_server",
421+
json={
422+
"names": ["missing_server"],
425423
"stop_on_notfound": False,
426-
"max_retries": 5,
424+
"max_retries": 10,
427425
"stream_until": "running",
428426
},
429427
)
@@ -451,7 +449,7 @@ def test_stream_logs_name_with_slash(test_client: tuple[TestClient, DiveHostAPI]
451449
client, _ = test_client
452450

453451
def update_tools():
454-
_ = client.post(
452+
response = client.post(
455453
"/api/config/mcpserver",
456454
json={
457455
"mcpServers": {
@@ -472,10 +470,10 @@ def update_tools():
472470

473471
with ThreadPoolExecutor(1) as executor:
474472
executor.submit(update_tools)
475-
response = client.get(
473+
response = client.post(
476474
"/api/tools/logs/stream",
477-
params={
478-
"server_name": "name/with/slash",
475+
json={
476+
"names": ["name/with/slash"],
479477
"stop_on_notfound": False,
480478
"max_retries": 5,
481479
"stream_until": "running",
@@ -498,3 +496,76 @@ def update_tools():
498496

499497
assert responses[-1].event == LogEvent.STATUS_CHANGE
500498
assert responses[-1].client_state == ClientState.RUNNING
499+
500+
501+
def test_stream_multiple_logs(test_client: tuple[TestClient, DiveHostAPI]):
502+
"""Test streaming multiple server logs."""
503+
client, _ = test_client
504+
505+
def setup_multiple_servers():
506+
time.sleep(1)
507+
response = client.post(
508+
"/api/config/mcpserver",
509+
json={
510+
"mcpServers": {
511+
"echo": {
512+
"transport": "stdio",
513+
"enabled": True,
514+
"command": "python",
515+
"args": [
516+
"-m",
517+
"dive_mcp_host.host.tools.echo",
518+
"--transport=stdio",
519+
],
520+
},
521+
"server_two": {
522+
"transport": "stdio",
523+
"enabled": True,
524+
"command": "python",
525+
"args": [
526+
"-m",
527+
"dive_mcp_host.host.tools.echo",
528+
"--transport=stdio",
529+
],
530+
},
531+
}
532+
},
533+
)
534+
assert response.status_code == status.HTTP_200_OK
535+
536+
with ThreadPoolExecutor(1) as executer:
537+
executer.submit(setup_multiple_servers)
538+
response = client.post(
539+
"/api/tools/logs/stream",
540+
json={
541+
"names": ["echo", "server_two"],
542+
"stream_until": "running",
543+
"stop_on_notfound": False,
544+
"max_retries": 5,
545+
},
546+
)
547+
responses: list[LogMsg] = []
548+
server_names: set[str] = set()
549+
servers_reached_running: set[str] = set()
550+
551+
for line in response.iter_lines():
552+
content = line.removeprefix("data: ")
553+
if content in ("[DONE]", ""):
554+
continue
555+
556+
data = LogMsg.model_validate_json(content)
557+
responses.append(data)
558+
server_names.add(data.mcp_server_name)
559+
560+
if data.client_state == ClientState.RUNNING:
561+
servers_reached_running.add(data.mcp_server_name)
562+
563+
assert len(responses) > 0, "Should receive logs"
564+
assert len(server_names) >= 2, "Should receive logs from multiple servers"
565+
assert "echo" in server_names or "server_two" in server_names
566+
567+
running_states = [r for r in responses if r.client_state == ClientState.RUNNING]
568+
assert len(running_states) >= 2, "Should have at least 2 servers reach RUNNING"
569+
assert len(servers_reached_running) >= 2, (
570+
"At least 2 different servers should reach RUNNING state"
571+
)

tests/test_log.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ async def test_listener(msg: LogMsg) -> None:
205205
await log_buffer.push_log(test_log)
206206

207207
# Listen to logs
208-
async with log_manager.listen_log("test_listen", test_listener):
208+
async with log_manager.listen_log(["test_listen"], test_listener):
209209
# Should receive the existing log
210210
assert len(captured_logs) == 1
211211
assert captured_logs[0].body == "before listener"
@@ -240,7 +240,7 @@ async def test_listener(msg: LogMsg) -> None:
240240

241241
# Try to listen to a buffer that doesn't exist
242242
with pytest.raises(LogBufferNotFoundError):
243-
async with log_manager.listen_log("nonexistent", test_listener):
243+
async with log_manager.listen_log(["nonexistent"], test_listener):
244244
pass
245245
finally:
246246
import shutil

0 commit comments

Comments
 (0)