Skip to content

Commit a2874f4

Browse files
Experimental MCP propagation (#1103)
Co-authored-by: Alex Hall <[email protected]>
1 parent ec577a5 commit a2874f4

File tree

3 files changed

+210
-67
lines changed

3 files changed

+210
-67
lines changed

logfire/_internal/integrations/mcp.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,38 @@
11
from __future__ import annotations
22

33
import functools
4+
from contextlib import ExitStack, contextmanager
45
from typing import TYPE_CHECKING, Any
56

67
from mcp.client.session import ClientSession
8+
from mcp.server import Server
79
from mcp.shared.session import BaseSession
8-
from mcp.types import CallToolRequest, LoggingMessageNotification
10+
from mcp.types import CallToolRequest, LoggingMessageNotification, RequestParams
11+
from pydantic import TypeAdapter
912

1013
from logfire._internal.utils import handle_internal_errors
14+
from logfire.propagate import attach_context, get_context
1115

1216
if TYPE_CHECKING:
1317
from logfire import LevelName, Logfire
1418

1519

16-
def instrument_mcp(logfire_instance: Logfire):
20+
def instrument_mcp(logfire_instance: Logfire, propagate_otel_context: bool):
1721
logfire_instance = logfire_instance.with_settings(custom_scope_suffix='mcp')
1822

1923
original_send_request = BaseSession.send_request # type: ignore
2024

2125
@functools.wraps(original_send_request) # type: ignore
2226
async def send_request(self: Any, request: Any, *args: Any, **kwargs: Any):
27+
root = request.root
2328
attributes: dict[str, Any] = {
24-
'request': request,
29+
'request': root,
2530
# https://opentelemetry.io/docs/specs/semconv/rpc/json-rpc/
2631
'rpc.system': 'jsonrpc',
2732
'rpc.jsonrpc.version': '2.0',
2833
}
2934
span_name = 'MCP request'
3035

31-
root = request.root
3236
# method should always exist, but it's had to verify because the request type is a RootModel
3337
# of a big union, instead of just using a base class with a method attribute.
3438
if method := getattr(root, 'method', None): # pragma: no branch
@@ -38,6 +42,20 @@ async def send_request(self: Any, request: Any, *args: Any, **kwargs: Any):
3842
span_name += f' {root.params.name}'
3943

4044
with logfire_instance.span(span_name, **attributes) as span:
45+
with handle_internal_errors:
46+
if propagate_otel_context: # pragma: no branch
47+
carrier = get_context()
48+
if params := getattr(root, 'params', None):
49+
if meta := getattr(params, 'meta', None): # pragma: no cover # TODO
50+
dumped_meta = meta.model_dump()
51+
else:
52+
dumped_meta = {}
53+
# Prioritise existing values in meta over the context carrier.
54+
# RequestParams.Meta should allow basically anything, we're being extra careful here.
55+
params.meta = RequestParams.Meta.model_validate({**carrier, **dumped_meta})
56+
else:
57+
root.params = _request_params_type_adapter(type(root)).validate_python({'_meta': carrier}) # type: ignore
58+
4159
result = await original_send_request(self, request, *args, **kwargs)
4260
span.set_attribute('response', result)
4361
return result
@@ -63,3 +81,44 @@ async def _received_notification(self: Any, notification: Any, *args: Any, **kwa
6381
await original_received_notification(self, notification, *args, **kwargs)
6482

6583
ClientSession._received_notification = _received_notification # type: ignore
84+
85+
original_handle_client_request = ClientSession._received_request # type: ignore
86+
87+
@functools.wraps(original_handle_client_request)
88+
async def _received_request_client(self: Any, responder: Any) -> None: # pragma: no cover
89+
request = responder.request.root
90+
span_name = 'MCP client handle request'
91+
with _handle_request_with_context(request, span_name):
92+
await original_handle_client_request(self, responder)
93+
94+
ClientSession._received_request = _received_request_client # type: ignore
95+
96+
original_handle_server_request = Server._handle_request # type: ignore
97+
98+
@functools.wraps(original_handle_server_request) # type: ignore
99+
async def _handle_request(self: Any, message: Any, request: Any, *args: Any, **kwargs: Any) -> Any:
100+
span_name = 'MCP server handle request'
101+
with _handle_request_with_context(request, span_name):
102+
return await original_handle_server_request(self, message, request, *args, **kwargs)
103+
104+
Server._handle_request = _handle_request # type: ignore
105+
106+
@contextmanager
107+
def _handle_request_with_context(request: Any, span_name: str):
108+
with ExitStack() as exit_stack:
109+
if ( # pragma: no branch
110+
propagate_otel_context
111+
and (params := getattr(request, 'params', None))
112+
and (meta := getattr(params, 'meta', None))
113+
):
114+
exit_stack.enter_context(attach_context(meta.model_dump()))
115+
if method := getattr(request, 'method', None): # pragma: no branch
116+
span_name += f': {method}'
117+
with logfire_instance.span(span_name, request=request):
118+
yield
119+
120+
121+
@functools.lru_cache
122+
def _request_params_type_adapter(root_type: Any):
123+
params_type = root_type.model_fields['params'].annotation
124+
return TypeAdapter(params_type)

logfire/_internal/main.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -900,12 +900,17 @@ def install_auto_tracing(
900900
def _warn_if_not_initialized_for_instrumentation(self):
901901
self.config.warn_if_not_initialized('Instrumentation will have no effect')
902902

903-
def instrument_mcp(self) -> None:
904-
"""Instrument [MCP](https://modelcontextprotocol.io/) requests such as tool calls."""
903+
def instrument_mcp(self, *, propagate_otel_context: bool = True) -> None:
904+
"""Instrument [MCP](https://modelcontextprotocol.io/) requests such as tool calls.
905+
906+
Args:
907+
propagate_otel_context: Whether to enable propagation of the OpenTelemetry context.
908+
Set to False to prevent setting extra fields like `traceparent` on the metadata of requests.
909+
"""
905910
from .integrations.mcp import instrument_mcp
906911

907912
self._warn_if_not_initialized_for_instrumentation()
908-
instrument_mcp(self)
913+
instrument_mcp(self, propagate_otel_context)
909914

910915
def instrument_pydantic(
911916
self,

0 commit comments

Comments
 (0)