diff --git a/camel/toolkits/__init__.py b/camel/toolkits/__init__.py index 11be755a2a..f085e61326 100644 --- a/camel/toolkits/__init__.py +++ b/camel/toolkits/__init__.py @@ -98,6 +98,7 @@ 'SymPyToolkit', 'MinerUToolkit', 'MCPToolkit', + 'MCPToolkitManager', 'AudioAnalysisToolkit', 'ExcelToolkit', 'VideoAnalysisToolkit', diff --git a/camel/toolkits/mcp_toolkit.py b/camel/toolkits/mcp_toolkit.py index af03a37a03..c8ece5301c 100644 --- a/camel/toolkits/mcp_toolkit.py +++ b/camel/toolkits/mcp_toolkit.py @@ -12,10 +12,13 @@ # limitations under the License. # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= import inspect +import json +import os from contextlib import AsyncExitStack, asynccontextmanager from typing import ( TYPE_CHECKING, Any, + AsyncGenerator, Callable, Dict, List, @@ -28,13 +31,16 @@ if TYPE_CHECKING: from mcp import ListToolsResult, Tool +from camel.logger import get_logger from camel.toolkits import BaseToolkit, FunctionTool +logger = get_logger(__name__) -class MCPToolkit(BaseToolkit): - r"""MCPToolkit provides an abstraction layer to interact with external - tools using the Model Context Protocol (MCP). It supports two modes of - connection: + +class _MCPServer(BaseToolkit): + r"""Internal class that provides an abstraction layer to interact with + external tools using the Model Context Protocol (MCP). It supports two + modes of connection: 1. stdio mode: Connects via standard input/output streams for local command-line interactions. @@ -73,20 +79,20 @@ def __init__( self._exit_stack = AsyncExitStack() self._is_connected = False - @asynccontextmanager - async def connection(self): - r"""Async context manager for establishing and managing the connection - with the MCP server. Automatically selects SSE or stdio mode based - on the provided `command_or_url`. + async def connect(self): + r"""Explicitly connect to the MCP server. - Yields: - MCPToolkit: Instance with active connection ready for tool - interaction. + Returns: + _MCPServer: The connected server instance """ from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters, stdio_client + if self._is_connected: + logger.warning("Server is already connected") + return self + try: if urlparse(self.command_or_url).scheme in ("http", "https"): ( @@ -113,12 +119,33 @@ async def connection(self): list_tools_result = await self.list_mcp_tools() self._mcp_tools = list_tools_result.tools self._is_connected = True - yield self + return self + except Exception as e: + # Ensure resources are cleaned up on connection failure + await self.disconnect() + logger.error(f"Failed to connect to MCP server: {e}") + async def disconnect(self): + r"""Explicitly disconnect from the MCP server.""" + self._is_connected = False + await self._exit_stack.aclose() + self._session = None + + @asynccontextmanager + async def connection(self): + r"""Async context manager for establishing and managing the connection + with the MCP server. Automatically selects SSE or stdio mode based + on the provided `command_or_url`. + + Yields: + _MCPServer: Instance with active connection ready for tool + interaction. + """ + try: + await self.connect() + yield self finally: - self._is_connected = False - await self._exit_stack.aclose() - self._session = None + await self.disconnect() async def list_mcp_tools(self) -> Union[str, "ListToolsResult"]: r"""Retrieves the list of available tools from the connected MCP @@ -188,34 +215,53 @@ async def dynamic_function(**kwargs): kwargs.keys() ) if missing_params: - raise ValueError( + logger.warning( f"Missing required parameters: {missing_params}" ) + return "Missing required parameters." - result: CallToolResult = await self._session.call_tool( - func_name, kwargs - ) + if not self._session: + logger.error( + "MCP Client is not connected. Call `connection()` first." + ) + return ( + "MCP Client is not connected. Call `connection()` first." + ) - if not result.content: + try: + result: CallToolResult = await self._session.call_tool( + func_name, kwargs + ) + except Exception as e: + logger.error(f"Failed to call MCP tool '{func_name}': {e!s}") + return f"Failed to call MCP tool '{func_name}': {e!s}" + + if not result.content or len(result.content) == 0: return "No data available for this request." # Handle different content types - content = result.content[0] - if content.type == "text": - return content.text - elif content.type == "image": - # Return image URL or data URI if available - if hasattr(content, "url") and content.url: - return f"Image available at: {content.url}" - return "Image content received (data URI not shown)" - elif content.type == "embedded_resource": - # Return resource information if available - if hasattr(content, "name") and content.name: - return f"Embedded resource: {content.name}" - return "Embedded resource received" - else: - msg = f"Received content of type '{content.type}'" - return f"{msg} which is not fully supported yet." + try: + content = result.content[0] + if content.type == "text": + return content.text + elif content.type == "image": + # Return image URL or data URI if available + if hasattr(content, "url") and content.url: + return f"Image available at: {content.url}" + return "Image content received (data URI not shown)" + elif content.type == "embedded_resource": + # Return resource information if available + if hasattr(content, "name") and content.name: + return f"Embedded resource: {content.name}" + return "Embedded resource received" + else: + msg = f"Received content of type '{content.type}'" + return f"{msg} which is not fully supported yet." + except (IndexError, AttributeError) as e: + logger.error( + f"Error processing content from MCP tool response: {e!s}" + ) + return "Error processing content from MCP tool response" dynamic_function.__name__ = func_name dynamic_function.__doc__ = func_desc @@ -236,6 +282,27 @@ async def dynamic_function(**kwargs): return dynamic_function + def _build_tool_schema(self, mcp_tool: "Tool") -> Dict[str, Any]: + input_schema = mcp_tool.inputSchema + properties = input_schema.get("properties", {}) + required = input_schema.get("required", []) + + parameters = { + "type": "object", + "properties": properties, + "required": required, + } + + return { + "type": "function", + "function": { + "name": mcp_tool.name, + "description": mcp_tool.description + or "No description provided.", + "parameters": parameters, + }, + } + def get_tools(self) -> List[FunctionTool]: r"""Returns a list of FunctionTool objects representing the functions in the toolkit. Each function is dynamically generated @@ -246,6 +313,197 @@ def get_tools(self) -> List[FunctionTool]: representing the functions in the toolkit. """ return [ - FunctionTool(self.generate_function_from_mcp_tool(mcp_tool)) + FunctionTool( + self.generate_function_from_mcp_tool(mcp_tool), + openai_tool_schema=self._build_tool_schema(mcp_tool), + ) for mcp_tool in self._mcp_tools ] + + +class MCPToolkit(BaseToolkit): + r"""MCPToolkit provides a unified interface for managing multiple + MCP server connections and their tools. + + This class handles the lifecycle of multiple MCP server connections and + offers a centralized configuration mechanism for both local and remote + MCP services. + + Args: + servers (Optional[List[_MCPServer]]): List of _MCPServer + instances to manage. + config_path (Optional[str]): Path to a JSON configuration file + defining MCP servers. + + Note: + Either `servers` or `config_path` must be provided. If both are + provided, servers from both sources will be combined. + + Attributes: + servers (List[_MCPServer]): List of _MCPServer instances being managed. + """ + + def __init__( + self, + servers: Optional[List[_MCPServer]] = None, + config_path: Optional[str] = None, + ): + super().__init__() + + if servers and config_path: + logger.warning( + "Both servers and config_path are provided. " + "Servers from both sources will be combined." + ) + + self.servers = servers or [] + + if config_path: + self.servers.extend(self._load_servers_from_config(config_path)) + + self._exit_stack = AsyncExitStack() + self._connected = False + + def _load_servers_from_config(self, config_path: str) -> List[_MCPServer]: + r"""Loads MCP server configurations from a JSON file. + + Args: + config_path (str): Path to the JSON configuration file. + + Returns: + List[_MCPServer]: List of configured _MCPServer instances. + """ + try: + with open(config_path, "r", encoding="utf-8") as f: + try: + data = json.load(f) + except json.JSONDecodeError as e: + logger.warning( + f"Invalid JSON in config file '{config_path}': {e!s}" + ) + return [] + except FileNotFoundError: + logger.warning(f"Config file not found: '{config_path}'") + return [] + + all_servers = [] + + # Process local MCP servers + mcp_servers = data.get("mcpServers", {}) + if not isinstance(mcp_servers, dict): + logger.warning("'mcpServers' is not a dictionary, skipping...") + mcp_servers = {} + + for name, cfg in mcp_servers.items(): + if not isinstance(cfg, dict): + logger.warning( + f"Configuration for server '{name}' must be a dictionary" + ) + continue + + if "command" not in cfg: + logger.warning( + f"Missing required 'command' field for server '{name}'" + ) + continue + + server = _MCPServer( + command_or_url=cfg["command"], + args=cfg.get("args", []), + env={**os.environ, **cfg.get("env", {})}, + timeout=cfg.get("timeout", None), + ) + all_servers.append(server) + + # Process remote MCP web servers + mcp_web_servers = data.get("mcpWebServers", {}) + if not isinstance(mcp_web_servers, dict): + logger.warning("'mcpWebServers' is not a dictionary, skipping...") + mcp_web_servers = {} + + for name, cfg in mcp_web_servers.items(): + if not isinstance(cfg, dict): + logger.warning( + f"Configuration for web server '{name}' must" + "be a dictionary" + ) + continue + + if "url" not in cfg: + logger.warning( + f"Missing required 'url' field for web server '{name}'" + ) + continue + + server = _MCPServer( + command_or_url=cfg["url"], + timeout=cfg.get("timeout", None), + ) + all_servers.append(server) + + return all_servers + + async def connect(self): + r"""Explicitly connect to all MCP servers. + + Returns: + MCPToolkit: The connected toolkit instance + """ + if self._connected: + logger.warning("MCPToolkit is already connected") + return self + + self._exit_stack = AsyncExitStack() + try: + # Sequentially connect to each server + for server in self.servers: + await server.connect() + self._connected = True + return self + except Exception as e: + # Ensure resources are cleaned up on connection failure + await self.disconnect() + logger.error(f"Failed to connect to one or more MCP servers: {e}") + + async def disconnect(self): + r"""Explicitly disconnect from all MCP servers.""" + if not self._connected: + return + + for server in self.servers: + await server.disconnect() + self._connected = False + await self._exit_stack.aclose() + + @asynccontextmanager + async def connection(self) -> AsyncGenerator["MCPToolkit", None]: + r"""Async context manager that simultaneously establishes connections + to all managed MCP server instances. + + Yields: + MCPToolkit: Self with all servers connected. + """ + try: + await self.connect() + yield self + finally: + await self.disconnect() + + def is_connected(self) -> bool: + r"""Checks if all the managed servers are connected. + + Returns: + bool: True if connected, False otherwise. + """ + return self._connected + + def get_tools(self) -> List[FunctionTool]: + r"""Aggregates all tools from the managed MCP server instances. + + Returns: + List[FunctionTool]: Combined list of all available function tools. + """ + all_tools = [] + for server in self.servers: + all_tools.extend(server.get_tools()) + return all_tools diff --git a/examples/toolkits/mcp/mcp_server.py b/examples/toolkits/mcp/mcp_server.py deleted file mode 100755 index 7f02d5741f..0000000000 --- a/examples/toolkits/mcp/mcp_server.py +++ /dev/null @@ -1,136 +0,0 @@ -# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= -from typing import Any - -import httpx -from mcp.server.fastmcp import FastMCP - -mcp = FastMCP("weather") - -NWS_API_BASE = "https://api.weather.gov" -USER_AGENT = "weather-app/1.0" - - -async def make_nws_request(url: str) -> dict[str, Any] | None: - r"""Make a request to the NWS API with proper error handling.""" - headers = {"User-Agent": USER_AGENT, "Accept": "application/geo+json"} - async with httpx.AsyncClient() as client: - try: - response = await client.get(url, headers=headers, timeout=30.0) - response.raise_for_status() - return response.json() - except Exception: - return None - - -def format_alert(feature: dict) -> str: - r"""Format an alert feature into a readable string.""" - props = feature["properties"] - return f""" -Event: {props.get('event', 'Unknown')} -Area: {props.get('areaDesc', 'Unknown')} -Severity: {props.get('severity', 'Unknown')} -Description: {props.get('description', 'No description available')} -Instructions: {props.get('instruction', 'No specific instructions provided')} -""" - - -@mcp.tool() -async def get_alerts(state: str) -> str: - r"""Get weather alerts for a US state. - - Args: - state: Two-letter US state code (e.g. CA, NY) - """ - url = f"{NWS_API_BASE}/alerts/active/area/{state}" - data = await make_nws_request(url) - - if not data or "features" not in data: - return "Unable to fetch alerts or no alerts found." - - if not data["features"]: - return "No active alerts for this state." - - alerts = [format_alert(feature) for feature in data["features"]] - return "\n---\n".join(alerts) - - -@mcp.tool() -async def get_forecast(latitude: float, longitude: float) -> str: - r"""Get weather forecast for a location. - - Args: - latitude: Latitude of the location - longitude: Longitude of the location - """ - # First get the forecast grid endpoint - points_url = f"{NWS_API_BASE}/points/{latitude},{longitude}" - points_data = await make_nws_request(points_url) - - if not points_data: - return "Unable to fetch forecast data for this location." - - # Get the forecast URL from the points response - forecast_url = points_data["properties"]["forecast"] - forecast_data = await make_nws_request(forecast_url) - - if not forecast_data: - return "Unable to fetch detailed forecast." - - # Format the periods into a readable forecast - periods = forecast_data["properties"]["periods"] - forecasts = [] - for period in periods[:5]: # Only show next 5 periods - forecast = f""" -{period['name']}: -Temperature: {period['temperature']}°{period['temperatureUnit']} -Wind: {period['windSpeed']} {period['windDirection']} -Forecast: {period['detailedForecast']} -""" - forecasts.append(forecast) - - return "\n---\n".join(forecasts) - - -def main(transport: str = "stdio"): - r"""Weather MCP Server - - This server provides weather-related functionalities implemented via the Model Context Protocol (MCP). - It demonstrates how to establish interactions between AI models and external tools using MCP. - - The server supports two modes of operation: - - 1. stdio mode (default): - - - Communicates with clients via standard input/output streams, ideal for local command-line usage. - - - Example usage: python mcp_server.py [--transport stdio] - - 2. SSE mode (Server-Sent Events): - - - Communicates with clients over HTTP using server-sent events, suitable for persistent network connections. - - - Runs by default at http://127.0.0.1:8000. - - - Example usage: python mcp_server.py --transport sse - """ # noqa: E501 - if transport == 'stdio': - mcp.run(transport='stdio') - elif transport == 'sse': - mcp.run(transport='sse') - - -if __name__ == "__main__": - # Hardcoded to use stdio transport mode - main("stdio") diff --git a/examples/toolkits/mcp/mcp_servers_config.json b/examples/toolkits/mcp/mcp_servers_config.json new file mode 100644 index 0000000000..d38d2842a1 --- /dev/null +++ b/examples/toolkits/mcp/mcp_servers_config.json @@ -0,0 +1,13 @@ +{ + "mcpServers": { + "filesystem": { + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-filesystem@2025.1.14", + "." + ] + } + }, + "mcpWebServers": {} +} \ No newline at end of file diff --git a/examples/toolkits/mcp/mcp_toolkit.py b/examples/toolkits/mcp/mcp_toolkit.py index 1072268220..bd0f5203d2 100644 --- a/examples/toolkits/mcp/mcp_toolkit.py +++ b/examples/toolkits/mcp/mcp_toolkit.py @@ -11,98 +11,70 @@ # See the License for the specific language governing permissions and # limitations under the License. # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +"""MCP Server Example + +This example demonstrates how to use the MCP (Managed Code Processing) server +with CAMEL agents for file operations. + +Setup: +1. Install Node.js and npm + +2. Install MCP filesystem server globally: + ```bash + npm install -g @modelcontextprotocol/server-filesystem + ``` + +Usage: +1. Run this script to start an MCP filesystem server +2. The server will only operate within the specified directory +3. All paths in responses will be relative to maintain privacy +""" + import asyncio -import sys from pathlib import Path from camel.agents import ChatAgent -from camel.configs.openai_config import ChatGPTConfig from camel.models import ModelFactory from camel.toolkits import MCPToolkit from camel.types import ModelPlatformType, ModelType -async def main(server_transport: str = "stdio"): - if server_transport == "stdio": - mcp_toolkit = MCPToolkit( - command_or_url=sys.executable, - args=[str(Path(__file__).parent / "mcp_server.py")], - ) - else: - # SSE mode, Must run the server first - mcp_toolkit = MCPToolkit("http://127.0.0.1:8000/sse") - - async with mcp_toolkit.connection() as toolkit: - tools = toolkit.get_tools() +async def main(): + config_path = Path(__file__).parent / "mcp_servers_config.json" + mcp_toolkit = MCPToolkit(config_path=str(config_path)) - # Define system message - sys_msg = "You are a helpful assistant" - model_config_dict = ChatGPTConfig( - temperature=0.0, - ).as_dict() + # Connect to all MCP servers. + await mcp_toolkit.connect() - model = ModelFactory.create( - model_platform=ModelPlatformType.DEFAULT, - model_type=ModelType.DEFAULT, - model_config_dict=model_config_dict, - ) + sys_msg = "You are a helpful assistant" + model = ModelFactory.create( + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, + ) + camel_agent = ChatAgent( + system_message=sys_msg, + model=model, + tools=[*mcp_toolkit.get_tools()], + ) + user_msg = "List 5 files in the project, using relative paths" + response = await camel_agent.astep(user_msg) + print(response.msgs[0].content) + print(response.info['tool_calls']) - # Set agent - camel_agent = ChatAgent( - system_message=sys_msg, - model=model, - tools=tools, - ) - camel_agent.reset() - - # Define a user message - usr_msg = "How is the weather in Chicago today?" - - # Get response information - response = await camel_agent.astep(usr_msg) - print(str(response.info['tool_calls'])) - """ - ======================================================================= - [ToolCallingRecord(tool_name='get_forecast', args={'latitude': 41. - 8781, 'longitude': -87.6298}, result='\nThis Afternoon:\nTemperature: - 65°F\nWind: 15 mph SW\nForecast: Sunny, with a high near 65. Southwest - wind around 15 mph, with gusts as high as 25 mph. - \n\n---\n\nTonight:\nTemperature: 45°F\nWind: 10 to 15 mph - SW\nForecast: Mostly clear. Low around 45, with temperatures rising to - around 51 overnight. Southwest wind 10 to 15 mph, with gusts as high - as 30 mph.\n\n---\n\nTuesday:\nTemperature: 46°F\nWind: 10 to 20 mph - NW\nForecast: Mostly sunny. High near 46, with temperatures falling - to around 36 in the afternoon. Northwest wind 10 to 20 mph, with - gusts as high as 35 mph.\n\n---\n\nTuesday Night:\nTemperature: 36° - F\nWind: 10 to 20 mph ENE\nForecast: Partly cloudy, with a low around - 36. East northeast wind 10 to 20 mph. - \n\n---\n\nWednesday:\nTemperature: 46°F\nWind: 5 to 15 mph - E\nForecast: Mostly sunny, with a high near 46. East wind 5 to 15 mph. - \n', tool_call_id='call_DJjGYAzqlzb5ojirRAuKZmtk')] - ======================================================================= - """ - - usr_msg = "Please get the latest 3 weather alerts for California." - - # Get response information - response = await camel_agent.astep(usr_msg) - print(str(response.info['tool_calls'])) - """ - ======================================================================= - [ToolCallingRecord(tool_name='get_alerts', args={'state': 'CA'}, result - ='\nEvent: Wind Advisory\nArea: Central Siskiyou County\nSeverity: - Moderate\nDescription: * WHAT...South winds 20 to 30 mph with gusts up - to 50 mph expected.\n\n* WHERE...Portions of central Siskiyou County. - This includes\nInterstate 5 from Weed to Grenada and portions of - Highway 97.\n\n* WHEN...Until 5 PM PDT this afternoon.\n\n* IMPACTS... - Gusty winds will blow around unsecured objects. Tree\nlimbs could be - blown down and a few power outages may result.\nInstructions: Winds - this strong can make driving difficult, especially for high\nprofile - vehicles. Use extra caution.\n\nSecure outdoor objects.\n', - tool_call_id='call_JRDYuTjOjYrymXeFiWxcHZ5d')] - ======================================================================= - """ + # Disconnect from all MCP servers and clean up resources. + await mcp_toolkit.disconnect() if __name__ == "__main__": asyncio.run(main()) +''' +=============================================================================== +Here are 5 files in the project using relative paths: + +1. `.env` +2. `.gitignore` +3. `.pre-commit-config.yaml` +4. `CONTRIBUTING.md` +5. `README.md` +=============================================================================== +''' diff --git a/test/toolkits/test_mcp_toolkit.py b/test/toolkits/test_mcp_toolkit.py index 0424a2c550..025d76d972 100644 --- a/test/toolkits/test_mcp_toolkit.py +++ b/test/toolkits/test_mcp_toolkit.py @@ -11,270 +11,732 @@ # See the License for the specific language governing permissions and # limitations under the License. # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +import json +import tempfile +from contextlib import AsyncExitStack +from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest -from camel.toolkits.mcp_toolkit import MCPToolkit - - -@pytest.mark.asyncio -async def test_init(): - r"""Test initialization of MCPToolkit.""" - # Test with default parameters - toolkit = MCPToolkit("test_command") - assert toolkit.command_or_url == "test_command" - assert toolkit.args == [] - assert toolkit.env == {} - assert toolkit._mcp_tools == [] - assert toolkit._session is None - assert toolkit._is_connected is False - - # Test with custom parameters - toolkit = MCPToolkit( - "test_url", - args=["--arg1", "--arg2"], - env={"ENV_VAR": "value"}, - timeout=30, - ) - assert toolkit.command_or_url == "test_url" - assert toolkit.args == ["--arg1", "--arg2"] - assert toolkit.env == {"ENV_VAR": "value"} - assert toolkit._mcp_tools == [] - assert toolkit._session is None - assert toolkit._is_connected is False - - -@pytest.mark.asyncio -async def test_connection_http(): - r"""Test connection with HTTP URL.""" - with ( - patch("mcp.client.sse.sse_client") as mock_sse_client, - patch("mcp.client.session.ClientSession") as mock_session, - ): - # Setup mocks - mock_read_stream = AsyncMock() - mock_write_stream = AsyncMock() - mock_sse_client.return_value.__aenter__.return_value = ( - mock_read_stream, - mock_write_stream, +from camel.toolkits.mcp_toolkit import MCPToolkit, _MCPServer + + +class Test_MCPServer: + r"""Test _MCPServer class.""" + + @pytest.mark.asyncio + async def test_init(self): + r"""Test initialization of _MCPServer.""" + # Test with default parameters + server = _MCPServer("test_command") + assert server.command_or_url == "test_command" + assert server.args == [] + assert server.env == {} + assert server._mcp_tools == [] + assert server._session is None + assert server._is_connected is False + + # Test with custom parameters + server = _MCPServer( + "test_url", + args=["--arg1", "--arg2"], + env={"ENV_VAR": "value"}, + timeout=30, ) - - mock_session_instance = AsyncMock() - mock_session.return_value.__aenter__.return_value = ( - mock_session_instance + assert server.command_or_url == "test_url" + assert server.args == ["--arg1", "--arg2"] + assert server.env == {"ENV_VAR": "value"} + assert server._mcp_tools == [] + assert server._session is None + assert server._is_connected is False + + @pytest.mark.asyncio + async def test_connection_http(self): + r"""Test connection with HTTP URL.""" + with ( + patch("mcp.client.sse.sse_client") as mock_sse_client, + patch("mcp.client.session.ClientSession") as mock_session, + ): + # Setup mocks + mock_read_stream = AsyncMock() + mock_write_stream = AsyncMock() + mock_sse_client.return_value.__aenter__.return_value = ( + mock_read_stream, + mock_write_stream, + ) + + mock_session_instance = AsyncMock() + mock_session.return_value.__aenter__.return_value = ( + mock_session_instance + ) + + # Mock list_tools result + list_tools_result = MagicMock() + list_tools_result.tools = ["tool1", "tool2"] + mock_session_instance.list_tools.return_value = list_tools_result + + # Test HTTP connection + server = _MCPServer("https://example.com/api") + async with server.connection() as connected_server: + assert connected_server._is_connected is True + assert connected_server._mcp_tools == ["tool1", "tool2"] + + # Verify mocks were called correctly + mock_sse_client.assert_called_once_with("https://example.com/api") + mock_session.assert_called_once() + mock_session_instance.initialize.assert_called_once() + mock_session_instance.list_tools.assert_called_once() + + @pytest.mark.asyncio + async def test_connection_stdio(self): + r"""Test connection with stdio command.""" + with ( + patch("mcp.client.stdio.stdio_client") as mock_stdio_client, + patch("mcp.client.session.ClientSession") as mock_session, + ): + # Setup mocks + mock_read_stream = AsyncMock() + mock_write_stream = AsyncMock() + mock_stdio_client.return_value.__aenter__.return_value = ( + mock_read_stream, + mock_write_stream, + ) + + mock_session_instance = AsyncMock() + mock_session.return_value.__aenter__.return_value = ( + mock_session_instance + ) + + # Mock list_tools result + list_tools_result = MagicMock() + list_tools_result.tools = ["tool1", "tool2"] + mock_session_instance.list_tools.return_value = list_tools_result + + # Test stdio connection + server = _MCPServer( + "local_command", args=["--arg1"], env={"ENV_VAR": "value"} + ) + async with server.connection() as connected_server: + assert connected_server._is_connected is True + assert connected_server._mcp_tools == ["tool1", "tool2"] + + # Verify mocks were called correctly + mock_stdio_client.assert_called_once() + mock_session.assert_called_once() + mock_session_instance.initialize.assert_called_once() + mock_session_instance.list_tools.assert_called_once() + + @pytest.mark.asyncio + async def test_list_mcp_tools_not_connected(self): + r"""Test list_mcp_tools when not connected.""" + server = _MCPServer("test_command") + result = await server.list_mcp_tools() + assert isinstance(result, str) + assert "not connected" in result + + @pytest.mark.asyncio + async def test_list_mcp_tools_connected(self): + r"""Test list_mcp_tools when connected.""" + server = _MCPServer("test_command") + server._session = AsyncMock() + + # Mock successful response + mock_result = MagicMock() + server._session.list_tools.return_value = mock_result + + result = await server.list_mcp_tools() + assert result == mock_result + server._session.list_tools.assert_called_once() + + # Mock exception + server._session.list_tools.side_effect = Exception("Test error") + result = await server.list_mcp_tools() + assert isinstance(result, str) + assert "Failed to list MCP tools" in result + + @pytest.mark.asyncio + async def test_generate_function_from_mcp_tool(self): + r"""Test generate_function_from_mcp_tool.""" + server = _MCPServer("test_command") + server._session = AsyncMock() + + # Create mock MCP tool + mock_tool = MagicMock() + mock_tool.name = "test_function" + mock_tool.description = "Test function description" + mock_tool.inputSchema = { + "properties": { + "param1": {"type": "string"}, + "param2": {"type": "integer"}, + "param3": {"type": "boolean"}, + }, + "required": ["param1", "param2"], + } + + # Generate function + func = server.generate_function_from_mcp_tool(mock_tool) + + # Check function attributes + assert func.__name__ == "test_function" + assert func.__doc__ == "Test function description" + assert "param1" in func.__annotations__ + assert "param2" in func.__annotations__ + assert "param3" in func.__annotations__ + + # Mock call_tool response + mock_content = MagicMock() + mock_content.type = "text" + mock_content.text = "Test result" + + mock_result = MagicMock() + mock_result.content = [mock_content] + server._session.call_tool.return_value = mock_result + + # Test function call + result = await func(param1="test", param2=123) + assert result == "Test result" + server._session.call_tool.assert_called_once_with( + "test_function", {"param1": "test", "param2": 123} ) - # Mock list_tools result - list_tools_result = MagicMock() - list_tools_result.tools = ["tool1", "tool2"] - mock_session_instance.list_tools.return_value = list_tools_result + # Test missing required parameter - now returns a message + with patch("camel.toolkits.mcp_toolkit.logger") as mock_logger: + mock_logger.reset_mock() + + result = await func(param1="test") + assert result == "Missing required parameters." + mock_logger.warning.assert_called_once() + + # Test different content types + # Image content + mock_content.type = "image" + mock_content.url = "https://example.com/image.jpg" + result = await func(param1="test", param2=123) + assert "Image available at" in result + + # Image without URL + mock_content.url = None + result = await func(param1="test", param2=123) + assert "Image content received" in result + + # Embedded resource + mock_content.type = "embedded_resource" + mock_content.name = "resource.pdf" + result = await func(param1="test", param2=123) + assert "Embedded resource: resource.pdf" in result + + # Embedded resource without name + mock_content.name = None + result = await func(param1="test", param2=123) + assert "Embedded resource received" in result + + # Unknown content type + mock_content.type = "unknown" + result = await func(param1="test", param2=123) + assert "not fully supported" in result + + # No content + mock_result.content = [] + result = await func(param1="test", param2=123) + assert "No data available" in result + + @pytest.mark.asyncio + async def test_build_tool_schema(self): + r"""Test build_tool_schema method.""" + server = _MCPServer("test_command") + mock_tool = MagicMock() + mock_tool.name = "test_function" + mock_tool.description = "Test function description" + mock_tool.inputSchema = { + "properties": { + "param1": {"type": "string"}, + "param2": {"type": "integer"}, + }, + "required": ["param1", "param2"], + } + schema = server._build_tool_schema(mock_tool) + + target_schema = { + "type": "function", + "function": { + "name": "test_function", + "description": "Test function description", + "parameters": { + "type": "object", + "properties": { + "param1": {"type": "string"}, + "param2": {"type": "integer"}, + }, + "required": ["param1", "param2"], + }, + }, + } + assert schema == target_schema + + # No description + mock_tool.description = None + schema = server._build_tool_schema(mock_tool) + assert schema == { + "type": "function", + "function": { + "name": "test_function", + "description": "No description provided.", + "parameters": { + "type": "object", + "properties": { + "param1": {"type": "string"}, + "param2": {"type": "integer"}, + }, + "required": ["param1", "param2"], + }, + }, + } + + @pytest.mark.asyncio + async def test_get_tools(self): + r"""Test get_tools method for _MCPServer.""" + with patch( + "camel.toolkits.mcp_toolkit.FunctionTool" + ) as mock_function_tool: + server = _MCPServer("test_command") + + # Mock tools + mock_tool1 = MagicMock() + mock_tool2 = MagicMock() + server._mcp_tools = [mock_tool1, mock_tool2] + + # Mock generate_function_from_mcp_tool + mock_func1 = AsyncMock() + mock_func2 = AsyncMock() + server.generate_function_from_mcp_tool = MagicMock( + side_effect=[mock_func1, mock_func2] + ) + + # Mock FunctionTool + mock_function_tool_instance1 = MagicMock() + mock_function_tool_instance2 = MagicMock() + mock_function_tool.side_effect = [ + mock_function_tool_instance1, + mock_function_tool_instance2, + ] + + # Get tools + tools = server.get_tools() + + # Verify results + assert len(tools) == 2 + assert tools[0] == mock_function_tool_instance1 + assert tools[1] == mock_function_tool_instance2 + + # Verify mocks were called correctly + server.generate_function_from_mcp_tool.assert_any_call(mock_tool1) + server.generate_function_from_mcp_tool.assert_any_call(mock_tool2) + + @pytest.mark.asyncio + async def test_connect_explicit(self): + r"""Test explicit connect method.""" + with ( + patch("mcp.client.sse.sse_client") as mock_sse_client, + patch("mcp.client.session.ClientSession") as mock_session, + ): + # Setup mocks + mock_read_stream = AsyncMock() + mock_write_stream = AsyncMock() + mock_sse_client.return_value.__aenter__.return_value = ( + mock_read_stream, + mock_write_stream, + ) + + mock_session_instance = AsyncMock() + mock_session.return_value.__aenter__.return_value = ( + mock_session_instance + ) + + # Mock list_tools result + list_tools_result = MagicMock() + list_tools_result.tools = ["tool1", "tool2"] + mock_session_instance.list_tools.return_value = list_tools_result + + # Test HTTP connection + server = _MCPServer("https://example.com/api") + result = await server.connect() + + # Verify results + assert result == server + assert server._is_connected is True + assert server._mcp_tools == ["tool1", "tool2"] + assert server._session is not None + + # Verify mocks were called correctly + mock_sse_client.assert_called_once_with("https://example.com/api") + mock_session.assert_called_once() + mock_session_instance.initialize.assert_called_once() + mock_session_instance.list_tools.assert_called_once() + + # Test connecting when already connected + with patch("camel.toolkits.mcp_toolkit.logger") as mock_logger: + result = await server.connect() + assert result == server + mock_logger.warning.assert_called_once() + # Verify no new connections were made + assert mock_sse_client.call_count == 1 + + @pytest.mark.asyncio + async def test_connect_failure(self): + r"""Test connect method with failure.""" + with patch("mcp.client.sse.sse_client") as mock_sse_client: + # Setup mock to raise exception + mock_sse_client.return_value.__aenter__.side_effect = Exception( + "Connection error" + ) + + # Create server + server = _MCPServer("https://example.com/api") + + # Mock disconnect to verify it's called on failure + server.disconnect = AsyncMock() + + # Test connect with failure + with patch("camel.toolkits.mcp_toolkit.logger") as mock_logger: + await server.connect() + + # Verify disconnect was called to clean up + server.disconnect.assert_called_once() + mock_logger.error.assert_called_once() + + # Verify server is not connected + assert server._is_connected is False + + @pytest.mark.asyncio + async def test_disconnect_explicit(self): + r"""Test explicit disconnect method.""" + # Create server + server = _MCPServer("test_command") + + # Setup connected state + server._is_connected = True + server._exit_stack = AsyncMock() + server._exit_stack.aclose = AsyncMock() + server._session = MagicMock() + + # Test disconnect + await server.disconnect() - # Test HTTP connection - toolkit = MCPToolkit("https://example.com/api") - async with toolkit.connection() as connected_toolkit: - assert connected_toolkit._is_connected is True - assert connected_toolkit._mcp_tools == ["tool1", "tool2"] - - # Verify mocks were called correctly - mock_sse_client.assert_called_once_with("https://example.com/api") - mock_session.assert_called_once() - mock_session_instance.initialize.assert_called_once() - mock_session_instance.list_tools.assert_called_once() - - -@pytest.mark.asyncio -async def test_connection_stdio(): - r"""Test connection with stdio command.""" - with ( - patch("mcp.client.stdio.stdio_client") as mock_stdio_client, - patch("mcp.client.session.ClientSession") as mock_session, - ): - # Setup mocks - mock_read_stream = AsyncMock() - mock_write_stream = AsyncMock() - mock_stdio_client.return_value.__aenter__.return_value = ( - mock_read_stream, - mock_write_stream, - ) - - mock_session_instance = AsyncMock() - mock_session.return_value.__aenter__.return_value = ( - mock_session_instance - ) - - # Mock list_tools result - list_tools_result = MagicMock() - list_tools_result.tools = ["tool1", "tool2"] - mock_session_instance.list_tools.return_value = list_tools_result + # Verify results + assert server._is_connected is False + assert server._session is None + server._exit_stack.aclose.assert_called_once() + + # Test disconnecting when not connected + server._exit_stack.aclose.reset_mock() + + # Set up disconnected state + server._is_connected = False + server._exit_stack = AsyncMock() + server._exit_stack.aclose = AsyncMock() + + await server.disconnect() + + # Verify exit stack is still closed even when not connected + server._exit_stack.aclose.assert_called_once() + + +class TestMCPToolkit: + r"""Test MCPToolkit class.""" + + def test_init(self): + r"""Test initialization of MCPToolkit.""" + # Test with servers list + server1 = _MCPServer("test_command1") + server2 = _MCPServer("test_command2") + toolkit = MCPToolkit(servers=[server1, server2]) + + assert toolkit.servers == [server1, server2] + assert isinstance(toolkit._exit_stack, AsyncExitStack) + assert toolkit._connected is False + + # Test with both servers and config_path + with patch("camel.toolkits.mcp_toolkit.logger") as mock_logger: + with patch.object( + MCPToolkit, "_load_servers_from_config", return_value=[] + ): + toolkit = MCPToolkit( + servers=[server1], config_path="dummy_path" + ) + assert toolkit.servers == [server1] + mock_logger.warning.assert_called_once() + + def test_init_config_file_not_found(self): + r"""Test from_config with non-existent file.""" + with patch("camel.toolkits.mcp_toolkit.logger") as mock_logger: + with tempfile.TemporaryDirectory() as temp_dir: + non_existent_path = Path(temp_dir) / "non_existent.json" + toolkit = MCPToolkit(config_path=str(non_existent_path)) + + # Should return an empty toolkit. + assert toolkit.servers == [] + mock_logger.warning.assert_called_once() + + def test_init_config_invalid_json(self): + r"""Test from_config with invalid JSON.""" + with patch("camel.toolkits.mcp_toolkit.logger") as mock_logger: + with tempfile.TemporaryDirectory() as temp_dir: + config_path = Path(temp_dir) / "invalid.json" + config_path.write_text("{invalid json") + + toolkit = MCPToolkit(config_path=str(config_path)) + + # Should return an empty toolkit + assert toolkit.servers == [] + mock_logger.warning.assert_called_once() + + def test_init_config_valid(self): + r"""Test from_config with valid configuration.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_path = Path(temp_dir) / "valid.json" + config_data = { + "mcpServers": { + "server1": { + "command": "test-command", + "args": ["--arg1"], + "env": {"TEST_ENV": "value"}, + } + }, + "mcpWebServers": {"server2": {"url": "https://test.com/sse"}}, + } + config_path.write_text(json.dumps(config_data)) + + toolkit = MCPToolkit(config_path=str(config_path)) + assert len(toolkit.servers) == 2 + + # Check local server toolkit + assert toolkit.servers[0].command_or_url == "test-command" + assert toolkit.servers[0].args == ["--arg1"] + assert "TEST_ENV" in toolkit.servers[0].env + + # Check web server toolkit + assert toolkit.servers[1].command_or_url == "https://test.com/sse" + + def test_load_servers_from_config_missing_required_fields(self): + r"""Test _load_servers_from_config with missing required fields.""" + with patch("camel.toolkits.mcp_toolkit.logger") as mock_logger: + with tempfile.TemporaryDirectory() as temp_dir: + config_path = Path(temp_dir) / "invalid_fields.json" + + # Missing command field + config_data = {"mcpServers": {"server1": {"args": ["--arg1"]}}} + config_path.write_text(json.dumps(config_data)) + + mcp_toolkit = MCPToolkit() + servers = mcp_toolkit._load_servers_from_config( + str(config_path) + ) + # Should return an empty list and log a warning + assert servers == [] + mock_logger.warning.assert_called() + + mock_logger.reset_mock() + + # Missing url field + config_data = {"mcpWebServers": {"server1": {"timeout": 30}}} + config_path.write_text(json.dumps(config_data)) + + servers = mcp_toolkit._load_servers_from_config( + str(config_path) + ) + # Should return an empty list and log a warning + assert servers == [] + mock_logger.warning.assert_called() + + def test_load_servers_from_config_invalid_structure(self): + r"""Test _load_servers_from_config with invalid structure.""" + with patch("camel.toolkits.mcp_toolkit.logger") as mock_logger: + with tempfile.TemporaryDirectory() as temp_dir: + config_path = Path(temp_dir) / "invalid_structure.json" + + # mcpServers is not a dictionary + config_data = {"mcpServers": "not a dictionary"} + config_path.write_text(json.dumps(config_data)) + + mcp_toolkit = MCPToolkit() + servers = mcp_toolkit._load_servers_from_config( + str(config_path) + ) + # Should return an empty list and log a warning + assert servers == [] + mock_logger.warning.assert_called_with( + "'mcpServers' is not a dictionary, skipping..." + ) + + mock_logger.reset_mock() + + # mcpWebServers is not a dictionary + config_data = {"mcpWebServers": "not a dictionary"} + config_path.write_text(json.dumps(config_data)) + + servers = mcp_toolkit._load_servers_from_config( + str(config_path) + ) + # Should return an empty list and log a warning + assert servers == [] + mock_logger.warning.assert_called_with( + "'mcpWebServers' is not a dictionary, skipping..." + ) + + @pytest.mark.asyncio + async def test_connection(self): + r"""Test connection context manager.""" + server1 = _MCPServer("test_command1") + server2 = _MCPServer("test_command2") + toolkit = MCPToolkit(servers=[server1, server2]) + + # Mock the connection context managers of both servers + class MockAsyncContextManager: + async def __aenter__(self): + return AsyncMock() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return None + + server1.connection = MagicMock(return_value=MockAsyncContextManager()) + server2.connection = MagicMock(return_value=MockAsyncContextManager()) - # Test stdio connection - toolkit = MCPToolkit( - "local_command", args=["--arg1"], env={"ENV_VAR": "value"} - ) async with toolkit.connection() as connected_toolkit: - assert connected_toolkit._is_connected is True - assert connected_toolkit._mcp_tools == ["tool1", "tool2"] - - # Verify mocks were called correctly - mock_stdio_client.assert_called_once() - mock_session.assert_called_once() - mock_session_instance.initialize.assert_called_once() - mock_session_instance.list_tools.assert_called_once() - - -@pytest.mark.asyncio -async def test_list_mcp_tools_not_connected(): - r"""Test list_mcp_tools when not connected.""" - toolkit = MCPToolkit("test_command") - result = await toolkit.list_mcp_tools() - assert isinstance(result, str) - assert "not connected" in result - - -@pytest.mark.asyncio -async def test_list_mcp_tools_connected(): - r"""Test list_mcp_tools when connected.""" - toolkit = MCPToolkit("test_command") - toolkit._session = AsyncMock() - - # Mock successful response - mock_result = MagicMock() - toolkit._session.list_tools.return_value = mock_result - - result = await toolkit.list_mcp_tools() - assert result == mock_result - toolkit._session.list_tools.assert_called_once() - - # Mock exception - toolkit._session.list_tools.side_effect = Exception("Test error") - result = await toolkit.list_mcp_tools() - assert isinstance(result, str) - assert "Failed to list MCP tools" in result - - -@pytest.mark.asyncio -async def test_generate_function_from_mcp_tool(): - r"""Test generate_function_from_mcp_tool.""" - toolkit = MCPToolkit("test_command") - toolkit._session = AsyncMock() - - # Create mock MCP tool - mock_tool = MagicMock() - mock_tool.name = "test_function" - mock_tool.description = "Test function description" - mock_tool.inputSchema = { - "properties": { - "param1": {"type": "string"}, - "param2": {"type": "integer"}, - "param3": {"type": "boolean"}, - }, - "required": ["param1", "param2"], - } - - # Generate function - func = toolkit.generate_function_from_mcp_tool(mock_tool) - - # Check function attributes - assert func.__name__ == "test_function" - assert func.__doc__ == "Test function description" - assert "param1" in func.__annotations__ - assert "param2" in func.__annotations__ - assert "param3" in func.__annotations__ - - # Mock call_tool response - mock_content = MagicMock() - mock_content.type = "text" - mock_content.text = "Test result" - - mock_result = MagicMock() - mock_result.content = [mock_content] - toolkit._session.call_tool.return_value = mock_result - - # Test function call - result = await func(param1="test", param2=123) - assert result == "Test result" - toolkit._session.call_tool.assert_called_once_with( - "test_function", {"param1": "test", "param2": 123} - ) - - # Test missing required parameter - with pytest.raises(ValueError) as excinfo: - await func(param1="test") - assert "Missing required parameters" in str(excinfo.value) - - # Test different content types - # Image content - mock_content.type = "image" - mock_content.url = "https://example.com/image.jpg" - result = await func(param1="test", param2=123) - assert "Image available at" in result - - # Image without URL - mock_content.url = None - result = await func(param1="test", param2=123) - assert "Image content received" in result - - # Embedded resource - mock_content.type = "embedded_resource" - mock_content.name = "resource.pdf" - result = await func(param1="test", param2=123) - assert "Embedded resource: resource.pdf" in result - - # Embedded resource without name - mock_content.name = None - result = await func(param1="test", param2=123) - assert "Embedded resource received" in result - - # Unknown content type - mock_content.type = "unknown" - result = await func(param1="test", param2=123) - assert "not fully supported" in result - - # No content - mock_result.content = [] - result = await func(param1="test", param2=123) - assert "No data available" in result - - -@pytest.mark.asyncio -async def test_get_tools(): - r"""Test get_tools method.""" - with patch( - "camel.toolkits.mcp_toolkit.FunctionTool" - ) as mock_function_tool: - toolkit = MCPToolkit("test_command") - - # Mock tools + assert connected_toolkit._connected is True + assert isinstance(connected_toolkit._exit_stack, AsyncExitStack) + + assert toolkit._connected is False + assert isinstance(toolkit._exit_stack, AsyncExitStack) + + def test_is_connected(self): + r"""Test is_connected method.""" + toolkit = MCPToolkit(servers=[_MCPServer("test_command")]) + assert toolkit.is_connected() is False + toolkit._connected = True + assert toolkit.is_connected() is True + + @pytest.mark.asyncio + async def test_get_tools(self): + r"""Test get_tools method.""" + server1 = _MCPServer("test_command1") + server2 = _MCPServer("test_command2") + toolkit = MCPToolkit(servers=[server1, server2]) + + # Mock get_tools for both servers mock_tool1 = MagicMock() mock_tool2 = MagicMock() - toolkit._mcp_tools = [mock_tool1, mock_tool2] + mock_tool3 = MagicMock() - # Mock generate_function_from_mcp_tool - mock_func1 = AsyncMock() - mock_func2 = AsyncMock() - toolkit.generate_function_from_mcp_tool = MagicMock( - side_effect=[mock_func1, mock_func2] - ) + server1.get_tools = MagicMock(return_value=[mock_tool1, mock_tool2]) + server2.get_tools = MagicMock(return_value=[mock_tool3]) - # Mock FunctionTool - mock_function_tool_instance1 = MagicMock() - mock_function_tool_instance2 = MagicMock() - mock_function_tool.side_effect = [ - mock_function_tool_instance1, - mock_function_tool_instance2, - ] - - # Get tools tools = toolkit.get_tools() + assert len(tools) == 3 + assert tools == [mock_tool1, mock_tool2, mock_tool3] + server1.get_tools.assert_called_once() + server2.get_tools.assert_called_once() + + @pytest.mark.asyncio + async def test_connect(self): + r"""Test explicit connect method.""" + # Create mock servers + server1 = _MCPServer("test_command1") + server2 = _MCPServer("test_command2") + + # Mock connect methods + server1.connect = AsyncMock(return_value=server1) + server2.connect = AsyncMock(return_value=server2) + + # Create toolkit with mock servers + toolkit = MCPToolkit(servers=[server1, server2]) + + # Test connect + result = await toolkit.connect() + + # Verify results + assert result == toolkit + assert toolkit._connected is True + assert toolkit._exit_stack is not None + server1.connect.assert_called_once() + server2.connect.assert_called_once() + + # Test connecting when already connected + with patch("camel.toolkits.mcp_toolkit.logger") as mock_logger: + result = await toolkit.connect() + assert result == toolkit + mock_logger.warning.assert_called_once() + # Verify servers not connected again + assert server1.connect.call_count == 1 + assert server2.connect.call_count == 1 + + @pytest.mark.asyncio + async def test_connect_failure(self): + r"""Test connect method with failure.""" + # Create mock servers + server1 = _MCPServer("test_command1") + server2 = _MCPServer("test_command2") + + # First server connects successfully, second fails + server1.connect = AsyncMock(return_value=server1) + server2.connect = AsyncMock(side_effect=Exception("Connection error")) + + # Create toolkit with mock servers + toolkit = MCPToolkit(servers=[server1, server2]) + + # Mock disconnect to verify it's called on failure + toolkit.disconnect = AsyncMock() + + # Test connect with failure + with patch("camel.toolkits.mcp_toolkit.logger") as mock_logger: + await toolkit.connect() + + # Verify disconnect was called to clean up + toolkit.disconnect.assert_called_once() + mock_logger.error.assert_called_once() + + # Verify first server was connected + server1.connect.assert_called_once() + # Verify second server was attempted + server2.connect.assert_called_once() + + @pytest.mark.asyncio + async def test_disconnect(self): + r"""Test explicit disconnect method.""" + # Create mock servers + server1 = _MCPServer("test_command1") + server2 = _MCPServer("test_command2") + + # Mock disconnect methods + server1.disconnect = AsyncMock() + server2.disconnect = AsyncMock() + + # Create toolkit with mock servers + toolkit = MCPToolkit(servers=[server1, server2]) + + # Setup connected state + toolkit._connected = True + toolkit._exit_stack = AsyncMock() + toolkit._exit_stack.aclose = AsyncMock() + + # Test disconnect + await toolkit.disconnect() + # Verify results - assert len(tools) == 2 - assert tools[0] == mock_function_tool_instance1 - assert tools[1] == mock_function_tool_instance2 - - # Verify mocks were called correctly - toolkit.generate_function_from_mcp_tool.assert_any_call(mock_tool1) - toolkit.generate_function_from_mcp_tool.assert_any_call(mock_tool2) - mock_function_tool.assert_any_call(mock_func1) - mock_function_tool.assert_any_call(mock_func2) + assert toolkit._connected is False + server1.disconnect.assert_called_once() + server2.disconnect.assert_called_once() + toolkit._exit_stack.aclose.assert_called_once() + + # Test disconnecting when not connected + server1.disconnect.reset_mock() + server2.disconnect.reset_mock() + toolkit._exit_stack.aclose.reset_mock() + + await toolkit.disconnect() + + # Verify no actions taken when not connected + server1.disconnect.assert_not_called() + server2.disconnect.assert_not_called() + toolkit._exit_stack.aclose.assert_not_called()