1
1
from __future__ import annotations
2
2
3
3
import functools
4
+ from contextlib import ExitStack , contextmanager
4
5
from typing import TYPE_CHECKING , Any
5
6
6
7
from mcp .client .session import ClientSession
8
+ from mcp .server import Server
7
9
from mcp .shared .session import BaseSession
8
- from mcp .types import CallToolRequest , LoggingMessageNotification
10
+ from mcp .types import CallToolRequest , LoggingMessageNotification , RequestParams
11
+ from pydantic import TypeAdapter
9
12
10
13
from logfire ._internal .utils import handle_internal_errors
14
+ from logfire .propagate import attach_context , get_context
11
15
12
16
if TYPE_CHECKING :
13
17
from logfire import LevelName , Logfire
14
18
15
19
16
- def instrument_mcp (logfire_instance : Logfire ):
20
+ def instrument_mcp (logfire_instance : Logfire , propagate_otel_context : bool ):
17
21
logfire_instance = logfire_instance .with_settings (custom_scope_suffix = 'mcp' )
18
22
19
23
original_send_request = BaseSession .send_request # type: ignore
20
24
21
25
@functools .wraps (original_send_request ) # type: ignore
22
26
async def send_request (self : Any , request : Any , * args : Any , ** kwargs : Any ):
27
+ root = request .root
23
28
attributes : dict [str , Any ] = {
24
- 'request' : request ,
29
+ 'request' : root ,
25
30
# https://opentelemetry.io/docs/specs/semconv/rpc/json-rpc/
26
31
'rpc.system' : 'jsonrpc' ,
27
32
'rpc.jsonrpc.version' : '2.0' ,
28
33
}
29
34
span_name = 'MCP request'
30
35
31
- root = request .root
32
36
# method should always exist, but it's had to verify because the request type is a RootModel
33
37
# of a big union, instead of just using a base class with a method attribute.
34
38
if method := getattr (root , 'method' , None ): # pragma: no branch
@@ -38,6 +42,20 @@ async def send_request(self: Any, request: Any, *args: Any, **kwargs: Any):
38
42
span_name += f' { root .params .name } '
39
43
40
44
with logfire_instance .span (span_name , ** attributes ) as span :
45
+ with handle_internal_errors :
46
+ if propagate_otel_context : # pragma: no branch
47
+ carrier = get_context ()
48
+ if params := getattr (root , 'params' , None ):
49
+ if meta := getattr (params , 'meta' , None ): # pragma: no cover # TODO
50
+ dumped_meta = meta .model_dump ()
51
+ else :
52
+ dumped_meta = {}
53
+ # Prioritise existing values in meta over the context carrier.
54
+ # RequestParams.Meta should allow basically anything, we're being extra careful here.
55
+ params .meta = RequestParams .Meta .model_validate ({** carrier , ** dumped_meta })
56
+ else :
57
+ root .params = _request_params_type_adapter (type (root )).validate_python ({'_meta' : carrier }) # type: ignore
58
+
41
59
result = await original_send_request (self , request , * args , ** kwargs )
42
60
span .set_attribute ('response' , result )
43
61
return result
@@ -63,3 +81,44 @@ async def _received_notification(self: Any, notification: Any, *args: Any, **kwa
63
81
await original_received_notification (self , notification , * args , ** kwargs )
64
82
65
83
ClientSession ._received_notification = _received_notification # type: ignore
84
+
85
+ original_handle_client_request = ClientSession ._received_request # type: ignore
86
+
87
+ @functools .wraps (original_handle_client_request )
88
+ async def _received_request_client (self : Any , responder : Any ) -> None : # pragma: no cover
89
+ request = responder .request .root
90
+ span_name = 'MCP client handle request'
91
+ with _handle_request_with_context (request , span_name ):
92
+ await original_handle_client_request (self , responder )
93
+
94
+ ClientSession ._received_request = _received_request_client # type: ignore
95
+
96
+ original_handle_server_request = Server ._handle_request # type: ignore
97
+
98
+ @functools .wraps (original_handle_server_request ) # type: ignore
99
+ async def _handle_request (self : Any , message : Any , request : Any , * args : Any , ** kwargs : Any ) -> Any :
100
+ span_name = 'MCP server handle request'
101
+ with _handle_request_with_context (request , span_name ):
102
+ return await original_handle_server_request (self , message , request , * args , ** kwargs )
103
+
104
+ Server ._handle_request = _handle_request # type: ignore
105
+
106
+ @contextmanager
107
+ def _handle_request_with_context (request : Any , span_name : str ):
108
+ with ExitStack () as exit_stack :
109
+ if ( # pragma: no branch
110
+ propagate_otel_context
111
+ and (params := getattr (request , 'params' , None ))
112
+ and (meta := getattr (params , 'meta' , None ))
113
+ ):
114
+ exit_stack .enter_context (attach_context (meta .model_dump ()))
115
+ if method := getattr (request , 'method' , None ): # pragma: no branch
116
+ span_name += f': { method } '
117
+ with logfire_instance .span (span_name , request = request ):
118
+ yield
119
+
120
+
121
+ @functools .lru_cache
122
+ def _request_params_type_adapter (root_type : Any ):
123
+ params_type = root_type .model_fields ['params' ].annotation
124
+ return TypeAdapter (params_type )
0 commit comments