Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] - MCP improvements, add support for using SSE MCP servers #9642

Merged
merged 27 commits into from
Mar 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
fe3623d
init mcp client manager
ishaan-jaff Mar 29, 2025
ec283f7
use global_mcp_server_manager
ishaan-jaff Mar 29, 2025
6aa660b
mcp server manager
ishaan-jaff Mar 29, 2025
a1ec0dd
add testing mcp server
ishaan-jaff Mar 29, 2025
1cf5cba
REST API endpoint for MCP
ishaan-jaff Mar 29, 2025
b381dde
basic UI rendering of MCP tools
ishaan-jaff Mar 29, 2025
e0cff75
endpoints to list and call tools
ishaan-jaff Mar 29, 2025
a4a0830
working MCP call tool method
ishaan-jaff Mar 29, 2025
b7b9f9d
working MCP tool call logging
ishaan-jaff Mar 29, 2025
08a52f4
log MCP tool call metadata in SLP
ishaan-jaff Mar 29, 2025
fe6c033
render MCP tools on ui logs page
ishaan-jaff Mar 29, 2025
7dd5411
fix showing list of MCP tools
ishaan-jaff Mar 29, 2025
815263f
rename transform_openai_tool_call_request_to_mcp_tool_call_request
ishaan-jaff Mar 29, 2025
79e8bbb
fix types on tools.py
ishaan-jaff Mar 29, 2025
f2885bf
add code example
ishaan-jaff Mar 29, 2025
09e073d
ui mcp tools
ishaan-jaff Mar 30, 2025
047d767
fix tests for gcs pub sub
ishaan-jaff Mar 30, 2025
3e378f2
async def test_spend_logs_payload_e2e(self):
ishaan-jaff Mar 30, 2025
4e106ce
fix test
ishaan-jaff Mar 30, 2025
a3df026
fix tests
ishaan-jaff Mar 30, 2025
0e321ee
fix import errors without mcp
ishaan-jaff Mar 30, 2025
eb4b8d9
fix linting on DataTableWrapper
ishaan-jaff Mar 30, 2025
c24470e
list_tool_rest_api
ishaan-jaff Mar 30, 2025
385e8bf
fix order of imports
ishaan-jaff Mar 30, 2025
194327b
test fixes
ishaan-jaff Mar 30, 2025
3919e24
test fix
ishaan-jaff Mar 30, 2025
10486dd
fix listing mcp tools
ishaan-jaff Mar 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions litellm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@

########################### Logging Callback Constants ###########################
AZURE_STORAGE_MSFT_VERSION = "2019-07-07"
MCP_TOOL_NAME_PREFIX = "mcp_tool"

########################### LiteLLM Proxy Specific Constants ###########################
########################################################################################
Expand Down
12 changes: 7 additions & 5 deletions litellm/experimental_mcp_client/tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import List, Literal, Union
from typing import Dict, List, Literal, Union

from mcp import ClientSession
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
Expand Down Expand Up @@ -76,8 +76,8 @@ def _get_function_arguments(function: FunctionDefinition) -> dict:
return arguments if isinstance(arguments, dict) else {}


