Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add MCPToolkitManager to manage multiple MCPToolkits. #1817

Merged
merged 10 commits into from
Mar 12, 2025
1 change: 1 addition & 0 deletions camel/toolkits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
'SymPyToolkit',
'MinerUToolkit',
'MCPToolkit',
'MCPToolkitManager',
'AudioAnalysisToolkit',
'ExcelToolkit',
'VideoAnalysisToolkit',
Expand Down
276 changes: 258 additions & 18 deletions camel/toolkits/mcp_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import inspect
import json
import os
from contextlib import AsyncExitStack, asynccontextmanager
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Callable,
Dict,
List,
Expand All @@ -28,13 +31,16 @@
if TYPE_CHECKING:
from mcp import ListToolsResult, Tool

from camel.logger import get_logger
from camel.toolkits import BaseToolkit, FunctionTool

logger = get_logger(__name__)

class MCPToolkit(BaseToolkit):
r"""MCPToolkit provides an abstraction layer to interact with external
tools using the Model Context Protocol (MCP). It supports two modes of
connection:

class _MCPServer(BaseToolkit):
r"""Internal class that provides an abstraction layer to interact with
external tools using the Model Context Protocol (MCP). It supports two
modes of connection:

1. stdio mode: Connects via standard input/output streams for local
command-line interactions.
Expand Down Expand Up @@ -73,20 +79,20 @@ def __init__(
self._exit_stack = AsyncExitStack()
self._is_connected = False

@asynccontextmanager
async def connection(self):
r"""Async context manager for establishing and managing the connection
with the MCP server. Automatically selects SSE or stdio mode based
on the provided `command_or_url`.
async def connect(self):
r"""Explicitly connect to the MCP server.

Yields:
MCPToolkit: Instance with active connection ready for tool
interaction.
Returns:
_MCPServer: The connected server instance
"""
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client

if self._is_connected:
logger.warning("Server is already connected")
return self

try:
if urlparse(self.command_or_url).scheme in ("http", "https"):
(
Expand All @@ -113,12 +119,33 @@ async def connection(self):
list_tools_result = await self.list_mcp_tools()
self._mcp_tools = list_tools_result.tools
self._is_connected = True
yield self
return self
except Exception as e:
# Ensure resources are cleaned up on connection failure
await self.disconnect()
logger.error(f"Failed to connect to MCP server: {e}")

async def disconnect(self):
r"""Explicitly disconnect from the MCP server."""
self._is_connected = False
await self._exit_stack.aclose()
self._session = None

@asynccontextmanager
async def connection(self):
r"""Async context manager for establishing and managing the connection
with the MCP server. Automatically selects SSE or stdio mode based
on the provided `command_or_url`.

Yields:
_MCPServer: Instance with active connection ready for tool
interaction.
"""
try:
await self.connect()
yield self
finally:
self._is_connected = False
await self._exit_stack.aclose()
self._session = None
await self.disconnect()

async def list_mcp_tools(self) -> Union[str, "ListToolsResult"]:
r"""Retrieves the list of available tools from the connected MCP
Expand Down Expand Up @@ -188,9 +215,10 @@ async def dynamic_function(**kwargs):
kwargs.keys()
)
if missing_params:
raise ValueError(
logger.warning(
f"Missing required parameters: {missing_params}"
)
return "Missing required parameters."

result: CallToolResult = await self._session.call_tool(
func_name, kwargs
Expand Down Expand Up @@ -236,6 +264,27 @@ async def dynamic_function(**kwargs):

return dynamic_function

def _build_tool_schema(self, mcp_tool: "Tool") -> Dict[str, Any]:
input_schema = mcp_tool.inputSchema
properties = input_schema.get("properties", {})
required = input_schema.get("required", [])

parameters = {
"type": "object",
"properties": properties,
"required": required,
}

return {
"type": "function",
"function": {
"name": mcp_tool.name,
"description": mcp_tool.description
or "No description provided.",
"parameters": parameters,
},
}

def get_tools(self) -> List[FunctionTool]:
r"""Returns a list of FunctionTool objects representing the
functions in the toolkit. Each function is dynamically generated
Expand All @@ -246,6 +295,197 @@ def get_tools(self) -> List[FunctionTool]:
representing the functions in the toolkit.
"""
return [
FunctionTool(self.generate_function_from_mcp_tool(mcp_tool))
FunctionTool(
self.generate_function_from_mcp_tool(mcp_tool),
openai_tool_schema=self._build_tool_schema(mcp_tool),
)
for mcp_tool in self._mcp_tools
]


class MCPToolkit(BaseToolkit):
r"""MCPToolkit provides a unified interface for managing multiple
MCP server connections and their tools.

This class handles the lifecycle of multiple MCP server connections and
offers a centralized configuration mechanism for both local and remote
MCP services.

Args:
servers (Optional[List[_MCPServer]]): List of _MCPServer
instances to manage.
config_path (Optional[str]): Path to a JSON configuration file
defining MCP servers.

Note:
Either `servers` or `config_path` must be provided. If both are
provided, servers from both sources will be combined.

Attributes:
servers (List[_MCPServer]): List of _MCPServer instances being managed.
"""

def __init__(
self,
servers: Optional[List[_MCPServer]] = None,
config_path: Optional[str] = None,
):
super().__init__()

if servers and config_path:
logger.warning(
"Both servers and config_path are provided. "
"Servers from both sources will be combined."
)

self.servers = servers or []

if config_path:
self.servers.extend(self._load_servers_from_config(config_path))

self._exit_stack = AsyncExitStack()
self._connected = False

def _load_servers_from_config(self, config_path: str) -> List[_MCPServer]:
r"""Loads MCP server configurations from a JSON file.

Args:
config_path (str): Path to the JSON configuration file.

Returns:
List[_MCPServer]: List of configured _MCPServer instances.
"""
try:
with open(config_path, "r", encoding="utf-8") as f:
try:
data = json.load(f)
except json.JSONDecodeError as e:
logger.warning(
f"Invalid JSON in config file '{config_path}': {e!s}"
)
return []
except FileNotFoundError:
logger.warning(f"Config file not found: '{config_path}'")
return []

all_servers = []

# Process local MCP servers
mcp_servers = data.get("mcpServers", {})
if not isinstance(mcp_servers, dict):
logger.warning("'mcpServers' is not a dictionary, skipping...")
mcp_servers = {}

for name, cfg in mcp_servers.items():
if not isinstance(cfg, dict):
logger.warning(
f"Configuration for server '{name}' must be a dictionary"
)
continue

if "command" not in cfg:
logger.warning(
f"Missing required 'command' field for server '{name}'"
)
continue

server = _MCPServer(
command_or_url=cfg["command"],
args=cfg.get("args", []),
env={**os.environ, **cfg.get("env", {})},
timeout=cfg.get("timeout", None),
)
all_servers.append(server)

# Process remote MCP web servers
mcp_web_servers = data.get("mcpWebServers", {})
if not isinstance(mcp_web_servers, dict):
logger.warning("'mcpWebServers' is not a dictionary, skipping...")
mcp_web_servers = {}

for name, cfg in mcp_web_servers.items():
if not isinstance(cfg, dict):
logger.warning(
f"Configuration for web server '{name}' must"
"be a dictionary"
)
continue

if "url" not in cfg:
logger.warning(
f"Missing required 'url' field for web server '{name}'"
)
continue

server = _MCPServer(
command_or_url=cfg["url"],
timeout=cfg.get("timeout", None),
)
all_servers.append(server)

return all_servers

async def connect(self):
r"""Explicitly connect to all MCP servers.

Returns:
MCPToolkit: The connected toolkit instance
"""
if self._connected:
logger.warning("MCPToolkit is already connected")
return self

self._exit_stack = AsyncExitStack()
try:
# Sequentially connect to each server
for server in self.servers:
await server.connect()
self._connected = True
return self
except Exception as e:
# Ensure resources are cleaned up on connection failure
await self.disconnect()
logger.error(f"Failed to connect to one or more MCP servers: {e}")

async def disconnect(self):
r"""Explicitly disconnect from all MCP servers."""
if not self._connected:
return

for server in self.servers:
await server.disconnect()
self._connected = False
await self._exit_stack.aclose()

@asynccontextmanager
async def connection(self) -> AsyncGenerator["MCPToolkit", None]:
r"""Async context manager that simultaneously establishes connections
to all managed MCP server instances.

Yields:
MCPToolkit: Self with all servers connected.
"""
try:
await self.connect()
yield self
finally:
await self.disconnect()

def is_connected(self) -> bool:
r"""Checks if all the managed servers are connected.

Returns:
bool: True if connected, False otherwise.
"""
return self._connected

def get_tools(self) -> List[FunctionTool]:
r"""Aggregates all tools from the managed MCP server instances.

Returns:
List[FunctionTool]: Combined list of all available function tools.
"""
all_tools = []
for server in self.servers:
all_tools.extend(server.get_tools())
return all_tools
Loading