def _transform_openai_tool_call_to_mcp_tool_call_request(
openai_tool: ChatCompletionMessageToolCall,
def transform_openai_tool_call_request_to_mcp_tool_call_request(
openai_tool: Union[ChatCompletionMessageToolCall, Dict],
) -> MCPCallToolRequestParams:
"""Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams."""
function = openai_tool["function"]
Expand All @@ -100,8 +100,10 @@ async def call_openai_tool(
Returns:
The result of the MCP tool call.
"""
mcp_tool_call_request_params = _transform_openai_tool_call_to_mcp_tool_call_request(
openai_tool=openai_tool,
mcp_tool_call_request_params = (
transform_openai_tool_call_request_to_mcp_tool_call_request(
openai_tool=openai_tool,
)
)
return await call_mcp_tool(
session=session,
Expand Down
7 changes: 6 additions & 1 deletion litellm/litellm_core_utils/litellm_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
StandardCallbackDynamicParams,
StandardLoggingAdditionalHeaders,
StandardLoggingHiddenParams,
StandardLoggingMCPToolCall,
StandardLoggingMetadata,
StandardLoggingModelCostFailureDebugInformation,
StandardLoggingModelInformation,
Expand Down Expand Up @@ -1099,7 +1100,7 @@ def _success_handler_helper_fn(
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
)
elif isinstance(result, dict): # pass-through endpoints
elif isinstance(result, dict) or isinstance(result, list):
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
Expand Down Expand Up @@ -3114,6 +3115,7 @@ def get_standard_logging_metadata(
litellm_params: Optional[dict] = None,
prompt_integration: Optional[str] = None,
applied_guardrails: Optional[List[str]] = None,
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None,
) -> StandardLoggingMetadata:
"""
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
Expand Down Expand Up @@ -3160,6 +3162,7 @@ def get_standard_logging_metadata(
user_api_key_end_user_id=None,
prompt_management_metadata=prompt_management_metadata,
applied_guardrails=applied_guardrails,
mcp_tool_call_metadata=mcp_tool_call_metadata,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys
Expand Down Expand Up @@ -3486,6 +3489,7 @@ def get_standard_logging_object_payload(
litellm_params=litellm_params,
prompt_integration=kwargs.get("prompt_integration", None),
applied_guardrails=kwargs.get("applied_guardrails", None),
mcp_tool_call_metadata=kwargs.get("mcp_tool_call_metadata", None),
)

_request_body = proxy_server_request.get("body", {})
Expand Down Expand Up @@ -3626,6 +3630,7 @@ def get_standard_logging_metadata(
user_api_key_end_user_id=None,
prompt_management_metadata=None,
applied_guardrails=None,
mcp_tool_call_metadata=None,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys
Expand Down
153 changes: 153 additions & 0 deletions litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""
MCP Client Manager

This class is responsible for managing MCP SSE clients.

This is a Proxy
"""

import asyncio
import json
from typing import Any, Dict, List, Optional

from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.types import Tool as MCPTool

from litellm._logging import verbose_logger
from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPSSEServer


class MCPServerManager:
def __init__(self):
self.mcp_servers: List[MCPSSEServer] = []
"""
eg.
[
{
"name": "zapier_mcp_server",
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
},
{
"name": "google_drive_mcp_server",
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
}
]
"""

self.tool_name_to_mcp_server_name_mapping: Dict[str, str] = {}
"""
{
"gmail_send_email": "zapier_mcp_server",
}
"""

def load_servers_from_config(self, mcp_servers_config: Dict[str, Any]):
"""
Load the MCP Servers from the config
"""
for server_name, server_config in mcp_servers_config.items():
_mcp_info: dict = server_config.get("mcp_info", None) or {}
mcp_info = MCPInfo(**_mcp_info)
mcp_info["server_name"] = server_name
self.mcp_servers.append(
MCPSSEServer(
name=server_name,
url=server_config["url"],
mcp_info=mcp_info,
)
)
verbose_logger.debug(
f"Loaded MCP Servers: {json.dumps(self.mcp_servers, indent=4, default=str)}"
)

self.initialize_tool_name_to_mcp_server_name_mapping()

async def list_tools(self) -> List[MCPTool]:
"""
List all tools available across all MCP Servers.

Returns:
List[MCPTool]: Combined list of tools from all servers
"""
list_tools_result: List[MCPTool] = []
verbose_logger.debug("SSE SERVER MANAGER LISTING TOOLS")

for server in self.mcp_servers:
tools = await self._get_tools_from_server(server)
list_tools_result.extend(tools)

return list_tools_result

async def _get_tools_from_server(self, server: MCPSSEServer) -> List[MCPTool]:
"""
Helper method to get tools from a single MCP server.

Args:
server (MCPSSEServer): The server to query tools from

Returns:
List[MCPTool]: List of tools available on the server
"""
verbose_logger.debug(f"Connecting to url: {server.url}")

async with sse_client(url=server.url) as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()

tools_result = await session.list_tools()
verbose_logger.debug(f"Tools from {server.name}: {tools_result}")

# Update tool to server mapping
for tool in tools_result.tools:
self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name

return tools_result.tools

def initialize_tool_name_to_mcp_server_name_mapping(self):
"""
On startup, initialize the tool name to MCP server name mapping
"""
try:
if asyncio.get_running_loop():
asyncio.create_task(
self._initialize_tool_name_to_mcp_server_name_mapping()
)
except RuntimeError as e: # no running event loop
verbose_logger.exception(
f"No running event loop - skipping tool name to MCP server name mapping initialization: {str(e)}"
)

async def _initialize_tool_name_to_mcp_server_name_mapping(self):
"""
Call list_tools for each server and update the tool name to MCP server name mapping
"""
for server in self.mcp_servers:
tools = await self._get_tools_from_server(server)
for tool in tools:
self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name

async def call_tool(self, name: str, arguments: Dict[str, Any]):
"""
Call a tool with the given name and arguments
"""
mcp_server = self._get_mcp_server_from_tool_name(name)
if mcp_server is None:
raise ValueError(f"Tool {name} not found")
async with sse_client(url=mcp_server.url) as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()
return await session.call_tool(name, arguments)

def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPSSEServer]:
"""
Get the MCP Server from the tool name
"""
if tool_name in self.tool_name_to_mcp_server_name_mapping:
for server in self.mcp_servers:
if server.name == self.tool_name_to_mcp_server_name_mapping[tool_name]:
return server
return None


global_mcp_server_manager: MCPServerManager = MCPServerManager()
Loading
Loading