From 601563699ca533c4549a66101ecdc73f8c363013 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Sat, 25 Jan 2025 13:28:34 -0800 Subject: [PATCH 01/16] initial team manager refactor --- .../autogenstudio/teammanager.py | 73 --------- .../autogenstudio/teammanager/__init__.py | 6 + .../autogenstudio/teammanager/teammanager.py | 141 ++++++++++++++++++ 3 files changed, 147 insertions(+), 73 deletions(-) delete mode 100644 python/packages/autogen-studio/autogenstudio/teammanager.py create mode 100644 python/packages/autogen-studio/autogenstudio/teammanager/__init__.py create mode 100644 python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py diff --git a/python/packages/autogen-studio/autogenstudio/teammanager.py b/python/packages/autogen-studio/autogenstudio/teammanager.py deleted file mode 100644 index b4ba92460ef3..000000000000 --- a/python/packages/autogen-studio/autogenstudio/teammanager.py +++ /dev/null @@ -1,73 +0,0 @@ -import time -from typing import AsyncGenerator, Callable, Optional, Union - -from autogen_agentchat.base import TaskResult -from autogen_agentchat.messages import AgentEvent, ChatMessage -from autogen_core import CancellationToken - -from .database import Component, ComponentFactory -from .datamodel import ComponentConfigInput, TeamResult - - -class TeamManager: - def __init__(self) -> None: - self.component_factory = ComponentFactory() - - async def _create_team(self, team_config: ComponentConfigInput, input_func: Optional[Callable] = None) -> Component: - """Create team instance with common setup logic""" - return await self.component_factory.load(team_config, input_func=input_func) - - def _create_result(self, task_result: TaskResult, start_time: float) -> TeamResult: - """Create TeamResult with timing info""" - return TeamResult(task_result=task_result, usage="", duration=time.time() - start_time) - - async def run_stream( - self, - task: str, - team_config: ComponentConfigInput, - input_func: Optional[Callable] = None, - cancellation_token: Optional[CancellationToken] = None, - ) -> AsyncGenerator[Union[AgentEvent | ChatMessage, ChatMessage, TaskResult], None]: - """Stream the team's execution results""" - start_time = time.time() - - try: - team = await self._create_team(team_config, input_func) - stream = team.run_stream(task=task, cancellation_token=cancellation_token) - - async for message in stream: - if cancellation_token and cancellation_token.is_cancelled(): - break - - if isinstance(message, TaskResult): - yield self._create_result(message, start_time) - else: - yield message - - # close agent resources - for agent in team._participants: - if hasattr(agent, "close"): - await agent.close() - - except Exception as e: - raise e - - async def run( - self, - task: str, - team_config: ComponentConfigInput, - input_func: Optional[Callable] = None, - cancellation_token: Optional[CancellationToken] = None, - ) -> TeamResult: - """Original non-streaming run method with optional cancellation""" - start_time = time.time() - - team = await self._create_team(team_config, input_func) - result = await team.run(task=task, cancellation_token=cancellation_token) - - # close agent resources - for agent in team._participants: - if hasattr(agent, "close"): - await agent.close() - - return self._create_result(result, start_time) diff --git a/python/packages/autogen-studio/autogenstudio/teammanager/__init__.py b/python/packages/autogen-studio/autogenstudio/teammanager/__init__.py new file mode 100644 index 000000000000..74d5821674e7 --- /dev/null +++ b/python/packages/autogen-studio/autogenstudio/teammanager/__init__.py @@ -0,0 +1,6 @@ +from .teammanager import TeamManager + + +__all__ = [ + "TeamManager" +] \ No newline at end of file diff --git a/python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py b/python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py new file mode 100644 index 000000000000..78a9b9bac3ba --- /dev/null +++ b/python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py @@ -0,0 +1,141 @@ +import json +import logging +import time +from pathlib import Path +from typing import AsyncGenerator, Callable, List, Optional, Union + +import aiofiles +import yaml +from autogen_agentchat.base import TaskResult +from autogen_agentchat.messages import AgentEvent, ChatMessage +from autogen_core import CancellationToken, Component, ComponentModel +from autogen_agentchat.base import Team +from ..datamodel.types import TeamResult + +logger = logging.getLogger(__name__) + + + +class TeamManager: + """Manages team operations including loading configs and running teams""" + + @staticmethod + async def load_from_file(path: Union[str, Path]) -> dict: + """Load team configuration from JSON/YAML file""" + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Config file not found: {path}") + + async with aiofiles.open(path) as f: + content = await f.read() + if path.suffix == ".json": + return json.loads(content) + elif path.suffix in (".yml", ".yaml"): + return yaml.safe_load(content) + raise ValueError(f"Unsupported file format: {path.suffix}") + + @staticmethod + async def load_from_directory(directory: Union[str, Path]) -> List[dict]: + """Load all team configurations from a directory""" + directory = Path(directory) + configs = [] + + for path in directory.glob("*.[json,yaml,yml]"): + try: + config = await TeamManager.load_from_file(path) + configs.append(config) + except Exception as e: + logger.error(f"Failed to load {path}: {e}") + + return configs + + async def _create_team( + self, + team_config: Union[str, Path, dict, ComponentModel], + input_func: Optional[Callable] = None + ) -> Component: + """Create team instance from config""" + # Handle different input types + if isinstance(team_config, (str, Path)): + config = await self.load_from_file(team_config) + elif isinstance(team_config, dict): + config = team_config + else: + config = team_config.model_dump() + + # Use Component.load_component directly + team = Team.load_component(config) + + # Set input function if provided + if input_func and hasattr(team, "set_input_func"): + team.set_input_func(input_func) + + return team + + async def run_stream( + self, + task: str, + team_config: Union[str, Path, dict, ComponentModel], + input_func: Optional[Callable] = None, + cancellation_token: Optional[CancellationToken] = None, + ) -> AsyncGenerator[Union[AgentEvent | ChatMessage, ChatMessage, TaskResult], None]: + """Stream team execution results""" + start_time = time.time() + team = None + + try: + team = await self._create_team(team_config, input_func) + + async for message in team.run_stream( + task=task, + cancellation_token=cancellation_token + ): + if cancellation_token and cancellation_token.is_cancelled(): + break + + if isinstance(message, TaskResult): + yield TeamResult( + task_result=message, + usage="", + duration=time.time() - start_time + ) + else: + yield message + + finally: + # Ensure cleanup happens + if team and hasattr(team, "_participants"): + for agent in team._participants: + if hasattr(agent, "close"): + await agent.close() + + async def run( + self, + task: str, + team_config: Union[str, Path, dict, ComponentModel], + input_func: Optional[Callable] = None, + cancellation_token: Optional[CancellationToken] = None, + ) -> TeamResult: + """Run team synchronously""" + start_time = time.time() + team = None + + try: + team = await self._create_team(team_config, input_func) + result = await team.run( + task=task, + cancellation_token=cancellation_token + ) + + return TeamResult( + task_result=result, + usage="", + duration=time.time() - start_time + ) + + finally: + # Ensure cleanup happens + if team and hasattr(team, "_participants"): + for agent in team._participants: + if hasattr(agent, "close"): + await agent.close() \ No newline at end of file From 71a893d261e79d989a81b231982279fcff6d36d2 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Sat, 25 Jan 2025 13:29:46 -0800 Subject: [PATCH 02/16] update team manager --- .../autogen-studio/autogenstudio/teammanager/teammanager.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py b/python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py index 78a9b9bac3ba..4641d5d1302b 100644 --- a/python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py +++ b/python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py @@ -66,10 +66,7 @@ async def _create_team( # Use Component.load_component directly team = Team.load_component(config) - # Set input function if provided - if input_func and hasattr(team, "set_input_func"): - team.set_input_func(input_func) - + # TBD - set input function return team async def run_stream( From e2c16ee44b312d6637a11d6fe9a48f5f3169d478 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Sat, 25 Jan 2025 15:46:43 -0800 Subject: [PATCH 03/16] move db import to db manager --- .../autogenstudio/database/db_manager.py | 238 ++++++++---------- .../autogenstudio/teammanager/teammanager.py | 27 +- 2 files changed, 118 insertions(+), 147 deletions(-) diff --git a/python/packages/autogen-studio/autogenstudio/database/db_manager.py b/python/packages/autogen-studio/autogenstudio/database/db_manager.py index 2b9ec43e3bd5..d27764b8cd93 100644 --- a/python/packages/autogen-studio/autogenstudio/database/db_manager.py +++ b/python/packages/autogen-studio/autogenstudio/database/db_manager.py @@ -1,13 +1,14 @@ import threading from datetime import datetime from pathlib import Path -from typing import Optional +from typing import Optional, Union from loguru import logger from sqlalchemy import exc, func, inspect, text from sqlmodel import Session, SQLModel, and_, create_engine, select -from ..datamodel import LinkTypes, Response +from ..datamodel import LinkTypes, Response, Team +from ..teammanager import TeamManager from .schema_manager import SchemaManager # from .dbutils import init_db_samples @@ -247,152 +248,113 @@ def delete(self, model_class: SQLModel, filters: dict = None): return Response(message=status_message, status=status, data=None) - def link( + + async def import_team( self, - link_type: LinkTypes, - primary_id: int, - secondary_id: int, - sequence: Optional[int] = None, - ): - """Link two entities with automatic sequence handling.""" - with Session(self.engine) as session: - try: - # Get classes from LinkTypes - primary_class = link_type.primary_class - secondary_class = link_type.secondary_class - link_table = link_type.link_table - - # Get entities - primary_entity = session.get(primary_class, primary_id) - secondary_entity = session.get(secondary_class, secondary_id) - - if not primary_entity or not secondary_entity: - return Response(message="One or both entities do not exist", status=False) - - # Get field names - primary_id_field = f"{primary_class.__name__.lower()}_id" - secondary_id_field = f"{secondary_class.__name__.lower()}_id" - - # Check for existing link - existing_link = session.exec( - select(link_table).where( - and_( - getattr(link_table, primary_id_field) == primary_id, - getattr(link_table, secondary_id_field) == secondary_id, - ) + team_config: Union[str, Path, dict], + user_id: str, + check_exists: bool = False + ) -> Response: + try: + # Load config if path provided + if isinstance(team_config, (str, Path)): + config = await TeamManager.load_from_file(team_config) + else: + config = team_config + + # Check existence if requested + if check_exists: + existing = await self._check_team_exists(config, user_id) + if existing: + return Response( + message="Identical team configuration already exists", + status=True, + data={"id": existing.id} ) - ).first() - - if existing_link: - return Response(message="Link already exists", status=False) - - # Get the next sequence number if not provided - if sequence is None: - max_seq_result = session.exec( - select(func.max(link_table.sequence)).where(getattr(link_table, primary_id_field) == primary_id) - ).first() - sequence = 0 if max_seq_result is None else max_seq_result + 1 - - # Create new link - new_link = link_table( - **{primary_id_field: primary_id, secondary_id_field: secondary_id, "sequence": sequence} - ) - session.add(new_link) - session.commit() - return Response(message=f"Entities linked successfully with sequence {sequence}", status=True) + # Store in database + team_db = Team( + user_id=user_id, + config=config + ) + + result = self.upsert(team_db) + return result - except Exception as e: - session.rollback() - return Response(message=f"Error linking entities: {str(e)}", status=False) + except Exception as e: + logger.error(f"Failed to import team: {str(e)}") + return Response(message=str(e), status=False) - def unlink(self, link_type: LinkTypes, primary_id: int, secondary_id: int, sequence: Optional[int] = None): - """Unlink two entities and reorder sequences if needed.""" - with Session(self.engine) as session: - try: - # Get classes from LinkTypes - primary_class = link_type.primary_class - secondary_class = link_type.secondary_class - link_table = link_type.link_table - - # Get field names - primary_id_field = f"{primary_class.__name__.lower()}_id" - secondary_id_field = f"{secondary_class.__name__.lower()}_id" - - # Find existing link - statement = select(link_table).where( - and_( - getattr(link_table, primary_id_field) == primary_id, - getattr(link_table, secondary_id_field) == secondary_id, + async def import_teams_from_directory( + self, + directory: Union[str, Path], + user_id: str, + check_exists: bool = False + ) -> Response: + """ + Import all team configurations from a directory. + + Args: + directory: Path to directory containing team configs + user_id: User ID to associate with imported teams + check_exists: Whether to check for existing teams + + Returns: + Response containing import results for all files + """ + try: + # Load all configs from directory + configs = await TeamManager.load_from_directory(directory) + + results = [] + for config in configs: + try: + result = await self.import_team( + team_config=config, + user_id=user_id, + check_exists=check_exists ) - ) - - if sequence is not None: - statement = statement.where(link_table.sequence == sequence) - - existing_link = session.exec(statement).first() - - if not existing_link: - return Response(message="Link does not exist", status=False) - - deleted_sequence = existing_link.sequence - session.delete(existing_link) - - # Reorder sequences for remaining links - remaining_links = session.exec( - select(link_table) - .where(getattr(link_table, primary_id_field) == primary_id) - .where(link_table.sequence > deleted_sequence) - .order_by(link_table.sequence) - ).all() - - # Decrease sequence numbers to fill the gap - for link in remaining_links: - link.sequence -= 1 - - session.commit() + + # Add result info + results.append({ + "status": result.status, + "message": result.message, + "id": result.data.get("id") if result.status else None + }) + + except Exception as e: + logger.error(f"Failed to import team config: {str(e)}") + results.append({ + "status": False, + "message": str(e), + "id": None + }) - return Response(message="Entities unlinked successfully and sequences reordered", status=True) + return Response( + message="Directory import complete", + status=True, + data=results + ) - except Exception as e: - session.rollback() - return Response(message=f"Error unlinking entities: {str(e)}", status=False) + except Exception as e: + logger.error(f"Failed to import directory: {str(e)}") + return Response(message=str(e), status=False) - def get_linked_entities( + async def _check_team_exists( self, - link_type: LinkTypes, - primary_id: int, - return_json: bool = False, - ): - """Get linked entities based on link type and primary ID, ordered by sequence.""" - with Session(self.engine) as session: - try: - # Get classes from LinkTypes - primary_class = link_type.primary_class - secondary_class = link_type.secondary_class - link_table = link_type.link_table - - # Get field names - primary_id_field = f"{primary_class.__name__.lower()}_id" - secondary_id_field = f"{secondary_class.__name__.lower()}_id" - - # Query both link and entity, ordered by sequence - items = session.exec( - select(secondary_class) - .join(link_table, getattr(link_table, secondary_id_field) == secondary_class.id) - .where(getattr(link_table, primary_id_field) == primary_id) - .order_by(link_table.sequence) - ).all() - - result = [item.model_dump() if return_json else item for item in items] - - return Response(message="Linked entities retrieved successfully", status=True, data=result) - - except Exception as e: - logger.error(f"Error getting linked entities: {str(e)}") - return Response(message=f"Error getting linked entities: {str(e)}", status=False, data=[]) - - # Add new close method + config: dict, + user_id: str + ) -> Optional[Team]: + """Check if identical team config already exists""" + teams = self.get(Team, {"user_id": user_id}).data + + + for team in teams: + print(team.config, "******" ,config) + if team.config == config: + return team + + return None async def close(self): """Close database connections and cleanup resources""" diff --git a/python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py b/python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py index 4641d5d1302b..663a5442627c 100644 --- a/python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py +++ b/python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py @@ -34,19 +34,28 @@ async def load_from_file(path: Union[str, Path]) -> dict: return yaml.safe_load(content) raise ValueError(f"Unsupported file format: {path.suffix}") - @staticmethod + @staticmethod async def load_from_directory(directory: Union[str, Path]) -> List[dict]: - """Load all team configurations from a directory""" + """Load all team configurations from a directory + + Args: + directory (Union[str, Path]): Path to directory containing config files + + Returns: + List[dict]: List of loaded team configurations + """ directory = Path(directory) configs = [] + valid_extensions = {'.json', '.yaml', '.yml'} - for path in directory.glob("*.[json,yaml,yml]"): - try: - config = await TeamManager.load_from_file(path) - configs.append(config) - except Exception as e: - logger.error(f"Failed to load {path}: {e}") - + for path in directory.iterdir(): + if path.is_file() and path.suffix.lower() in valid_extensions: + try: + config = await TeamManager.load_from_file(path) + configs.append(config) + except Exception as e: + logger.error(f"Failed to load {path}: {e}") + return configs async def _create_team( From 7e86ad1aed3adcc06dc902642ede90cf8ce8685c Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Sat, 25 Jan 2025 18:02:00 -0800 Subject: [PATCH 04/16] ui update --- .../autogen-studio/frontend/src/components/views/atoms.tsx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/packages/autogen-studio/frontend/src/components/views/atoms.tsx b/python/packages/autogen-studio/frontend/src/components/views/atoms.tsx index b3db483b4a29..fe6199083e9c 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/atoms.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/atoms.tsx @@ -1,5 +1,6 @@ import React, { memo, useState } from "react"; import { Loader2, Maximize2, Minimize2, X } from "lucide-react"; +import ReactMarkdown from "react-markdown"; export const LoadingIndicator = ({ size = 16 }: { size: number }) => (
@@ -81,6 +82,7 @@ export const TruncatableText = memo( `} > {displayContent} + {displayContent} {shouldTruncate && !isExpanded && (
)} From fc5565e7d68bf22f2163b08331d60a021ba32445 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Sun, 26 Jan 2025 11:02:54 -0800 Subject: [PATCH 05/16] initial updates/refactoring --- .../autogen-studio/autogenstudio/__init__.py | 16 +- .../autogenstudio/database/__init__.py | 7 +- .../database/component_factory.py | 503 ------------------ .../autogenstudio/database/config_manager.py | 268 ---------- .../autogenstudio/database/db_manager.py | 7 +- .../autogenstudio/datamodel/__init__.py | 17 +- .../autogenstudio/datamodel/db.py | 190 +------ .../autogenstudio/datamodel/types.py | 181 +------ .../autogen-studio/autogenstudio/web/app.py | 24 +- .../autogen-studio/autogenstudio/web/deps.py | 8 +- .../autogenstudio/web/routes/models.py | 42 -- .../autogenstudio/web/routes/tools.py | 41 -- .../src/components/types/datamodel.ts | 19 +- .../src/components/views/team/types.ts | 100 +++- .../tests/test_component_factory.py | 397 -------------- 15 files changed, 129 insertions(+), 1691 deletions(-) delete mode 100644 python/packages/autogen-studio/autogenstudio/database/component_factory.py delete mode 100644 python/packages/autogen-studio/autogenstudio/database/config_manager.py delete mode 100644 python/packages/autogen-studio/autogenstudio/web/routes/models.py delete mode 100644 python/packages/autogen-studio/autogenstudio/web/routes/tools.py delete mode 100644 python/packages/autogen-studio/tests/test_component_factory.py diff --git a/python/packages/autogen-studio/autogenstudio/__init__.py b/python/packages/autogen-studio/autogenstudio/__init__.py index 137cbad5a834..271a86ad35c0 100644 --- a/python/packages/autogen-studio/autogenstudio/__init__.py +++ b/python/packages/autogen-studio/autogenstudio/__init__.py @@ -1,18 +1,8 @@ from .database.db_manager import DatabaseManager -from .datamodel import Agent, AgentConfig, Model, ModelConfig, Team, TeamConfig, Tool, ToolConfig +from .datamodel import Team from .teammanager import TeamManager from .version import __version__ __all__ = [ - "Tool", - "Model", - "DatabaseManager", - "Team", - "Agent", - "ToolConfig", - "ModelConfig", - "TeamConfig", - "AgentConfig", - "TeamManager", - "__version__", -] + "DatabaseManager", "Team", "TeamManager", "__version__" +] \ No newline at end of file diff --git a/python/packages/autogen-studio/autogenstudio/database/__init__.py b/python/packages/autogen-studio/autogenstudio/database/__init__.py index acdf583557c3..e1f3c9dc61db 100644 --- a/python/packages/autogen-studio/autogenstudio/database/__init__.py +++ b/python/packages/autogen-studio/autogenstudio/database/__init__.py @@ -1,3 +1,6 @@ -from .component_factory import Component, ComponentFactory -from .config_manager import ConfigurationManager + from .db_manager import DatabaseManager + +__all__ = [ + "DatabaseManager", +] diff --git a/python/packages/autogen-studio/autogenstudio/database/component_factory.py b/python/packages/autogen-studio/autogenstudio/database/component_factory.py deleted file mode 100644 index b954b39c0f4b..000000000000 --- a/python/packages/autogen-studio/autogenstudio/database/component_factory.py +++ /dev/null @@ -1,503 +0,0 @@ -import json -import logging -from datetime import datetime -from pathlib import Path -from typing import Callable, Dict, List, Literal, Optional, Union - -import aiofiles -import yaml -from autogen_agentchat.agents import AssistantAgent, UserProxyAgent -from autogen_agentchat.conditions import ( - ExternalTermination, - HandoffTermination, - MaxMessageTermination, - SourceMatchTermination, - StopMessageTermination, - TextMentionTermination, - TimeoutTermination, - TokenUsageTermination, -) -from autogen_agentchat.teams import MagenticOneGroupChat, RoundRobinGroupChat, SelectorGroupChat -from autogen_core.tools import FunctionTool -from autogen_ext.agents.file_surfer import FileSurfer -from autogen_ext.agents.magentic_one import MagenticOneCoderAgent -from autogen_ext.agents.web_surfer import MultimodalWebSurfer -from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient - -from ..datamodel.types import ( - AgentConfig, - AgentTypes, - AssistantAgentConfig, - AzureOpenAIModelConfig, - CombinationTerminationConfig, - ComponentConfig, - ComponentConfigInput, - ComponentTypes, - MagenticOneTeamConfig, - MaxMessageTerminationConfig, - ModelConfig, - ModelTypes, - MultimodalWebSurferAgentConfig, - OpenAIModelConfig, - RoundRobinTeamConfig, - SelectorTeamConfig, - TeamConfig, - TeamTypes, - TerminationConfig, - TerminationTypes, - TextMentionTerminationConfig, - ToolConfig, - ToolTypes, - UserProxyAgentConfig, -) -from ..utils.utils import Version - -logger = logging.getLogger(__name__) - -TeamComponent = Union[RoundRobinGroupChat, SelectorGroupChat, MagenticOneGroupChat] -AgentComponent = Union[AssistantAgent, MultimodalWebSurfer, UserProxyAgent, FileSurfer, MagenticOneCoderAgent] -ModelComponent = Union[OpenAIChatCompletionClient, AzureOpenAIChatCompletionClient] -ToolComponent = Union[FunctionTool] # Will grow with more tool types -TerminationComponent = Union[ - MaxMessageTermination, - StopMessageTermination, - TextMentionTermination, - TimeoutTermination, - ExternalTermination, - TokenUsageTermination, - HandoffTermination, - SourceMatchTermination, - StopMessageTermination, -] - -Component = Union[TeamComponent, AgentComponent, ModelComponent, ToolComponent, TerminationComponent] - -ReturnType = Literal["object", "dict", "config"] - -DEFAULT_SELECTOR_PROMPT = """You are in a role play game. The following roles are available: -{roles}. -Read the following conversation. Then select the next role from {participants} to play. Only return the role. - -{history} - -Read the above conversation. Then select the next role from {participants} to play. Only return the role. -""" - -CONFIG_RETURN_TYPES = Literal["object", "dict", "config"] - - -class ComponentFactory: - """Creates and manages agent components with versioned configuration loading""" - - SUPPORTED_VERSIONS = { - ComponentTypes.TEAM: ["1.0.0"], - ComponentTypes.AGENT: ["1.0.0"], - ComponentTypes.MODEL: ["1.0.0"], - ComponentTypes.TOOL: ["1.0.0"], - ComponentTypes.TERMINATION: ["1.0.0"], - } - - def __init__(self): - self._model_cache: Dict[str, ModelComponent] = {} - self._tool_cache: Dict[str, FunctionTool] = {} - self._last_cache_clear = datetime.now() - - async def load( - self, component: ComponentConfigInput, input_func: Optional[Callable] = None, return_type: ReturnType = "object" - ) -> Union[Component, dict, ComponentConfig]: - """ - Universal loader for any component type - - Args: - component: Component configuration (file path, dict, or ComponentConfig) - input_func: Optional callable for user input handling - return_type: Type of return value ('object', 'dict', or 'config') - - Returns: - Component instance, config dict, or ComponentConfig based on return_type - """ - try: - # Load and validate config - if isinstance(component, (str, Path)): - component_dict = await self._load_from_file(component) - config = self._dict_to_config(component_dict) - elif isinstance(component, dict): - config = self._dict_to_config(component) - else: - config = component - - # Validate version - if not self._is_version_supported(config.component_type, config.version): - raise ValueError( - f"Unsupported version {config.version} for " - f"component type {config.component_type}. " - f"Supported versions: {self.SUPPORTED_VERSIONS[config.component_type]}" - ) - - # Return early if dict or config requested - if return_type == "dict": - return config.model_dump() - elif return_type == "config": - return config - - # Otherwise create and return component instance - handlers = { - ComponentTypes.TEAM: lambda c: self.load_team(c, input_func), - ComponentTypes.AGENT: lambda c: self.load_agent(c, input_func), - ComponentTypes.MODEL: self.load_model, - ComponentTypes.TOOL: self.load_tool, - ComponentTypes.TERMINATION: self.load_termination, - } - - handler = handlers.get(config.component_type) - if not handler: - raise ValueError(f"Unknown component type: {config.component_type}") - - return await handler(config) - - except Exception as e: - logger.error(f"Failed to load component: {str(e)}") - raise - - async def load_directory( - self, directory: Union[str, Path], return_type: ReturnType = "object" - ) -> List[Union[Component, dict, ComponentConfig]]: - """ - Import all component configurations from a directory. - """ - components = [] - try: - directory = Path(directory) - # Using Path.iterdir() instead of os.listdir - for path in list(directory.glob("*")): - if path.suffix.lower().endswith((".json", ".yaml", ".yml")): - try: - component = await self.load(path, return_type=return_type) - components.append(component) - except Exception as e: - logger.info(f"Failed to load component: {str(e)}, {path}") - - return components - except Exception as e: - logger.info(f"Failed to load directory: {str(e)}") - return components - - def _dict_to_config(self, config_dict: dict) -> ComponentConfig: - """Convert dictionary to appropriate config type based on component_type and type discriminator""" - if "component_type" not in config_dict: - raise ValueError("component_type is required in configuration") - - component_type = ComponentTypes(config_dict["component_type"]) - - # Define mapping structure - type_mappings = { - ComponentTypes.MODEL: { - "discriminator": "model_type", - ModelTypes.OPENAI.value: OpenAIModelConfig, - ModelTypes.AZUREOPENAI.value: AzureOpenAIModelConfig, - }, - ComponentTypes.AGENT: { - "discriminator": "agent_type", - AgentTypes.ASSISTANT.value: AssistantAgentConfig, - AgentTypes.USERPROXY.value: UserProxyAgentConfig, - AgentTypes.MULTIMODAL_WEBSURFER.value: MultimodalWebSurferAgentConfig, - }, - ComponentTypes.TEAM: { - "discriminator": "team_type", - TeamTypes.ROUND_ROBIN.value: RoundRobinTeamConfig, - TeamTypes.SELECTOR.value: SelectorTeamConfig, - TeamTypes.MAGENTIC_ONE.value: MagenticOneTeamConfig, - }, - ComponentTypes.TOOL: ToolConfig, - ComponentTypes.TERMINATION: { - "discriminator": "termination_type", - TerminationTypes.MAX_MESSAGES.value: MaxMessageTerminationConfig, - TerminationTypes.TEXT_MENTION.value: TextMentionTerminationConfig, - TerminationTypes.COMBINATION.value: CombinationTerminationConfig, - }, - } - - mapping = type_mappings.get(component_type) - if not mapping: - raise ValueError(f"Unknown component type: {component_type}") - - # Handle simple cases (no discriminator) - if isinstance(mapping, type): - return mapping(**config_dict) - - # Get discriminator field value - discriminator = mapping["discriminator"] - if discriminator not in config_dict: - raise ValueError(f"Missing {discriminator} in configuration") - - type_value = config_dict[discriminator] - config_class = mapping.get(type_value) - - if not config_class: - raise ValueError(f"Unknown {discriminator}: {type_value}") - - return config_class(**config_dict) - - async def load_termination(self, config: TerminationConfig) -> TerminationComponent: - """Create termination condition instance from configuration.""" - try: - if config.termination_type == TerminationTypes.COMBINATION: - if not config.conditions or len(config.conditions) < 2: - raise ValueError("Combination termination requires at least 2 conditions") - if not config.operator: - raise ValueError("Combination termination requires an operator (and/or)") - - # Load first two conditions - conditions = [await self.load_termination(cond) for cond in config.conditions[:2]] - result = conditions[0] & conditions[1] if config.operator == "and" else conditions[0] | conditions[1] - - # Process remaining conditions if any - for condition in config.conditions[2:]: - next_condition = await self.load_termination(condition) - result = result & next_condition if config.operator == "and" else result | next_condition - - return result - - elif config.termination_type == TerminationTypes.MAX_MESSAGES: - if config.max_messages is None: - raise ValueError("max_messages parameter required for MaxMessageTermination") - return MaxMessageTermination(max_messages=config.max_messages) - - elif config.termination_type == TerminationTypes.STOP_MESSAGE: - return StopMessageTermination() - - elif config.termination_type == TerminationTypes.TEXT_MENTION: - if not config.text: - raise ValueError("text parameter required for TextMentionTermination") - return TextMentionTermination(text=config.text) - - else: - raise ValueError(f"Unsupported termination type: {config.termination_type}") - - except Exception as e: - logger.error(f"Failed to create termination condition: {str(e)}") - raise ValueError(f"Termination condition creation failed: {str(e)}") from e - - async def load_team(self, config: TeamConfig, input_func: Optional[Callable] = None) -> TeamComponent: - """Create team instance from configuration.""" - try: - # Load participants (agents) with input_func - participants = [] - for participant in config.participants: - agent = await self.load(participant, input_func=input_func) - participants.append(agent) - - # Load termination condition if specified - termination = None - if config.termination_condition: - termination = await self.load(config.termination_condition) - - # Create team based on type - if config.team_type == TeamTypes.ROUND_ROBIN: - return RoundRobinGroupChat(participants=participants, termination_condition=termination) - elif config.team_type == TeamTypes.SELECTOR: - model_client = await self.load(config.model_client) - if not model_client: - raise ValueError("SelectorGroupChat requires a model_client") - selector_prompt = config.selector_prompt if config.selector_prompt else DEFAULT_SELECTOR_PROMPT - return SelectorGroupChat( - participants=participants, - model_client=model_client, - termination_condition=termination, - selector_prompt=selector_prompt, - ) - elif config.team_type == TeamTypes.MAGENTIC_ONE: - model_client = await self.load(config.model_client) - if not model_client: - raise ValueError("MagenticOneGroupChat requires a model_client") - return MagenticOneGroupChat( - participants=participants, - model_client=model_client, - termination_condition=termination if termination is not None else None, - max_turns=config.max_turns if config.max_turns is not None else 20, - ) - else: - raise ValueError(f"Unsupported team type: {config.team_type}") - - except Exception as e: - logger.error(f"Failed to create team {config.name}: {str(e)}") - raise ValueError(f"Team creation failed: {str(e)}") from e - - async def load_agent(self, config: AgentConfig, input_func: Optional[Callable] = None) -> AgentComponent: - """Create agent instance from configuration.""" - - model_client = None - system_message = None - tools = [] - if hasattr(config, "system_message") and config.system_message: - system_message = config.system_message - if hasattr(config, "model_client") and config.model_client: - model_client = await self.load(config.model_client) - if hasattr(config, "tools") and config.tools: - for tool_config in config.tools: - tool = await self.load(tool_config) - tools.append(tool) - - try: - if config.agent_type == AgentTypes.USERPROXY: - return UserProxyAgent( - name=config.name, - description=config.description or "A human user", - input_func=input_func, # Pass through to UserProxyAgent - ) - elif config.agent_type == AgentTypes.ASSISTANT: - system_message = config.system_message if config.system_message else "You are a helpful assistant" - - return AssistantAgent( - name=config.name, - description=config.description or "A helpful assistant", - model_client=model_client, - tools=tools, - system_message=system_message, - ) - elif config.agent_type == AgentTypes.MULTIMODAL_WEBSURFER: - return MultimodalWebSurfer( - name=config.name, - model_client=model_client, - headless=config.headless if config.headless is not None else True, - debug_dir=config.logs_dir if config.logs_dir is not None else None, - downloads_folder=config.logs_dir if config.logs_dir is not None else None, - to_save_screenshots=config.to_save_screenshots if config.to_save_screenshots is not None else False, - use_ocr=config.use_ocr if config.use_ocr is not None else False, - animate_actions=config.animate_actions if config.animate_actions is not None else False, - ) - elif config.agent_type == AgentTypes.FILE_SURFER: - return FileSurfer( - name=config.name, - model_client=model_client, - ) - elif config.agent_type == AgentTypes.MAGENTIC_ONE_CODER: - return MagenticOneCoderAgent( - name=config.name, - model_client=model_client, - ) - else: - raise ValueError(f"Unsupported agent type: {config.agent_type}") - - except Exception as e: - logger.error(f"Failed to create agent {config.name}: {str(e)}") - raise ValueError(f"Agent creation failed: {str(e)}") from e - - async def load_model(self, config: ModelConfig) -> ModelComponent: - """Create model instance from configuration.""" - try: - # Check cache first - cache_key = str(config.model_dump()) - if cache_key in self._model_cache: - logger.debug(f"Using cached model for {config.model}") - return self._model_cache[cache_key] - - if config.model_type == ModelTypes.OPENAI: - args = { - "model": config.model, - "api_key": config.api_key, - "base_url": config.base_url, - } - - if hasattr(config, "model_capabilities") and config.model_capabilities is not None: - args["model_capabilities"] = config.model_capabilities - - model = OpenAIChatCompletionClient(**args) - self._model_cache[cache_key] = model - return model - elif config.model_type == ModelTypes.AZUREOPENAI: - model = AzureOpenAIChatCompletionClient( - azure_deployment=config.azure_deployment, - model=config.model, - api_version=config.api_version, - azure_endpoint=config.azure_endpoint, - api_key=config.api_key, - ) - self._model_cache[cache_key] = model - return model - else: - raise ValueError(f"Unsupported model type: {config.model_type}") - - except Exception as e: - logger.error(f"Failed to create model {config.model}: {str(e)}") - raise ValueError(f"Model creation failed: {str(e)}") from e - - async def load_tool(self, config: ToolConfig) -> ToolComponent: - """Create tool instance from configuration.""" - try: - # Validate required fields - if not all([config.name, config.description, config.content, config.tool_type]): - raise ValueError("Tool configuration missing required fields") - - # Check cache first - cache_key = str(config.model_dump()) - if cache_key in self._tool_cache: - logger.debug(f"Using cached tool '{config.name}'") - return self._tool_cache[cache_key] - - if config.tool_type == ToolTypes.PYTHON_FUNCTION: - tool = FunctionTool( - name=config.name, description=config.description, func=self._func_from_string(config.content) - ) - self._tool_cache[cache_key] = tool - return tool - else: - raise ValueError(f"Unsupported tool type: {config.tool_type}") - - except Exception as e: - logger.error(f"Failed to create tool '{config.name}': {str(e)}") - raise - - async def _load_from_file(self, path: Union[str, Path]) -> dict: - """Load configuration from JSON or YAML file.""" - path = Path(path) - if not path.exists(): - raise FileNotFoundError(f"Config file not found: {path}") - - try: - async with aiofiles.open(path) as f: - content = await f.read() - if path.suffix == ".json": - return json.loads(content) - elif path.suffix in (".yml", ".yaml"): - return yaml.safe_load(content) - else: - raise ValueError(f"Unsupported file format: {path.suffix}") - except Exception as e: - raise ValueError(f"Failed to load file {path}: {str(e)}") from e - - def _func_from_string(self, content: str) -> callable: - """Convert function string to callable.""" - try: - namespace = {} - exec(content, namespace) - for item in namespace.values(): - if callable(item) and not isinstance(item, type): - return item - raise ValueError("No function found in provided code") - except Exception as e: - raise ValueError(f"Failed to create function: {str(e)}") from e - - def _is_version_supported(self, component_type: ComponentTypes, ver: str) -> bool: - """Check if version is supported for component type.""" - try: - version = Version(ver) - supported = [Version(v) for v in self.SUPPORTED_VERSIONS[component_type]] - return any(version == v for v in supported) - except ValueError: - return False - - async def cleanup(self) -> None: - """Cleanup resources and clear caches.""" - for model in self._model_cache.values(): - if hasattr(model, "cleanup"): - await model.cleanup() - - for tool in self._tool_cache.values(): - if hasattr(tool, "cleanup"): - await tool.cleanup() - - self._model_cache.clear() - self._tool_cache.clear() - self._last_cache_clear = datetime.now() - logger.info("Cleared all component caches") diff --git a/python/packages/autogen-studio/autogenstudio/database/config_manager.py b/python/packages/autogen-studio/autogenstudio/database/config_manager.py deleted file mode 100644 index 3cd1b43d81ab..000000000000 --- a/python/packages/autogen-studio/autogenstudio/database/config_manager.py +++ /dev/null @@ -1,268 +0,0 @@ -import logging -from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -from loguru import logger - -from ..datamodel.db import Agent, LinkTypes, Model, Team, Tool -from ..datamodel.types import ComponentConfigInput, ComponentTypes, Response -from .component_factory import ComponentFactory -from .db_manager import DatabaseManager - - -class ConfigurationManager: - """Manages persistence and relationships of components using ComponentFactory for validation""" - - DEFAULT_UNIQUENESS_FIELDS = { - ComponentTypes.MODEL: ["model_type", "model"], - ComponentTypes.TOOL: ["name"], - ComponentTypes.AGENT: ["agent_type", "name"], - ComponentTypes.TEAM: ["team_type", "name"], - } - - def __init__(self, db_manager: DatabaseManager, uniqueness_fields: Dict[ComponentTypes, List[str]] = None): - self.db_manager = db_manager - self.component_factory = ComponentFactory() - self.uniqueness_fields = uniqueness_fields or self.DEFAULT_UNIQUENESS_FIELDS - - async def import_component( - self, component_config: ComponentConfigInput, user_id: str, check_exists: bool = False - ) -> Response: - """ - Import a component configuration, validate it, and store the resulting component. - - Args: - component_config: Configuration for the component (file path, dict, or ComponentConfig) - user_id: User ID to associate with imported component - check_exists: Whether to check for existing components before storing (default: False) - - Returns: - Response containing import results or error - """ - try: - # Get validated config as dict - config = await self.component_factory.load(component_config, return_type="dict") - - # Get component type - component_type = self._determine_component_type(config) - if not component_type: - raise ValueError("Unable to determine component type from config") - - # Check existence if requested - if check_exists: - existing = self._check_exists(component_type, config, user_id) - if existing: - return Response( - message=self._format_exists_message(component_type, config), - status=True, - data={"id": existing.id}, - ) - - # Route to appropriate storage method - if component_type == ComponentTypes.TEAM: - return await self._store_team(config, user_id, check_exists) - elif component_type == ComponentTypes.AGENT: - return await self._store_agent(config, user_id, check_exists) - elif component_type == ComponentTypes.MODEL: - return await self._store_model(config, user_id) - elif component_type == ComponentTypes.TOOL: - return await self._store_tool(config, user_id) - else: - raise ValueError(f"Unsupported component type: {component_type}") - - except Exception as e: - logger.error(f"Failed to import component: {str(e)}") - return Response(message=str(e), status=False) - - async def import_directory(self, directory: Union[str, Path], user_id: str, check_exists: bool = False) -> Response: - """ - Import all component configurations from a directory. - - Args: - directory: Path to directory containing configuration files - user_id: User ID to associate with imported components - check_exists: Whether to check for existing components before storing (default: False) - - Returns: - Response containing import results for all files - """ - try: - configs = await self.component_factory.load_directory(directory, return_type="dict") - - results = [] - for config in configs: - result = await self.import_component(config, user_id, check_exists) - results.append( - { - "component": self._get_component_type(config), - "status": result.status, - "message": result.message, - "id": result.data.get("id") if result.status else None, - } - ) - - return Response(message="Directory import complete", status=True, data=results) - - except Exception as e: - logger.error(f"Failed to import directory: {str(e)}") - return Response(message=str(e), status=False) - - async def _store_team(self, config: dict, user_id: str, check_exists: bool = False) -> Response: - """Store team component and manage its relationships with agents""" - try: - # Store the team - team_db = Team(user_id=user_id, config=config) - team_result = self.db_manager.upsert(team_db) - if not team_result.status: - return team_result - - team_id = team_result.data["id"] - - # Handle participants (agents) - for participant in config.get("participants", []): - if check_exists: - # Check for existing agent - agent_type = self._determine_component_type(participant) - existing_agent = self._check_exists(agent_type, participant, user_id) - if existing_agent: - # Link existing agent - self.db_manager.link(LinkTypes.TEAM_AGENT, team_id, existing_agent.id) - logger.info(f"Linked existing agent to team: {existing_agent}") - continue - - # Store and link new agent - agent_result = await self._store_agent(participant, user_id, check_exists) - if agent_result.status: - self.db_manager.link(LinkTypes.TEAM_AGENT, team_id, agent_result.data["id"]) - - return team_result - - except Exception as e: - logger.error(f"Failed to store team: {str(e)}") - return Response(message=str(e), status=False) - - async def _store_agent(self, config: dict, user_id: str, check_exists: bool = False) -> Response: - """Store agent component and manage its relationships with tools and model""" - try: - # Store the agent - agent_db = Agent(user_id=user_id, config=config) - agent_result = self.db_manager.upsert(agent_db) - if not agent_result.status: - return agent_result - - agent_id = agent_result.data["id"] - - # Handle model client - if "model_client" in config: - if check_exists: - # Check for existing model - model_type = self._determine_component_type(config["model_client"]) - existing_model = self._check_exists(model_type, config["model_client"], user_id) - if existing_model: - # Link existing model - self.db_manager.link(LinkTypes.AGENT_MODEL, agent_id, existing_model.id) - logger.info(f"Linked existing model to agent: {existing_model.config.model_type}") - else: - # Store and link new model - model_result = await self._store_model(config["model_client"], user_id) - if model_result.status: - self.db_manager.link(LinkTypes.AGENT_MODEL, agent_id, model_result.data["id"]) - else: - # Store and link new model without checking - model_result = await self._store_model(config["model_client"], user_id) - if model_result.status: - self.db_manager.link(LinkTypes.AGENT_MODEL, agent_id, model_result.data["id"]) - - # Handle tools - for tool_config in config.get("tools", []): - if check_exists: - # Check for existing tool - tool_type = self._determine_component_type(tool_config) - existing_tool = self._check_exists(tool_type, tool_config, user_id) - if existing_tool: - # Link existing tool - self.db_manager.link(LinkTypes.AGENT_TOOL, agent_id, existing_tool.id) - logger.info(f"Linked existing tool to agent: {existing_tool.config.name}") - continue - - # Store and link new tool - tool_result = await self._store_tool(tool_config, user_id) - if tool_result.status: - self.db_manager.link(LinkTypes.AGENT_TOOL, agent_id, tool_result.data["id"]) - - return agent_result - - except Exception as e: - logger.error(f"Failed to store agent: {str(e)}") - return Response(message=str(e), status=False) - - async def _store_model(self, config: dict, user_id: str) -> Response: - """Store model component (leaf node - no relationships)""" - try: - model_db = Model(user_id=user_id, config=config) - return self.db_manager.upsert(model_db) - - except Exception as e: - logger.error(f"Failed to store model: {str(e)}") - return Response(message=str(e), status=False) - - async def _store_tool(self, config: dict, user_id: str) -> Response: - """Store tool component (leaf node - no relationships)""" - try: - tool_db = Tool(user_id=user_id, config=config) - return self.db_manager.upsert(tool_db) - - except Exception as e: - logger.error(f"Failed to store tool: {str(e)}") - return Response(message=str(e), status=False) - - def _check_exists( - self, component_type: ComponentTypes, config: dict, user_id: str - ) -> Optional[Union[Model, Tool, Agent, Team]]: - """Check if component exists based on configured uniqueness fields.""" - fields = self.uniqueness_fields.get(component_type, []) - if not fields: - return None - - component_class = { - ComponentTypes.MODEL: Model, - ComponentTypes.TOOL: Tool, - ComponentTypes.AGENT: Agent, - ComponentTypes.TEAM: Team, - }.get(component_type) - - components = self.db_manager.get(component_class, {"user_id": user_id}).data - - for component in components: - matches = all(component.config.get(field) == config.get(field) for field in fields) - if matches: - return component - - return None - - def _format_exists_message(self, component_type: ComponentTypes, config: dict) -> str: - """Format existence message with identifying fields.""" - fields = self.uniqueness_fields.get(component_type, []) - field_values = [f"{field}='{config.get(field)}'" for field in fields] - return f"{component_type.value} with {' and '.join(field_values)} already exists" - - def _determine_component_type(self, config: dict) -> Optional[ComponentTypes]: - """Determine component type from configuration dictionary""" - if "team_type" in config: - return ComponentTypes.TEAM - elif "agent_type" in config: - return ComponentTypes.AGENT - elif "model_type" in config: - return ComponentTypes.MODEL - elif "tool_type" in config: - return ComponentTypes.TOOL - return None - - def _get_component_type(self, config: dict) -> str: - """Helper to get component type string from config""" - component_type = self._determine_component_type(config) - return component_type.value if component_type else "unknown" - - async def cleanup(self): - """Cleanup resources""" - await self.component_factory.cleanup() diff --git a/python/packages/autogen-studio/autogenstudio/database/db_manager.py b/python/packages/autogen-studio/autogenstudio/database/db_manager.py index d27764b8cd93..bf5358436679 100644 --- a/python/packages/autogen-studio/autogenstudio/database/db_manager.py +++ b/python/packages/autogen-studio/autogenstudio/database/db_manager.py @@ -4,14 +4,13 @@ from typing import Optional, Union from loguru import logger -from sqlalchemy import exc, func, inspect, text +from sqlalchemy import exc, inspect, text from sqlmodel import Session, SQLModel, and_, create_engine, select -from ..datamodel import LinkTypes, Response, Team +from ..datamodel import Response, Team from ..teammanager import TeamManager from .schema_manager import SchemaManager -# from .dbutils import init_db_samples class DatabaseManager: @@ -301,7 +300,7 @@ async def import_teams_from_directory( Returns: Response containing import results for all files - """ + """ try: # Load all configs from directory configs = await TeamManager.load_from_directory(directory) diff --git a/python/packages/autogen-studio/autogenstudio/datamodel/__init__.py b/python/packages/autogen-studio/autogenstudio/datamodel/__init__.py index 0d46fb26334e..87b5775c4ef6 100644 --- a/python/packages/autogen-studio/autogenstudio/datamodel/__init__.py +++ b/python/packages/autogen-studio/autogenstudio/datamodel/__init__.py @@ -1,11 +1,10 @@ -from .db import Agent, LinkTypes, Message, Model, Run, RunStatus, Session, Team, Tool +from .db import Team, Run, RunStatus, Session, Team, Message from .types import ( - AgentConfig, - ComponentConfigInput, - MessageConfig, - ModelConfig, - Response, - TeamConfig, - TeamResult, - ToolConfig, + MessageConfig, MessageMeta, TeamResult, Response, SocketMessage ) + + +__all__ = [ + "Team", "Run", "RunStatus", "Session", "Team", + "MessageConfig", "MessageMeta", "TeamResult", "Response", "SocketMessage" +] \ No newline at end of file diff --git a/python/packages/autogen-studio/autogenstudio/datamodel/db.py b/python/packages/autogen-studio/autogenstudio/datamodel/db.py index 45f439d33910..36905623b399 100644 --- a/python/packages/autogen-studio/autogenstudio/datamodel/db.py +++ b/python/packages/autogen-studio/autogenstudio/datamodel/db.py @@ -2,142 +2,15 @@ from datetime import datetime from enum import Enum -from typing import List, Optional, Tuple, Type, Union +from typing import List, Optional, Union from uuid import UUID, uuid4 -from loguru import logger -from pydantic import BaseModel -from sqlalchemy import ForeignKey, Integer, UniqueConstraint -from sqlmodel import JSON, Column, DateTime, Field, Relationship, SQLModel, func - -from .types import AgentConfig, MessageConfig, MessageMeta, ModelConfig, TeamConfig, TeamResult, ToolConfig - -# added for python3.11 and sqlmodel 0.0.22 incompatibility -if hasattr(SQLModel, "model_config"): - SQLModel.model_config["protected_namespaces"] = () -elif hasattr(SQLModel, "Config"): - - class CustomSQLModel(SQLModel): - class Config: - protected_namespaces = () - - SQLModel = CustomSQLModel -else: - logger.warning("Unable to set protected_namespaces.") - -# pylint: disable=protected-access - - -class ComponentTypes(Enum): - TEAM = "team" - AGENT = "agent" - MODEL = "model" - TOOL = "tool" - - @property - def model_class(self) -> Type[SQLModel]: - return { - ComponentTypes.TEAM: Team, - ComponentTypes.AGENT: Agent, - ComponentTypes.MODEL: Model, - ComponentTypes.TOOL: Tool, - }[self] - - -class LinkTypes(Enum): - AGENT_MODEL = "agent_model" - AGENT_TOOL = "agent_tool" - TEAM_AGENT = "team_agent" - - @property - # type: ignore - def link_config(self) -> Tuple[Type[SQLModel], Type[SQLModel], Type[SQLModel]]: - return { - LinkTypes.AGENT_MODEL: (Agent, Model, AgentModelLink), - LinkTypes.AGENT_TOOL: (Agent, Tool, AgentToolLink), - LinkTypes.TEAM_AGENT: (Team, Agent, TeamAgentLink), - }[self] - - @property - def primary_class(self) -> Type[SQLModel]: # type: ignore - return self.link_config[0] - - @property - def secondary_class(self) -> Type[SQLModel]: # type: ignore - return self.link_config[1] - - @property - def link_table(self) -> Type[SQLModel]: # type: ignore - return self.link_config[2] - - -# link models -class AgentToolLink(SQLModel, table=True): - __table_args__ = ( - UniqueConstraint("agent_id", "sequence", name="unique_agent_tool_sequence"), - {"sqlite_autoincrement": True}, - ) - agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id") - tool_id: int = Field(default=None, primary_key=True, foreign_key="tool.id") - sequence: Optional[int] = Field(default=0, primary_key=True) - - -class AgentModelLink(SQLModel, table=True): - __table_args__ = ( - UniqueConstraint("agent_id", "sequence", name="unique_agent_tool_sequence"), - {"sqlite_autoincrement": True}, - ) - agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id") - model_id: int = Field(default=None, primary_key=True, foreign_key="model.id") - sequence: Optional[int] = Field(default=0, primary_key=True) - - -class TeamAgentLink(SQLModel, table=True): - __table_args__ = ( - UniqueConstraint("agent_id", "sequence", name="unique_agent_tool_sequence"), - {"sqlite_autoincrement": True}, - ) - team_id: int = Field(default=None, primary_key=True, foreign_key="team.id") - agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id") - sequence: Optional[int] = Field(default=0, primary_key=True) - - -# database models - - -class Tool(SQLModel, table=True): - __table_args__ = {"sqlite_autoincrement": True} - id: Optional[int] = Field(default=None, primary_key=True) - created_at: datetime = Field( - default_factory=datetime.now, - sa_column=Column(DateTime(timezone=True), server_default=func.now()), - ) # pylint: disable=not-callable - updated_at: datetime = Field( - default_factory=datetime.now, - sa_column=Column(DateTime(timezone=True), onupdate=func.now()), - ) # pylint: disable=not-callable - user_id: Optional[str] = None - version: Optional[str] = "0.0.1" - config: Union[ToolConfig, dict] = Field(sa_column=Column(JSON)) - agents: List["Agent"] = Relationship(back_populates="tools", link_model=AgentToolLink) - - -class Model(SQLModel, table=True): - __table_args__ = {"sqlite_autoincrement": True} - id: Optional[int] = Field(default=None, primary_key=True) - created_at: datetime = Field( - default_factory=datetime.now, - sa_column=Column(DateTime(timezone=True), server_default=func.now()), - ) # pylint: disable=not-callable - updated_at: datetime = Field( - default_factory=datetime.now, - sa_column=Column(DateTime(timezone=True), onupdate=func.now()), - ) # pylint: disable=not-callable - user_id: Optional[str] = None - version: Optional[str] = "0.0.1" - config: Union[ModelConfig, dict] = Field(sa_column=Column(JSON)) - agents: List["Agent"] = Relationship(back_populates="models", link_model=AgentModelLink) +from sqlalchemy import ForeignKey, Integer +from sqlmodel import JSON, Column, DateTime, Field, SQLModel, func +from .types import MessageConfig, MessageMeta, TeamResult +from autogen_core import ComponentModel + class Team(SQLModel, table=True): __table_args__ = {"sqlite_autoincrement": True} @@ -152,27 +25,8 @@ class Team(SQLModel, table=True): ) # pylint: disable=not-callable user_id: Optional[str] = None version: Optional[str] = "0.0.1" - config: Union[TeamConfig, dict] = Field(sa_column=Column(JSON)) - agents: List["Agent"] = Relationship(back_populates="teams", link_model=TeamAgentLink) - - -class Agent(SQLModel, table=True): - __table_args__ = {"sqlite_autoincrement": True} - id: Optional[int] = Field(default=None, primary_key=True) - created_at: datetime = Field( - default_factory=datetime.now, - sa_column=Column(DateTime(timezone=True), server_default=func.now()), - ) # pylint: disable=not-callable - updated_at: datetime = Field( - default_factory=datetime.now, - sa_column=Column(DateTime(timezone=True), onupdate=func.now()), - ) # pylint: disable=not-callable - user_id: Optional[str] = None - version: Optional[str] = "0.0.1" - config: Union[AgentConfig, dict] = Field(sa_column=Column(JSON)) - tools: List[Tool] = Relationship(back_populates="agents", link_model=AgentToolLink) - models: List[Model] = Relationship(back_populates="agents", link_model=AgentModelLink) - teams: List[Team] = Relationship(back_populates="agents", link_model=TeamAgentLink) + config: Union[ComponentModel, dict] = Field(sa_column=Column(JSON)) + class Message(SQLModel, table=True): @@ -251,31 +105,3 @@ class Run(SQLModel, table=True): class Config: json_encoders = {UUID: str, datetime: lambda v: v.isoformat()} - -class GalleryConfig(SQLModel, table=False): - id: UUID = Field(default_factory=uuid4, primary_key=True, index=True) - title: Optional[str] = None - description: Optional[str] = None - run: Run - team: TeamConfig = None - tags: Optional[List[str]] = None - visibility: str = "public" # public, private, shared - - class Config: - json_encoders = {UUID: str, datetime: lambda v: v.isoformat()} - - -class Gallery(SQLModel, table=True): - __table_args__ = {"sqlite_autoincrement": True} - id: Optional[int] = Field(default=None, primary_key=True) - created_at: datetime = Field( - default_factory=datetime.now, - sa_column=Column(DateTime(timezone=True), server_default=func.now()), - ) - updated_at: datetime = Field( - default_factory=datetime.now, - sa_column=Column(DateTime(timezone=True), onupdate=func.now()), - ) - user_id: Optional[str] = None - version: Optional[str] = "0.0.1" - config: Union[GalleryConfig, dict] = Field(default_factory=GalleryConfig, sa_column=Column(JSON)) diff --git a/python/packages/autogen-studio/autogenstudio/datamodel/types.py b/python/packages/autogen-studio/autogenstudio/datamodel/types.py index eb02fb121ebe..d0b65f8a0bae 100644 --- a/python/packages/autogen-studio/autogenstudio/datamodel/types.py +++ b/python/packages/autogen-studio/autogenstudio/datamodel/types.py @@ -1,56 +1,10 @@ from datetime import datetime -from enum import Enum -from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional from autogen_agentchat.base import TaskResult -from autogen_core.models import ModelCapabilities -from pydantic import BaseModel, Field +from pydantic import BaseModel -class ModelTypes(str, Enum): - OPENAI = "OpenAIChatCompletionClient" - AZUREOPENAI = "AzureOpenAIChatCompletionClient" - - -class ToolTypes(str, Enum): - PYTHON_FUNCTION = "PythonFunction" - - -class AgentTypes(str, Enum): - ASSISTANT = "AssistantAgent" - USERPROXY = "UserProxyAgent" - MULTIMODAL_WEBSURFER = "MultimodalWebSurfer" - FILE_SURFER = "FileSurfer" - MAGENTIC_ONE_CODER = "MagenticOneCoderAgent" - - -class TeamTypes(str, Enum): - ROUND_ROBIN = "RoundRobinGroupChat" - SELECTOR = "SelectorGroupChat" - MAGENTIC_ONE = "MagenticOneGroupChat" - - -class TerminationTypes(str, Enum): - MAX_MESSAGES = "MaxMessageTermination" - STOP_MESSAGE = "StopMessageTermination" - TEXT_MENTION = "TextMentionTermination" - COMBINATION = "CombinationTermination" - - -class ComponentTypes(str, Enum): - TEAM = "team" - AGENT = "agent" - MODEL = "model" - TOOL = "tool" - TERMINATION = "termination" - - -class BaseConfig(BaseModel): - model_config = {"protected_namespaces": ()} - version: str = "1.0.0" - component_type: ComponentTypes - class MessageConfig(BaseModel): source: str @@ -58,133 +12,6 @@ class MessageConfig(BaseModel): message_type: Optional[str] = "text" -class BaseModelConfig(BaseConfig): - model: str - model_type: ModelTypes - api_key: Optional[str] = None - base_url: Optional[str] = None - component_type: ComponentTypes = ComponentTypes.MODEL - model_capabilities: Optional[ModelCapabilities] = None - - -class OpenAIModelConfig(BaseModelConfig): - model_type: ModelTypes = ModelTypes.OPENAI - - -class AzureOpenAIModelConfig(BaseModelConfig): - azure_deployment: str - model: str - api_version: str - azure_endpoint: str - azure_ad_token_provider: Optional[str] = None - api_key: Optional[str] = None - model_type: ModelTypes = ModelTypes.AZUREOPENAI - - -ModelConfig = OpenAIModelConfig | AzureOpenAIModelConfig - - -class ToolConfig(BaseConfig): - name: str - description: str - content: str - tool_type: ToolTypes - component_type: ComponentTypes = ComponentTypes.TOOL - - -class BaseAgentConfig(BaseConfig): - name: str - agent_type: AgentTypes - description: Optional[str] = None - component_type: ComponentTypes = ComponentTypes.AGENT - - -class AssistantAgentConfig(BaseAgentConfig): - agent_type: AgentTypes = AgentTypes.ASSISTANT - model_client: ModelConfig - tools: Optional[List[ToolConfig]] = None - system_message: Optional[str] = None - - -class UserProxyAgentConfig(BaseAgentConfig): - agent_type: AgentTypes = AgentTypes.USERPROXY - - -class MultimodalWebSurferAgentConfig(BaseAgentConfig): - agent_type: AgentTypes = AgentTypes.MULTIMODAL_WEBSURFER - model_client: ModelConfig - headless: bool = True - logs_dir: str = None - to_save_screenshots: bool = False - use_ocr: bool = False - animate_actions: bool = False - tools: Optional[List[ToolConfig]] = None - - -AgentConfig = AssistantAgentConfig | UserProxyAgentConfig | MultimodalWebSurferAgentConfig - - -class BaseTerminationConfig(BaseConfig): - termination_type: TerminationTypes - component_type: ComponentTypes = ComponentTypes.TERMINATION - - -class MaxMessageTerminationConfig(BaseTerminationConfig): - termination_type: TerminationTypes = TerminationTypes.MAX_MESSAGES - max_messages: int - - -class TextMentionTerminationConfig(BaseTerminationConfig): - termination_type: TerminationTypes = TerminationTypes.TEXT_MENTION - text: str - - -class StopMessageTerminationConfig(BaseTerminationConfig): - termination_type: TerminationTypes = TerminationTypes.STOP_MESSAGE - - -class CombinationTerminationConfig(BaseTerminationConfig): - termination_type: TerminationTypes = TerminationTypes.COMBINATION - operator: str - conditions: List["TerminationConfig"] - - -TerminationConfig = ( - MaxMessageTerminationConfig - | TextMentionTerminationConfig - | CombinationTerminationConfig - | StopMessageTerminationConfig -) - - -class BaseTeamConfig(BaseConfig): - name: str - participants: List[AgentConfig] - team_type: TeamTypes - termination_condition: Optional[TerminationConfig] = None - component_type: ComponentTypes = ComponentTypes.TEAM - max_turns: Optional[int] = None - - -class RoundRobinTeamConfig(BaseTeamConfig): - team_type: TeamTypes = TeamTypes.ROUND_ROBIN - - -class SelectorTeamConfig(BaseTeamConfig): - team_type: TeamTypes = TeamTypes.SELECTOR - selector_prompt: Optional[str] = None - model_client: ModelConfig - - -class MagenticOneTeamConfig(BaseTeamConfig): - team_type: TeamTypes = TeamTypes.MAGENTIC_ONE - model_client: ModelConfig - max_stalls: int = 3 - final_answer_prompt: Optional[str] = None - - -TeamConfig = RoundRobinTeamConfig | SelectorTeamConfig | MagenticOneTeamConfig - class TeamResult(BaseModel): task_result: TaskResult @@ -216,7 +43,3 @@ class SocketMessage(BaseModel): data: Dict[str, Any] type: str - -ComponentConfig = Union[TeamConfig, AgentConfig, ModelConfig, ToolConfig, TerminationConfig] - -ComponentConfigInput = Union[str, Path, dict, ComponentConfig] diff --git a/python/packages/autogen-studio/autogenstudio/web/app.py b/python/packages/autogen-studio/autogenstudio/web/app.py index 2e2ad3337248..9fba6601ee07 100644 --- a/python/packages/autogen-studio/autogenstudio/web/app.py +++ b/python/packages/autogen-studio/autogenstudio/web/app.py @@ -13,7 +13,7 @@ from .config import settings from .deps import cleanup_managers, init_managers from .initialization import AppInitializer -from .routes import agents, models, runs, sessions, teams, tools, ws +from .routes import runs, sessions, teams, ws # Configure logging # logger = logging.getLogger(__name__) @@ -103,27 +103,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: tags=["teams"], responses={404: {"description": "Not found"}}, ) - -api.include_router( - agents.router, - prefix="/agents", - tags=["agents"], - responses={404: {"description": "Not found"}}, -) - -api.include_router( - models.router, - prefix="/models", - tags=["models"], - responses={404: {"description": "Not found"}}, -) - -api.include_router( - tools.router, - prefix="/tools", - tags=["tools"], - responses={404: {"description": "Not found"}}, -) + api.include_router( ws.router, diff --git a/python/packages/autogen-studio/autogenstudio/web/deps.py b/python/packages/autogen-studio/autogenstudio/web/deps.py index d2e5b6fabb4a..04dfb41366a1 100644 --- a/python/packages/autogen-studio/autogenstudio/web/deps.py +++ b/python/packages/autogen-studio/autogenstudio/web/deps.py @@ -5,7 +5,7 @@ from fastapi import Depends, HTTPException, status -from ..database import ConfigurationManager, DatabaseManager +from ..database import DatabaseManager from ..teammanager import TeamManager from .config import settings from .managers.connection import WebSocketManager @@ -93,10 +93,8 @@ async def init_managers(database_uri: str, config_dir: str, app_root: str) -> No _db_manager = DatabaseManager(engine_uri=database_uri, base_dir=app_root) _db_manager.initialize_database(auto_upgrade=settings.UPGRADE_DATABASE) - # init default team config - - _team_config_manager = ConfigurationManager(db_manager=_db_manager) - await _team_config_manager.import_directory(config_dir, settings.DEFAULT_USER_ID, check_exists=True) + # init default team config + await _db_manager.import_teams_from_directory(config_dir, settings.DEFAULT_USER_ID, check_exists=True) # Initialize connection manager _websocket_manager = WebSocketManager(db_manager=_db_manager) diff --git a/python/packages/autogen-studio/autogenstudio/web/routes/models.py b/python/packages/autogen-studio/autogenstudio/web/routes/models.py deleted file mode 100644 index f041e52cb93b..000000000000 --- a/python/packages/autogen-studio/autogenstudio/web/routes/models.py +++ /dev/null @@ -1,42 +0,0 @@ -# api/routes/models.py -from typing import Dict - -from fastapi import APIRouter, Depends, HTTPException -from openai import OpenAIError - -from ...datamodel import Model -from ..deps import get_db - -router = APIRouter() - - -@router.get("/") -async def list_models(user_id: str, db=Depends(get_db)) -> Dict: - """List all models for a user""" - response = db.get(Model, filters={"user_id": user_id}) - return {"status": True, "data": response.data} - - -@router.get("/{model_id}") -async def get_model(model_id: int, user_id: str, db=Depends(get_db)) -> Dict: - """Get a specific model""" - response = db.get(Model, filters={"id": model_id, "user_id": user_id}) - if not response.status or not response.data: - raise HTTPException(status_code=404, detail="Model not found") - return {"status": True, "data": response.data[0]} - - -@router.post("/") -async def create_model(model: Model, db=Depends(get_db)) -> Dict: - """Create a new model""" - response = db.upsert(model) - if not response.status: - raise HTTPException(status_code=400, detail=response.message) - return {"status": True, "data": response.data} - - -@router.delete("/{model_id}") -async def delete_model(model_id: int, user_id: str, db=Depends(get_db)) -> Dict: - """Delete a model""" - db.delete(filters={"id": model_id, "user_id": user_id}, model_class=Model) - return {"status": True, "message": "Model deleted successfully"} diff --git a/python/packages/autogen-studio/autogenstudio/web/routes/tools.py b/python/packages/autogen-studio/autogenstudio/web/routes/tools.py deleted file mode 100644 index da2ae7733b2b..000000000000 --- a/python/packages/autogen-studio/autogenstudio/web/routes/tools.py +++ /dev/null @@ -1,41 +0,0 @@ -# api/routes/tools.py -from typing import Dict - -from fastapi import APIRouter, Depends, HTTPException - -from ...datamodel import Tool -from ..deps import get_db - -router = APIRouter() - - -@router.get("/") -async def list_tools(user_id: str, db=Depends(get_db)) -> Dict: - """List all tools for a user""" - response = db.get(Tool, filters={"user_id": user_id}) - return {"status": True, "data": response.data} - - -@router.get("/{tool_id}") -async def get_tool(tool_id: int, user_id: str, db=Depends(get_db)) -> Dict: - """Get a specific tool""" - response = db.get(Tool, filters={"id": tool_id, "user_id": user_id}) - if not response.status or not response.data: - raise HTTPException(status_code=404, detail="Tool not found") - return {"status": True, "data": response.data[0]} - - -@router.post("/") -async def create_tool(tool: Tool, db=Depends(get_db)) -> Dict: - """Create a new tool""" - response = db.upsert(tool) - if not response.status: - raise HTTPException(status_code=400, detail=response.message) - return {"status": True, "data": response.data} - - -@router.delete("/{tool_id}") -async def delete_tool(tool_id: int, user_id: str, db=Depends(get_db)) -> Dict: - """Delete a tool""" - db.delete(filters={"id": tool_id, "user_id": user_id}, model_class=Tool) - return {"status": True, "message": "Tool deleted successfully"} diff --git a/python/packages/autogen-studio/frontend/src/components/types/datamodel.ts b/python/packages/autogen-studio/frontend/src/components/types/datamodel.ts index b502f04c892c..b0d82f46e338 100644 --- a/python/packages/autogen-studio/frontend/src/components/types/datamodel.ts +++ b/python/packages/autogen-studio/frontend/src/components/types/datamodel.ts @@ -97,8 +97,11 @@ export interface SessionRuns { export interface BaseConfig { component_type: string; - version?: string; - description?: string; + version?: number; + description?: string | null; + label?: string; + component_version?: number; + provider: string; } export interface WebSocketMessage { @@ -273,8 +276,18 @@ export interface SelectorGroupChatConfig extends BaseTeamConfig { export type TeamConfig = RoundRobinGroupChatConfig | SelectorGroupChatConfig; +export interface ComponentModel { + provider: string; + component_type: ComponentTypes; + version: number; + component_version: number; + config: any; + description?: string | null; + label?: string; +} + export interface Team extends DBModel { - config: TeamConfig; + config: ComponentModel; } export interface TeamResult { diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/types.ts b/python/packages/autogen-studio/frontend/src/components/views/team/types.ts index fd29035eb3fd..a5fd5370c92c 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/types.ts +++ b/python/packages/autogen-studio/frontend/src/components/views/team/types.ts @@ -1,4 +1,4 @@ -import type { Team, TeamConfig } from "../../types/datamodel"; +import type { ComponentModel, Team, TeamConfig } from "../../types/datamodel"; export interface TeamEditorProps { team?: Team; @@ -16,29 +16,87 @@ export interface TeamListProps { isLoading?: boolean; } -export const defaultTeamConfig: TeamConfig = { - version: "1.0.0", +export const defaultTeamConfig: ComponentModel = { + provider: "autogen_agentchat.teams.RoundRobinGroupChat", component_type: "team", - name: "default_team", - participants: [ - { - component_type: "agent", - name: "assistant_agent", - agent_type: "AssistantAgent", - system_message: - "You are a helpful assistant. Solve tasks carefully. When done respond with TERMINATE", - model_client: { - component_type: "model", - model: "gpt-4o-2024-08-06", - model_type: "OpenAIChatCompletionClient", + version: 1, + component_version: 1, + description: null, + config: { + participants: [ + { + provider: "autogen_agentchat.agents.AssistantAgent", + component_type: "agent", + version: 1, + component_version: 1, + config: { + name: "weather_agent", + model_client: { + provider: "autogen_ext.models.openai.OpenAIChatCompletionClient", + component_type: "model", + version: 1, + component_version: 1, + config: { model: "gpt-4o-mini" }, + }, + tools: [ + { + provider: "autogen_core.tools.FunctionTool", + component_type: "tool", + version: 1, + component_version: 1, + config: { + source_code: + 'async def get_weather(city: str) -> str:\n return f"The weather in {city} is 73 degrees and Sunny."\n', + name: "get_weather", + description: "", + global_imports: [], + has_cancellation_support: false, + }, + }, + ], + handoffs: [], + model_context: { + provider: + "autogen_core.model_context.UnboundedChatCompletionContext", + component_type: "chat_completion_context", + version: 1, + component_version: 1, + config: {}, + }, + description: + "An agent that provides assistance with ability to use tools.", + system_message: + "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.", + reflect_on_tool_use: false, + tool_call_summary_format: "{result}", + }, + }, + ], + termination_condition: { + provider: "autogen_agentchat.base.OrTerminationCondition", + component_type: "termination", + version: 1, + component_version: 1, + config: { + conditions: [ + { + provider: "autogen_agentchat.conditions.MaxMessageTermination", + component_type: "termination", + version: 1, + component_version: 1, + config: { max_messages: 10 }, + }, + { + provider: "autogen_agentchat.conditions.TextMentionTermination", + component_type: "termination", + version: 1, + component_version: 1, + config: { text: "TERMINATE" }, + }, + ], }, }, - ], - team_type: "RoundRobinGroupChat", - termination_condition: { - component_type: "termination", - termination_type: "MaxMessageTermination", - max_messages: 3, + max_turns: 1, }, }; diff --git a/python/packages/autogen-studio/tests/test_component_factory.py b/python/packages/autogen-studio/tests/test_component_factory.py deleted file mode 100644 index c5339cc70721..000000000000 --- a/python/packages/autogen-studio/tests/test_component_factory.py +++ /dev/null @@ -1,397 +0,0 @@ -import pytest -from typing import List - -from autogen_agentchat.agents import AssistantAgent -from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat, MagenticOneGroupChat -from autogen_agentchat.conditions import MaxMessageTermination, StopMessageTermination, TextMentionTermination -from autogen_core.tools import FunctionTool - -from autogenstudio.datamodel.types import ( - AssistantAgentConfig, - OpenAIModelConfig, - RoundRobinTeamConfig, - SelectorTeamConfig, - MagenticOneTeamConfig, - ToolConfig, - MaxMessageTerminationConfig, - StopMessageTerminationConfig, - TextMentionTerminationConfig, - CombinationTerminationConfig, - ModelTypes, - AgentTypes, - TeamTypes, - TerminationTypes, - ToolTypes, - ComponentTypes, -) -from autogenstudio.database import ComponentFactory - - -@pytest.fixture -def component_factory(): - return ComponentFactory() - - -@pytest.fixture -def sample_tool_config(): - return ToolConfig( - name="calculator", - description="A simple calculator function", - content=""" -def calculator(a: int, b: int, operation: str = '+') -> int: - ''' - A simple calculator that performs basic operations - ''' - if operation == '+': - return a + b - elif operation == '-': - return a - b - elif operation == '*': - return a * b - elif operation == '/': - return a / b - else: - raise ValueError("Invalid operation") -""", - tool_type=ToolTypes.PYTHON_FUNCTION, - component_type=ComponentTypes.TOOL, - version="1.0.0", - ) - - -@pytest.fixture -def sample_model_config(): - return OpenAIModelConfig( - model_type=ModelTypes.OPENAI, - model="gpt-4", - api_key="test-key", - component_type=ComponentTypes.MODEL, - version="1.0.0", - ) - - -@pytest.fixture -def sample_agent_config(sample_model_config: OpenAIModelConfig, sample_tool_config: ToolConfig): - return AssistantAgentConfig( - name="test_agent", - agent_type=AgentTypes.ASSISTANT, - system_message="You are a helpful assistant", - model_client=sample_model_config, - tools=[sample_tool_config], - component_type=ComponentTypes.AGENT, - version="1.0.0", - ) - - -@pytest.fixture -def sample_termination_config(): - return MaxMessageTerminationConfig( - termination_type=TerminationTypes.MAX_MESSAGES, - max_messages=10, - component_type=ComponentTypes.TERMINATION, - version="1.0.0", - ) - - -@pytest.fixture -def sample_team_config( - sample_agent_config: AssistantAgentConfig, sample_termination_config: MaxMessageTerminationConfig, sample_model_config: OpenAIModelConfig -): - return RoundRobinTeamConfig( - name="test_team", - team_type=TeamTypes.ROUND_ROBIN, - participants=[sample_agent_config], - termination_condition=sample_termination_config, - model_client=sample_model_config, - component_type=ComponentTypes.TEAM, - max_turns=10, - version="1.0.0", - ) - - -@pytest.mark.asyncio -async def test_load_tool(component_factory: ComponentFactory, sample_tool_config: ToolConfig): - # Test loading tool from ToolConfig - tool = await component_factory.load_tool(sample_tool_config) - assert isinstance(tool, FunctionTool) - assert tool.name == "calculator" - assert tool.description == "A simple calculator function" - - # Test tool functionality - result = tool._func(5, 3, "+") - assert result == 8 - - -@pytest.mark.asyncio -async def test_load_tool_invalid_config(component_factory: ComponentFactory): - # Test with missing required fields - with pytest.raises(ValueError): - await component_factory.load_tool( - ToolConfig( - name="test", - description="", - content="", - tool_type=ToolTypes.PYTHON_FUNCTION, - component_type=ComponentTypes.TOOL, - version="1.0.0", - ) - ) - - # Test with invalid Python code - invalid_config = ToolConfig( - name="invalid", - description="Invalid function", - content="def invalid_func(): return invalid syntax", - tool_type=ToolTypes.PYTHON_FUNCTION, - component_type=ComponentTypes.TOOL, - version="1.0.0", - ) - with pytest.raises(ValueError): - await component_factory.load_tool(invalid_config) - - -@pytest.mark.asyncio -async def test_load_model(component_factory: ComponentFactory, sample_model_config: OpenAIModelConfig): - # Test loading model from ModelConfig - model = await component_factory.load_model(sample_model_config) - assert model is not None - - -@pytest.mark.asyncio -async def test_load_agent(component_factory: ComponentFactory, sample_agent_config: AssistantAgentConfig): - # Test loading agent from AgentConfig - agent = await component_factory.load_agent(sample_agent_config) - assert isinstance(agent, AssistantAgent) - assert agent.name == "test_agent" - assert len(agent._tools) == 1 - - -@pytest.mark.asyncio -async def test_load_termination(component_factory: ComponentFactory): - - max_msg_config = MaxMessageTerminationConfig( - termination_type=TerminationTypes.MAX_MESSAGES, - max_messages=5, - component_type=ComponentTypes.TERMINATION, - version="1.0.0", - ) - termination = await component_factory.load_termination(max_msg_config) - assert isinstance(termination, MaxMessageTermination) - assert termination._max_messages == 5 - - # Test StopMessageTermination - stop_msg_config = StopMessageTerminationConfig( - termination_type=TerminationTypes.STOP_MESSAGE, component_type=ComponentTypes.TERMINATION, version="1.0.0" - ) - termination = await component_factory.load_termination(stop_msg_config) - assert isinstance(termination, StopMessageTermination) - - # Test TextMentionTermination - text_mention_config = TextMentionTerminationConfig( - termination_type=TerminationTypes.TEXT_MENTION, - text="DONE", - component_type=ComponentTypes.TERMINATION, - version="1.0.0", - ) - termination = await component_factory.load_termination(text_mention_config) - assert isinstance(termination, TextMentionTermination) - assert termination._text == "DONE" - - # Test AND combination - and_combo_config = CombinationTerminationConfig( - termination_type=TerminationTypes.COMBINATION, - operator="and", - conditions=[ - MaxMessageTerminationConfig( - termination_type=TerminationTypes.MAX_MESSAGES, - max_messages=5, - component_type=ComponentTypes.TERMINATION, - version="1.0.0", - ), - TextMentionTerminationConfig( - termination_type=TerminationTypes.TEXT_MENTION, - text="DONE", - component_type=ComponentTypes.TERMINATION, - version="1.0.0", - ), - ], - component_type=ComponentTypes.TERMINATION, - version="1.0.0", - ) - termination = await component_factory.load_termination(and_combo_config) - assert termination is not None - - # Test OR combination - or_combo_config = CombinationTerminationConfig( - termination_type=TerminationTypes.COMBINATION, - operator="or", - conditions=[ - MaxMessageTerminationConfig( - termination_type=TerminationTypes.MAX_MESSAGES, - max_messages=5, - component_type=ComponentTypes.TERMINATION, - version="1.0.0", - ), - TextMentionTerminationConfig( - termination_type=TerminationTypes.TEXT_MENTION, - text="DONE", - component_type=ComponentTypes.TERMINATION, - version="1.0.0", - ), - ], - component_type=ComponentTypes.TERMINATION, - version="1.0.0", - ) - termination = await component_factory.load_termination(or_combo_config) - assert termination is not None - - # Test invalid combinations - with pytest.raises(ValueError): - await component_factory.load_termination( - CombinationTerminationConfig( - termination_type=TerminationTypes.COMBINATION, - conditions=[], # Empty conditions - component_type=ComponentTypes.TERMINATION, - version="1.0.0", - ) - ) - - with pytest.raises(ValueError): - await component_factory.load_termination( - CombinationTerminationConfig( - termination_type=TerminationTypes.COMBINATION, - operator="invalid", # type: ignore - conditions=[ - MaxMessageTerminationConfig( - termination_type=TerminationTypes.MAX_MESSAGES, - max_messages=5, - component_type=ComponentTypes.TERMINATION, - version="1.0.0", - ) - ], - component_type=ComponentTypes.TERMINATION, - version="1.0.0", - ) - ) - - # Test missing operator - with pytest.raises(ValueError): - await component_factory.load_termination( - CombinationTerminationConfig( - termination_type=TerminationTypes.COMBINATION, - conditions=[ - MaxMessageTerminationConfig( - termination_type=TerminationTypes.MAX_MESSAGES, - max_messages=5, - component_type=ComponentTypes.TERMINATION, - version="1.0.0", - ), - TextMentionTerminationConfig( - termination_type=TerminationTypes.TEXT_MENTION, - text="DONE", - component_type=ComponentTypes.TERMINATION, - version="1.0.0", - ), - ], - component_type=ComponentTypes.TERMINATION, - version="1.0.0", - ) - ) - - -@pytest.mark.asyncio -async def test_load_team( - component_factory: ComponentFactory, sample_team_config: RoundRobinTeamConfig, sample_model_config: OpenAIModelConfig -): - # Test loading RoundRobinGroupChat team - team = await component_factory.load_team(sample_team_config) - assert isinstance(team, RoundRobinGroupChat) - assert len(team._participants) == 1 - - # Test loading SelectorGroupChat team with multiple participants - selector_team_config = SelectorTeamConfig( - name="selector_team", - team_type=TeamTypes.SELECTOR, - participants=[ # Add two participants - sample_team_config.participants[0], # First agent - AssistantAgentConfig( # Second agent - name="test_agent_2", - agent_type=AgentTypes.ASSISTANT, - system_message="You are another helpful assistant", - model_client=sample_model_config, - tools=sample_team_config.participants[0].tools, - component_type=ComponentTypes.AGENT, - version="1.0.0", - ), - ], - termination_condition=sample_team_config.termination_condition, - model_client=sample_model_config, - component_type=ComponentTypes.TEAM, - version="1.0.0", - ) - team = await component_factory.load_team(selector_team_config) - assert isinstance(team, SelectorGroupChat) - assert len(team._participants) == 2 - - # Test loading MagenticOneGroupChat team - magentic_one_config = MagenticOneTeamConfig( - name="magentic_one_team", - team_type=TeamTypes.MAGENTIC_ONE, - participants=[ # Add two participants - sample_team_config.participants[0], # First agent - AssistantAgentConfig( # Second agent - name="test_agent_2", - agent_type=AgentTypes.ASSISTANT, - system_message="You are another helpful assistant", - model_client=sample_model_config, - tools=sample_team_config.participants[0].tools, - component_type=ComponentTypes.AGENT, - max_turns=sample_team_config.max_turns, - version="1.0.0", - ), - ], - termination_condition=sample_team_config.termination_condition, - model_client=sample_model_config, - component_type=ComponentTypes.TEAM, - version="1.0.0", - ) - team = await component_factory.load_team(magentic_one_config) - assert isinstance(team, MagenticOneGroupChat) - assert len(team._participants) == 2 - - -@pytest.mark.asyncio -async def test_invalid_configs(component_factory: ComponentFactory): - # Test invalid agent type - with pytest.raises(ValueError): - await component_factory.load_agent( - AssistantAgentConfig( - name="test", - agent_type="InvalidAgent", # type: ignore - system_message="test", - component_type=ComponentTypes.AGENT, - version="1.0.0", - ) - ) - - # Test invalid team type - with pytest.raises(ValueError): - await component_factory.load_team( - RoundRobinTeamConfig( - name="test", - team_type="InvalidTeam", # type: ignore - participants=[], - component_type=ComponentTypes.TEAM, - version="1.0.0", - ) - ) - - # Test invalid termination type - with pytest.raises(ValueError): - await component_factory.load_termination( - MaxMessageTerminationConfig( - termination_type="InvalidTermination", # type: ignore - component_type=ComponentTypes.TERMINATION, - version="1.0.0", - ) - ) From ebe06f4d60a193880275382296f8f90b8750d725 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Sun, 26 Jan 2025 11:03:03 -0800 Subject: [PATCH 06/16] initial updates --- .../frontend/src/components/views/team/types.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/types.ts b/python/packages/autogen-studio/frontend/src/components/views/team/types.ts index a5fd5370c92c..399863cc789c 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/types.ts +++ b/python/packages/autogen-studio/frontend/src/components/views/team/types.ts @@ -1,4 +1,4 @@ -import type { ComponentModel, Team, TeamConfig } from "../../types/datamodel"; +import type { Team, TeamConfig } from "../../types/datamodel"; export interface TeamEditorProps { team?: Team; @@ -16,7 +16,7 @@ export interface TeamListProps { isLoading?: boolean; } -export const defaultTeamConfig: ComponentModel = { +export const defaultTeamConfig: TeamConfig = { provider: "autogen_agentchat.teams.RoundRobinGroupChat", component_type: "team", version: 1, From d0377798b09cb438b715e5a0713f5cca2a3ad62d Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Sun, 26 Jan 2025 20:22:23 -0800 Subject: [PATCH 07/16] v2 on ags migration --- .../autogenstudio/datamodel/db.py | 4 +- .../src/components/types/datamodel.ts | 339 ++++++++-------- .../frontend/src/components/types/guards.ts | 165 ++++++++ .../src/components/views/gallery/store.tsx | 13 +- .../src/components/views/gallery/types.ts | 11 +- .../src/components/views/gallery/utils.ts | 375 ++++++++++++------ .../components/views/team/builder/library.tsx | 12 +- .../src/components/views/team/hooks.tsx | 39 ++ .../src/components/views/team/manager.tsx | 2 +- .../src/components/views/team/sidebar.tsx | 27 +- .../src/components/views/team/types.ts | 10 +- 11 files changed, 651 insertions(+), 346 deletions(-) create mode 100644 python/packages/autogen-studio/frontend/src/components/types/guards.ts create mode 100644 python/packages/autogen-studio/frontend/src/components/views/team/hooks.tsx diff --git a/python/packages/autogen-studio/autogenstudio/datamodel/db.py b/python/packages/autogen-studio/autogenstudio/datamodel/db.py index 36905623b399..f3ca24b82744 100644 --- a/python/packages/autogen-studio/autogenstudio/datamodel/db.py +++ b/python/packages/autogen-studio/autogenstudio/datamodel/db.py @@ -25,7 +25,7 @@ class Team(SQLModel, table=True): ) # pylint: disable=not-callable user_id: Optional[str] = None version: Optional[str] = "0.0.1" - config: Union[ComponentModel, dict] = Field(sa_column=Column(JSON)) + component: Union[ComponentModel, dict] = Field(sa_column=Column(JSON)) @@ -42,7 +42,7 @@ class Message(SQLModel, table=True): ) # pylint: disable=not-callable user_id: Optional[str] = None version: Optional[str] = "0.0.1" - config: Union[MessageConfig, dict] = Field(default_factory=MessageConfig, sa_column=Column(JSON)) + component: Union[MessageConfig, dict] = Field(default_factory=MessageConfig, sa_column=Column(JSON)) session_id: Optional[int] = Field( default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE")) ) diff --git a/python/packages/autogen-studio/frontend/src/components/types/datamodel.ts b/python/packages/autogen-studio/frontend/src/components/types/datamodel.ts index b0d82f46e338..76717271885b 100644 --- a/python/packages/autogen-studio/frontend/src/components/types/datamodel.ts +++ b/python/packages/autogen-studio/frontend/src/components/types/datamodel.ts @@ -1,3 +1,21 @@ +// Base Component System +export interface Component { + provider: string; + component_type: + | "team" + | "agent" + | "model" + | "tool" + | "termination" + | "chat_completion_context"; + version?: number; + component_version?: number; + description?: string | null; + config: T; + label?: string; +} + +// Message Types export interface RequestUsage { prompt_tokens: number; completion_tokens: number; @@ -11,7 +29,7 @@ export interface ImageContent { export interface FunctionCall { id: string; - arguments: string; // JSON string + arguments: string; name: string; } @@ -20,13 +38,11 @@ export interface FunctionExecutionResult { content: string; } -// Base message configuration (maps to Python BaseMessage) export interface BaseMessageConfig { source: string; models_usage?: RequestUsage; } -// Message configurations (mapping directly to Python classes) export interface TextMessageConfig extends BaseMessageConfig { content: string; } @@ -52,17 +68,6 @@ export interface ToolCallResultMessageConfig extends BaseMessageConfig { content: FunctionExecutionResult[]; } -// Message type unions (matching Python type aliases) -export type InnerMessageConfig = - | ToolCallMessageConfig - | ToolCallResultMessageConfig; - -export type ChatMessageConfig = - | TextMessageConfig - | MultiModalMessageConfig - | StopMessageConfig - | HandoffMessageConfig; - export type AgentMessageConfig = | TextMessageConfig | MultiModalMessageConfig @@ -71,223 +76,191 @@ export type AgentMessageConfig = | ToolCallMessageConfig | ToolCallResultMessageConfig; -// Database model -export interface DBModel { - id?: number; - user_id?: string; - created_at?: string; - updated_at?: string; - version?: number; +// Tool Configs +export interface FunctionToolConfig { + source_code: string; + name: string; + description: string; + global_imports: any[]; // Sequence[Import] equivalent + has_cancellation_support: boolean; } -export interface Message extends DBModel { - config: AgentMessageConfig; - session_id: number; - run_id: string; +// Provider-based Configs +export interface SelectorGroupChatConfig { + participants: Component[]; + model_client: Component; + termination_condition?: Component; + max_turns?: number; + selector_prompt: string; + allow_repeated_speaker: boolean; } -export interface Session extends DBModel { - name: string; - team_id?: number; +export interface RoundRobinGroupChatConfig { + participants: Component[]; + termination_condition?: Component; + max_turns?: number; } -export interface SessionRuns { - runs: Run[]; +export interface MultimodalWebSurferConfig { + name: string; + model_client: Component; + downloads_folder?: string; + description?: string; + debug_dir?: string; + headless?: boolean; + start_page?: string; + animate_actions?: boolean; + to_save_screenshots?: boolean; + use_ocr?: boolean; + browser_channel?: string; + browser_data_dir?: string; + to_resize_viewport?: boolean; +} + +export interface AssistantAgentConfig { + name: string; + model_client: Component; + tools?: Component[]; + handoffs?: any[]; // HandoffBase | str equivalent + model_context?: Component; + description: string; + system_message?: string; + reflect_on_tool_use: boolean; + tool_call_summary_format: string; } -export interface BaseConfig { - component_type: string; - version?: number; - description?: string | null; - label?: string; - component_version?: number; - provider: string; +export interface UserProxyAgentConfig { + name: string; + description: string; } -export interface WebSocketMessage { - type: "message" | "result" | "completion" | "input_request" | "error"; - data?: AgentMessageConfig | TaskResult; - status?: RunStatus; - error?: string; - timestamp?: string; +// Model Configs +export interface ModelInfo { + vision: boolean; + function_calling: boolean; + json_output: boolean; + family: string; } -export interface TaskResult { - messages: AgentMessageConfig[]; - stop_reason?: string; +export interface CreateArgumentsConfig { + frequency_penalty?: number; + logit_bias?: Record; + max_tokens?: number; + n?: number; + presence_penalty?: number; + response_format?: any; // ResponseFormat equivalent + seed?: number; + stop?: string | string[]; + temperature?: number; + top_p?: number; + user?: string; } -export type ModelTypes = - | "OpenAIChatCompletionClient" - | "AzureOpenAIChatCompletionClient"; - -export type AgentTypes = - | "AssistantAgent" - | "UserProxyAgent" - | "MultimodalWebSurfer" - | "FileSurfer" - | "MagenticOneCoderAgent"; - -export type ToolTypes = "PythonFunction"; - -export type TeamTypes = - | "RoundRobinGroupChat" - | "SelectorGroupChat" - | "MagenticOneGroupChat"; - -export type TerminationTypes = - | "MaxMessageTermination" - | "StopMessageTermination" - | "TextMentionTermination" - | "TimeoutTermination" - | "CombinationTermination"; - -export type ComponentTypes = - | "team" - | "agent" - | "model" - | "tool" - | "termination"; - -export type ComponentConfigTypes = - | TeamConfig - | AgentConfig - | ModelConfig - | ToolConfig - | TerminationConfig; - -export interface BaseModelConfig extends BaseConfig { +export interface BaseOpenAIClientConfig extends CreateArgumentsConfig { model: string; - model_type: ModelTypes; api_key?: string; - base_url?: string; + timeout?: number; + max_retries?: number; + model_capabilities?: any; // ModelCapabilities equivalent + model_info?: ModelInfo; } -export interface AzureOpenAIModelConfig extends BaseModelConfig { - model_type: "AzureOpenAIChatCompletionClient"; - azure_deployment: string; - api_version: string; - azure_endpoint: string; - azure_ad_token_provider: string; +export interface OpenAIClientConfig extends BaseOpenAIClientConfig { + organization?: string; + base_url?: string; } -export interface OpenAIModelConfig extends BaseModelConfig { - model_type: "OpenAIChatCompletionClient"; +export interface AzureOpenAIClientConfig extends BaseOpenAIClientConfig { + azure_endpoint: string; + azure_deployment?: string; + api_version: string; + azure_ad_token?: string; + azure_ad_token_provider?: Component; } -export type ModelConfig = AzureOpenAIModelConfig | OpenAIModelConfig; - -export interface BaseToolConfig extends BaseConfig { - name: string; - description: string; - content: string; - tool_type: ToolTypes; +export interface UnboundedChatCompletionContextConfig { + // Empty in example but could have props } -export interface PythonFunctionToolConfig extends BaseToolConfig { - tool_type: "PythonFunction"; +export interface OrTerminationConfig { + conditions: Component[]; } -export type ToolConfig = PythonFunctionToolConfig; - -export interface BaseAgentConfig extends BaseConfig { - name: string; - agent_type: AgentTypes; - system_message?: string; - model_client?: ModelConfig; - tools?: ToolConfig[]; - description?: string; +export interface MaxMessageTerminationConfig { + max_messages: number; } -export interface AssistantAgentConfig extends BaseAgentConfig { - agent_type: "AssistantAgent"; +export interface TextMentionTerminationConfig { + text: string; } -export interface UserProxyAgentConfig extends BaseAgentConfig { - agent_type: "UserProxyAgent"; -} +// Config type unions based on provider +export type TeamConfig = SelectorGroupChatConfig | RoundRobinGroupChatConfig; -export interface MultimodalWebSurferAgentConfig extends BaseAgentConfig { - agent_type: "MultimodalWebSurfer"; -} +export type AgentConfig = + | MultimodalWebSurferConfig + | AssistantAgentConfig + | UserProxyAgentConfig; -export interface FileSurferAgentConfig extends BaseAgentConfig { - agent_type: "FileSurfer"; -} +export type ModelConfig = OpenAIClientConfig | AzureOpenAIClientConfig; -export interface MagenticOneCoderAgentConfig extends BaseAgentConfig { - agent_type: "MagenticOneCoderAgent"; -} +export type ToolConfig = FunctionToolConfig; -export type AgentConfig = - | AssistantAgentConfig - | UserProxyAgentConfig - | MultimodalWebSurferAgentConfig - | FileSurferAgentConfig - | MagenticOneCoderAgentConfig; +export type ChatCompletionContextConfig = UnboundedChatCompletionContextConfig; -// export interface TerminationConfig extends BaseConfig { -// termination_type: TerminationTypes; -// max_messages?: number; -// text?: string; -// } +export type TerminationConfig = + | OrTerminationConfig + | MaxMessageTerminationConfig + | TextMentionTerminationConfig; -export interface BaseTerminationConfig extends BaseConfig { - termination_type: TerminationTypes; -} +export type ComponentConfig = + | TeamConfig + | AgentConfig + | ModelConfig + | ToolConfig + | TerminationConfig + | ChatCompletionContextConfig; -export interface MaxMessageTerminationConfig extends BaseTerminationConfig { - termination_type: "MaxMessageTermination"; - max_messages: number; +// DB Models +export interface DBModel { + id?: number; + user_id?: string; + created_at?: string; + updated_at?: string; + version?: number; } -export interface TextMentionTerminationConfig extends BaseTerminationConfig { - termination_type: "TextMentionTermination"; - text: string; +export interface Message extends DBModel { + config: AgentMessageConfig; + session_id: number; + run_id: string; } -export interface CombinationTerminationConfig extends BaseTerminationConfig { - termination_type: "CombinationTermination"; - operator: "and" | "or"; - conditions: TerminationConfig[]; +export interface Team extends DBModel { + component: Component; } -export type TerminationConfig = - | MaxMessageTerminationConfig - | TextMentionTerminationConfig - | CombinationTerminationConfig; - -export interface BaseTeamConfig extends BaseConfig { +export interface Session extends DBModel { name: string; - participants: AgentConfig[]; - team_type: TeamTypes; - termination_condition?: TerminationConfig; -} - -export interface RoundRobinGroupChatConfig extends BaseTeamConfig { - team_type: "RoundRobinGroupChat"; + team_id?: number; } -export interface SelectorGroupChatConfig extends BaseTeamConfig { - team_type: "SelectorGroupChat"; - selector_prompt: string; - model_client: ModelConfig; +// Runtime Types +export interface SessionRuns { + runs: Run[]; } -export type TeamConfig = RoundRobinGroupChatConfig | SelectorGroupChatConfig; - -export interface ComponentModel { - provider: string; - component_type: ComponentTypes; - version: number; - component_version: number; - config: any; - description?: string | null; - label?: string; +export interface WebSocketMessage { + type: "message" | "result" | "completion" | "input_request" | "error"; + data?: AgentMessageConfig | TaskResult; + status?: RunStatus; + error?: string; + timestamp?: string; } -export interface Team extends DBModel { - config: ComponentModel; +export interface TaskResult { + messages: AgentMessageConfig[]; + stop_reason?: string; } export interface TeamResult { @@ -303,13 +276,13 @@ export interface Run { status: RunStatus; task: AgentMessageConfig; team_result: TeamResult | null; - messages: Message[]; // Change to Message[] + messages: Message[]; error_message?: string; } export type RunStatus = | "created" - | "active" // covers 'streaming' + | "active" | "awaiting_input" | "timeout" | "complete" diff --git a/python/packages/autogen-studio/frontend/src/components/types/guards.ts b/python/packages/autogen-studio/frontend/src/components/types/guards.ts new file mode 100644 index 000000000000..b01c56838018 --- /dev/null +++ b/python/packages/autogen-studio/frontend/src/components/types/guards.ts @@ -0,0 +1,165 @@ +import type { + Component, + ComponentConfig, + TeamConfig, + AgentConfig, + ModelConfig, + ToolConfig, + TerminationConfig, + ChatCompletionContextConfig, +} from "./datamodel"; + +// Provider constants +const PROVIDERS = { + // Teams + ROUND_ROBIN_TEAM: "autogen_agentchat.teams.RoundRobinGroupChat", + SELECTOR_TEAM: "autogen_agentchat.teams.SelectorGroupChat", + + // Agents + ASSISTANT_AGENT: "autogen_agentchat.agents.AssistantAgent", + USER_PROXY: "autogen_agentchat.agents.UserProxyAgent", + WEB_SURFER: "autogen_ext.agents.web_surfer.MultimodalWebSurfer", + + // Models + OPENAI: "autogen_ext.models.openai.OpenAIChatCompletionClient", + AZURE_OPENAI: + "autogen_ext.models.azure_openai.AzureOpenAIChatCompletionClient", + + // Tools + FUNCTION_TOOL: "autogen_core.tools.FunctionTool", + + // Termination + OR_TERMINATION: "autogen_agentchat.base.OrTerminationCondition", + MAX_MESSAGE: "autogen_agentchat.conditions.MaxMessageTermination", + TEXT_MENTION: "autogen_agentchat.conditions.TextMentionTermination", + + // Contexts + UNBOUNDED_CONTEXT: + "autogen_core.model_context.UnboundedChatCompletionContext", +} as const; + +// Base component type guards +export function isTeamComponent( + component: Component +): component is Component { + return component.component_type === "team"; +} + +export function isAgentComponent( + component: Component +): component is Component { + return component.component_type === "agent"; +} + +export function isModelComponent( + component: Component +): component is Component { + return component.component_type === "model"; +} + +export function isToolComponent( + component: Component +): component is Component { + return component.component_type === "tool"; +} + +export function isTerminationComponent( + component: Component +): component is Component { + return component.component_type === "termination"; +} + +export function isChatCompletionContextComponent( + component: Component +): component is Component { + return component.component_type === "chat_completion_context"; +} + +// Team provider guards +export function isRoundRobinTeam( + component: Component +): boolean { + return component.provider === PROVIDERS.ROUND_ROBIN_TEAM; +} + +export function isSelectorTeam(component: Component): boolean { + return component.provider === PROVIDERS.SELECTOR_TEAM; +} + +// Agent provider guards +export function isAssistantAgent( + component: Component +): boolean { + return component.provider === PROVIDERS.ASSISTANT_AGENT; +} + +export function isUserProxyAgent( + component: Component +): boolean { + return component.provider === PROVIDERS.USER_PROXY; +} + +export function isWebSurferAgent( + component: Component +): boolean { + return component.provider === PROVIDERS.WEB_SURFER; +} + +// Model provider guards +export function isOpenAIModel(component: Component): boolean { + return component.provider === PROVIDERS.OPENAI; +} + +export function isAzureOpenAIModel( + component: Component +): boolean { + return component.provider === PROVIDERS.AZURE_OPENAI; +} + +// Tool provider guards +export function isFunctionTool(component: Component): boolean { + return component.provider === PROVIDERS.FUNCTION_TOOL; +} + +// Termination provider guards +export function isOrTermination( + component: Component +): boolean { + return component.provider === PROVIDERS.OR_TERMINATION; +} + +export function isMaxMessageTermination( + component: Component +): boolean { + return component.provider === PROVIDERS.MAX_MESSAGE; +} + +export function isTextMentionTermination( + component: Component +): boolean { + return component.provider === PROVIDERS.TEXT_MENTION; +} + +// Context provider guards +export function isUnboundedContext( + component: Component +): boolean { + return component.provider === PROVIDERS.UNBOUNDED_CONTEXT; +} + +// Helper function for type narrowing +export function assertComponent( + component: Component, + providerCheck: (component: Component) => boolean +): asserts component is Component { + if (!providerCheck(component)) { + throw new Error( + `Component provider ${component.provider} does not match expected type` + ); + } +} + +// Example usage: +// const component: Component = someComponent; +// assertComponent(component, isRoundRobinTeam); +// Now TypeScript knows component is Component diff --git a/python/packages/autogen-studio/frontend/src/components/views/gallery/store.tsx b/python/packages/autogen-studio/frontend/src/components/views/gallery/store.tsx index c09096f73c2d..5cd2ffdfb00b 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/gallery/store.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/gallery/store.tsx @@ -3,6 +3,7 @@ import { persist } from "zustand/middleware"; import { Gallery } from "./types"; import { AgentConfig, + Component, ModelConfig, TeamConfig, TerminationConfig, @@ -25,12 +26,12 @@ interface GalleryStore { syncGallery: (id: string) => Promise; getLastSyncTime: (id: string) => string | null; getGalleryComponents: () => { - teams: TeamConfig[]; + teams: Component[]; components: { - agents: AgentConfig[]; - models: ModelConfig[]; - tools: ToolConfig[]; - terminations: TerminationConfig[]; + agents: Component[]; + models: Component[]; + tools: Component[]; + terminations: Component[]; }; }; } @@ -150,7 +151,7 @@ export const useGalleryStore = create()( }, }), { - name: "gallery-storage", + name: "gallery-storage-v1", } ) ); diff --git a/python/packages/autogen-studio/frontend/src/components/views/gallery/types.ts b/python/packages/autogen-studio/frontend/src/components/views/gallery/types.ts index 015eb961c926..7db306b904d9 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/gallery/types.ts +++ b/python/packages/autogen-studio/frontend/src/components/views/gallery/types.ts @@ -1,5 +1,6 @@ import { AgentConfig, + Component, ModelConfig, TeamConfig, TerminationConfig, @@ -25,12 +26,12 @@ export interface Gallery { url?: string; metadata: GalleryMetadata; items: { - teams: TeamConfig[]; + teams: Component[]; components: { - agents: AgentConfig[]; - models: ModelConfig[]; - tools: ToolConfig[]; - terminations: TerminationConfig[]; + agents: Component[]; + models: Component[]; + tools: Component[]; + terminations: Component[]; }; }; } diff --git a/python/packages/autogen-studio/frontend/src/components/views/gallery/utils.ts b/python/packages/autogen-studio/frontend/src/components/views/gallery/utils.ts index 2b0e4ea83e8d..f00958078553 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/gallery/utils.ts +++ b/python/packages/autogen-studio/frontend/src/components/views/gallery/utils.ts @@ -1,13 +1,15 @@ import { + Component, + RoundRobinGroupChatConfig, AssistantAgentConfig, - CombinationTerminationConfig, + UserProxyAgentConfig, + OpenAIClientConfig, + FunctionToolConfig, MaxMessageTerminationConfig, - OpenAIModelConfig, - PythonFunctionToolConfig, - RoundRobinGroupChatConfig, TextMentionTerminationConfig, - UserProxyAgentConfig, + OrTerminationConfig, } from "../../types/datamodel"; +import { Gallery } from "./types"; export const defaultGallery = { id: "gallery_default", @@ -26,169 +28,290 @@ export const defaultGallery = { items: { teams: [ { + provider: "autogen_agentchat.teams.RoundRobinGroupChat", component_type: "team", + version: 1, + component_version: 1, description: "A team with an assistant agent and a user agent to enable human-in-loop task completion in a round-robin fashion", - name: "huma_in_loop_team", - participants: [ - { - component_type: "agent", - description: - "An assistant agent that can help users complete tasks", - name: "assistant_agent", - agent_type: "AssistantAgent", - system_message: - "You are a helpful assistant. Solve tasks carefully. You also have a calculator tool which you can use if needed. When the task is done respond with TERMINATE.", - model_client: { - component_type: "model", - description: "A GPT-4o mini model", - model: "gpt-4o-mini", - model_type: "OpenAIChatCompletionClient", - }, - tools: [ - { - component_type: "tool", - name: "calculator", - description: - "A simple calculator that performs basic arithmetic operations between two numbers", - content: - "def calculator(a: float, b: float, operator: str) -> str:\n try:\n if operator == '+':\n return str(a + b)\n elif operator == '-':\n return str(a - b)\n elif operator == '*':\n return str(a * b)\n elif operator == '/':\n if b == 0:\n return 'Error: Division by zero'\n return str(a / b)\n else:\n return 'Error: Invalid operator. Please use +, -, *, or /'\n except Exception as e:\n return f'Error: {str(e)}'", - tool_type: "PythonFunction", - }, - ], - }, - { - component_type: "agent", - description: "A user agent that is driven by a human user", - name: "user_agent", - agent_type: "UserProxyAgent", - tools: [], - }, - ], - team_type: "RoundRobinGroupChat", - termination_condition: { - description: - "Terminate the conversation when the user mentions 'TERMINATE' or after 10 messages", - component_type: "termination", - termination_type: "CombinationTermination", - operator: "or", - conditions: [ + label: "RoundRobinGroupChat", + config: { + participants: [ { - component_type: "termination", + provider: "autogen_agentchat.agents.AssistantAgent", + component_type: "agent", + version: 1, + component_version: 1, description: - "Terminate the conversation when the user mentions 'TERMINATE'", - termination_type: "TextMentionTermination", - text: "TERMINATE", + "An assistant agent that can help users complete tasks", + label: "AssistantAgent", + config: { + name: "assistant_agent", + model_client: { + provider: + "autogen_ext.models.openai.OpenAIChatCompletionClient", + component_type: "model", + version: 1, + component_version: 1, + description: "A GPT-4o mini model", + label: "OpenAIChatCompletionClient", + config: { + model: "gpt-4o-mini", + }, + }, + tools: [ + { + provider: "autogen_core.tools.FunctionTool", + component_type: "tool", + version: 1, + component_version: 1, + description: + "Create custom tools by wrapping standard Python functions", + label: "FunctionTool", + config: { + name: "calculator", + description: + "A simple calculator that performs basic arithmetic operations between two numbers", + source_code: + "def calculator(a: float, b: float, operator: str) -> str:\n try:\n if operator == '+':\n return str(a + b)\n elif operator == '-':\n return str(a - b)\n elif operator == '*':\n return str(a * b)\n elif operator == '/':\n if b == 0:\n return 'Error: Division by zero'\n return str(a / b)\n else:\n return 'Error: Invalid operator. Please use +, -, *, or /'\n except Exception as e:\n return f'Error: {str(e)}'", + global_imports: [], + has_cancellation_support: false, + }, + }, + ], + description: + "An agent that provides assistance with ability to use tools", + system_message: + "You are a helpful assistant. Solve tasks carefully. You also have a calculator tool which you can use if needed. When the task is done respond with TERMINATE.", + reflect_on_tool_use: false, + tool_call_summary_format: "{result}", + }, }, { - component_type: "termination", - description: "Terminate the conversation after 10 messages", - termination_type: "MaxMessageTermination", - max_messages: 10, + provider: "autogen_agentchat.agents.UserProxyAgent", + component_type: "agent", + version: 1, + component_version: 1, + description: "A user agent that is driven by a human user", + label: "UserProxyAgent", + config: { + name: "user_agent", + description: "A user agent that is driven by a human user", + }, }, ], - }, - } as RoundRobinGroupChatConfig, + termination_condition: { + provider: "autogen_agentchat.base.OrTerminationCondition", + component_type: "termination", + version: 1, + component_version: 1, + label: "OrTerminationCondition", + config: { + conditions: [ + { + provider: + "autogen_agentchat.conditions.TextMentionTermination", + component_type: "termination", + version: 1, + component_version: 1, + description: + "Terminate the conversation when the user mentions 'TERMINATE'", + label: "TextMentionTermination", + config: { + text: "TERMINATE", + }, + }, + { + provider: + "autogen_agentchat.conditions.MaxMessageTermination", + component_type: "termination", + version: 1, + component_version: 1, + description: "Terminate the conversation after 10 messages", + label: "MaxMessageTermination", + config: { + max_messages: 10, + }, + }, + ], + }, + }, + max_turns: 1, + } as RoundRobinGroupChatConfig, + }, ], components: { agents: [ { + provider: "autogen_agentchat.agents.AssistantAgent", component_type: "agent", + version: 1, + component_version: 1, description: "An assistant agent that can help users complete tasks", - name: "assistant_agent", - agent_type: "AssistantAgent", - system_message: - "You are a helpful assistant. Solve tasks carefully. You also have a calculator tool which you can use if needed. When the task is done respond with TERMINATE.", - model_client: { - component_type: "model", - description: "A GPT-4o mini model", - model: "gpt-4o-mini", - model_type: "OpenAIChatCompletionClient", - }, - tools: [ - { - component_type: "tool", - name: "calculator", - description: - "A simple calculator that performs basic arithmetic operations between two numbers", - content: - "def calculator(a: float, b: float, operator: str) -> str:\n try:\n if operator == '+':\n return str(a + b)\n elif operator == '-':\n return str(a - b)\n elif operator == '*':\n return str(a * b)\n elif operator == '/':\n if b == 0:\n return 'Error: Division by zero'\n return str(a / b)\n else:\n return 'Error: Invalid operator. Please use +, -, *, or /'\n except Exception as e:\n return f'Error: {str(e)}'", - tool_type: "PythonFunction", + label: "AssistantAgent", + config: { + name: "assistant_agent", + model_client: { + provider: "autogen_ext.models.openai.OpenAIChatCompletionClient", + component_type: "model", + version: 1, + component_version: 1, + description: "A GPT-4o mini model", + label: "OpenAIChatCompletionClient", + config: { + model: "gpt-4o-mini", + }, }, - ], - } as AssistantAgentConfig, + tools: [ + { + provider: "autogen_core.tools.FunctionTool", + component_type: "tool", + version: 1, + component_version: 1, + description: + "Create custom tools by wrapping standard Python functions", + label: "FunctionTool", + config: { + name: "calculator", + description: + "A simple calculator that performs basic arithmetic operations", + source_code: + "def calculator(a: float, b: float, operator: str) -> str:\n try:\n if operator == '+':\n return str(a + b)\n elif operator == '-':\n return str(a - b)\n elif operator == '*':\n return str(a * b)\n elif operator == '/':\n if b == 0:\n return 'Error: Division by zero'\n return str(a / b)\n else:\n return 'Error: Invalid operator. Please use +, -, *, or /'\n except Exception as e:\n return f'Error: {str(e)}'", + global_imports: [], + has_cancellation_support: false, + }, + }, + ], + description: + "An agent that provides assistance with ability to use tools", + system_message: + "You are a helpful assistant. Solve tasks carefully. When the task is done respond with TERMINATE.", + reflect_on_tool_use: false, + tool_call_summary_format: "{result}", + } as AssistantAgentConfig, + }, { + provider: "autogen_agentchat.agents.UserProxyAgent", component_type: "agent", + version: 1, + component_version: 1, description: "A user agent that is driven by a human user", - name: "user_agent", - agent_type: "UserProxyAgent", - tools: [], - } as UserProxyAgentConfig, + label: "UserProxyAgent", + config: { + name: "user_agent", + description: "A user agent that is driven by a human user", + } as UserProxyAgentConfig, + }, ], models: [ { + provider: "autogen_ext.models.openai.OpenAIChatCompletionClient", component_type: "model", + version: 1, + component_version: 1, description: "A GPT-4o mini model", - model: "gpt-4o-mini", - model_type: "OpenAIChatCompletionClient", - } as OpenAIModelConfig, + label: "OpenAIChatCompletionClient", + config: { + model: "gpt-4o-mini", + } as OpenAIClientConfig, + }, ], tools: [ { + provider: "autogen_core.tools.FunctionTool", component_type: "tool", - name: "calculator", + version: 1, + component_version: 1, description: - "A simple calculator that performs basic arithmetic operations between two numbers", - content: - "def calculator(a: float, b: float, operator: str) -> str:\n try:\n if operator == '+':\n return str(a + b)\n elif operator == '-':\n return str(a - b)\n elif operator == '*':\n return str(a * b)\n elif operator == '/':\n if b == 0:\n return 'Error: Division by zero'\n return str(a / b)\n else:\n return 'Error: Invalid operator. Please use +, -, *, or /'\n except Exception as e:\n return f'Error: {str(e)}'", - tool_type: "PythonFunction", - } as PythonFunctionToolConfig, + "Create custom tools by wrapping standard Python functions", + label: "FunctionTool", + config: { + name: "calculator", + description: + "A simple calculator that performs basic arithmetic operations", + source_code: + "def calculator(a: float, b: float, operator: str) -> str:\n try:\n if operator == '+':\n return str(a + b)\n elif operator == '-':\n return str(a - b)\n elif operator == '*':\n return str(a * b)\n elif operator == '/':\n if b == 0:\n return 'Error: Division by zero'\n return str(a / b)\n else:\n return 'Error: Invalid operator. Please use +, -, *, or /'\n except Exception as e:\n return f'Error: {str(e)}'", + global_imports: [], + has_cancellation_support: false, + } as FunctionToolConfig, + }, { + provider: "autogen_core.tools.FunctionTool", component_type: "tool", - name: "fetch_website", - description: "Fetch and return the content of a website URL", - content: - "async def fetch_website(url: str) -> str:\n try:\n import requests\n from urllib.parse import urlparse\n \n # Validate URL format\n parsed = urlparse(url)\n if not parsed.scheme or not parsed.netloc:\n return \"Error: Invalid URL format. Please include http:// or https://\"\n \n # Add scheme if not present\n if not url.startswith(('http://', 'https://')): \n url = 'https://' + url\n \n # Set headers to mimic a browser request\n headers = {\n 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'\n }\n \n # Make the request with a timeout\n response = requests.get(url, headers=headers, timeout=10)\n response.raise_for_status()\n \n # Return the text content\n return response.text\n \n except requests.exceptions.Timeout:\n return \"Error: Request timed out\"\n except requests.exceptions.ConnectionError:\n return \"Error: Failed to connect to the website\"\n except requests.exceptions.HTTPError as e:\n return f\"Error: HTTP {e.response.status_code} - {e.response.reason}\"\n except Exception as e:\n return f\"Error: {str(e)}\"", - tool_type: "PythonFunction", - } as PythonFunctionToolConfig, + version: 1, + component_version: 1, + description: + "Create custom tools by wrapping standard Python functions", + label: "FunctionTool", + config: { + name: "fetch_website", + description: "Fetch and return the content of a website URL", + source_code: + "async def fetch_website(url: str) -> str:\n try:\n import requests\n from urllib.parse import urlparse\n \n # Validate URL format\n parsed = urlparse(url)\n if not parsed.scheme or not parsed.netloc:\n return \"Error: Invalid URL format. Please include http:// or https://\"\n \n # Add scheme if not present\n if not url.startswith(('http://', 'https://')): \n url = 'https://' + url\n \n # Set headers to mimic a browser request\n headers = {\n 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'\n }\n \n # Make the request with a timeout\n response = requests.get(url, headers=headers, timeout=10)\n response.raise_for_status()\n \n # Return the text content\n return response.text\n \n except requests.exceptions.Timeout:\n return \"Error: Request timed out\"\n except requests.exceptions.ConnectionError:\n return \"Error: Failed to connect to the website\"\n except requests.exceptions.HTTPError as e:\n return f\"Error: HTTP {e.response.status_code} - {e.response.reason}\"\n except Exception as e:\n return f\"Error: {str(e)}\"", + global_imports: [], + has_cancellation_support: false, + }, + }, ], terminations: [ { + provider: "autogen_agentchat.conditions.TextMentionTermination", component_type: "termination", + version: 1, + component_version: 1, description: "Terminate the conversation when the user mentions 'TERMINATE'", - termination_type: "TextMentionTermination", - text: "TERMINATE", - } as TextMentionTerminationConfig, + label: "TextMentionTermination", + config: { + text: "TERMINATE", + } as TextMentionTerminationConfig, + }, { + provider: "autogen_agentchat.conditions.MaxMessageTermination", component_type: "termination", + version: 1, + component_version: 1, description: "Terminate the conversation after 10 messages", - termination_type: "MaxMessageTermination", - max_messages: 10, - } as MaxMessageTerminationConfig, + label: "MaxMessageTermination", + config: { + max_messages: 10, + } as MaxMessageTerminationConfig, + }, { + provider: "autogen_agentchat.base.OrTerminationCondition", component_type: "termination", - description: - "Terminate the conversation when the user mentions 'TERMINATE' or after 10 messages", - termination_type: "CombinationTermination", - operator: "or", - conditions: [ - { - component_type: "termination", - description: - "Terminate the conversation when the user mentions 'TERMINATE'", - termination_type: "TextMentionTermination", - text: "TERMINATE", - }, - { - component_type: "termination", - description: "Terminate the conversation after 10 messages", - termination_type: "MaxMessageTermination", - max_messages: 10, - }, - ], - } as CombinationTerminationConfig, + version: 1, + component_version: 1, + description: "Terminate on either condition", + label: "OrTerminationCondition", + config: { + conditions: [ + { + provider: "autogen_agentchat.conditions.TextMentionTermination", + component_type: "termination", + version: 1, + component_version: 1, + description: "Terminate on TERMINATE", + label: "TextMentionTermination", + config: { + text: "TERMINATE", + }, + }, + { + provider: "autogen_agentchat.conditions.MaxMessageTermination", + component_type: "termination", + version: 1, + component_version: 1, + description: "Terminate after 10 messages", + label: "MaxMessageTermination", + config: { + max_messages: 10, + }, + }, + ], + } as OrTerminationConfig, + }, ], }, }, -}; +} as Gallery; diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/library.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/builder/library.tsx index c5bdfceb1269..a1118d6657c4 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/builder/library.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/library.tsx @@ -91,7 +91,7 @@ export const ComponentLibrary: React.FC = () => { title: "Agents", type: "agent" as ComponentTypes, items: defaultGallery.items.components.agents.map((agent) => ({ - label: agent.name, + label: agent.label, config: agent, })), icon: , @@ -100,7 +100,7 @@ export const ComponentLibrary: React.FC = () => { title: "Models", type: "model" as ComponentTypes, items: defaultGallery.items.components.models.map((model) => ({ - label: `${model.model_type} - ${model.model}`, + label: `${model.component_type} - ${model.config.model}`, config: model, })), icon: , @@ -109,7 +109,7 @@ export const ComponentLibrary: React.FC = () => { title: "Tools", type: "tool" as ComponentTypes, items: defaultGallery.items.components.tools.map((tool) => ({ - label: tool.name, + label: tool.label, config: tool, })), icon: , @@ -119,7 +119,7 @@ export const ComponentLibrary: React.FC = () => { type: "termination" as ComponentTypes, items: defaultGallery.items.components.terminations.map( (termination) => ({ - label: `${termination.termination_type}`, + label: `${termination.label}`, config: termination, }) ), @@ -131,7 +131,7 @@ export const ComponentLibrary: React.FC = () => { const items: CollapseProps["items"] = sections.map((section) => { const filteredItems = section.items.filter((item) => - item.label.toLowerCase().includes(searchTerm.toLowerCase()) + item.label?.toLowerCase().includes(searchTerm.toLowerCase()) ); return { @@ -153,7 +153,7 @@ export const ComponentLibrary: React.FC = () => { id={`${section.title.toLowerCase()}-${itemIndex}`} type={section.type} config={item.config} - label={item.label} + label={item.label || ""} icon={section.icon} /> ))} diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/hooks.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/hooks.tsx new file mode 100644 index 000000000000..8079a6068262 --- /dev/null +++ b/python/packages/autogen-studio/frontend/src/components/views/team/hooks.tsx @@ -0,0 +1,39 @@ +import { config } from "process"; +import { Team } from "../../types/datamodel"; + +export const useTeam = (team: Team) => { + return { + id: team.id, + name: team.config.config.name, + type: team.config.config.team_type, + participants: team.config.config.participants, + updated: team.updated_at, + config: team.config, + + // Helper methods + setName: (name: string) => { + team.config.config.name = name; + }, + + // Computed properties + agentCount: team.config.config.participants.length, + + // Type guards + isRoundRobin: () => team.config.config.team_type === "RoundRobinGroupChat", + isSelector: () => team.config.config.team_type === "SelectorGroupChat", + }; +}; + +// For creating new teams +export const useTeamCreation = () => { + const createTeamName = () => `new_team_${new Date().getTime()}`; + + return { + createTeamName, + initNewTeam: (baseTeam: Team) => { + const team = Object.assign({}, baseTeam); + team.config.config.name = createTeamName(); + return team; + }, + }; +}; diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/manager.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/manager.tsx index dca1fafe6b02..6191c553a00f 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/manager.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/manager.tsx @@ -201,7 +201,7 @@ export const TeamManager: React.FC = () => { <> - {currentTeam.config.name} + {currentTeam.component.label} {currentTeam.id ? ( "" ) : ( diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx index 42f6a9c3b3ae..39cac6bc321d 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx @@ -42,10 +42,11 @@ export const TeamSidebar: React.FC = ({ isLoading = false, }) => { const defaultGallery = useGalleryStore((state) => state.getDefaultGallery()); + console.log(defaultGallery); const createTeam = () => { const newTeam = Object.assign({}, defaultTeam); - newTeam.config.name = "new_team_" + new Date().getTime(); + newTeam.component.label = "new_team_" + new Date().getTime(); onCreateTeam(newTeam); }; // Render collapsed state @@ -161,7 +162,7 @@ export const TeamSidebar: React.FC = ({ {/* Team Name and Actions Row */}
- {team.config.name} + {team.component.label}
{/* @@ -195,13 +196,13 @@ export const TeamSidebar: React.FC = ({ {/* Team Metadata Row */}
- {team.config.team_type} + {team.component.component_type}
- {team.config.participants.length}{" "} - {team.config.participants.length === 1 + {team.component.config.participants.length}{" "} + {team.component.config.participants.length === 1 ? "agent" : "agents"} @@ -232,7 +233,7 @@ export const TeamSidebar: React.FC = ({
{defaultGallery?.items.teams.map((galleryTeam) => (
= ({ {/* Team Name and Use Template Action */}
- {galleryTeam.name} + {galleryTeam.label}
@@ -254,10 +255,10 @@ export const TeamSidebar: React.FC = ({ icon={} onClick={(e) => { e.stopPropagation(); - galleryTeam.name = - galleryTeam.name + "_" + new Date().getTime(); + galleryTeam.label = + galleryTeam.label + "_" + new Date().getTime(); onCreateTeam({ - config: galleryTeam, + component: galleryTeam, }); }} /> @@ -268,13 +269,13 @@ export const TeamSidebar: React.FC = ({ {/* Team Metadata Row */}
- {galleryTeam.team_type} + {galleryTeam.component_type}
- {galleryTeam.participants.length}{" "} - {galleryTeam.participants.length === 1 + {galleryTeam.config.participants.length}{" "} + {galleryTeam.config.participants.length === 1 ? "agent" : "agents"} diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/types.ts b/python/packages/autogen-studio/frontend/src/components/views/team/types.ts index 399863cc789c..fe4a5fa28928 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/types.ts +++ b/python/packages/autogen-studio/frontend/src/components/views/team/types.ts @@ -1,4 +1,4 @@ -import type { Team, TeamConfig } from "../../types/datamodel"; +import type { Component, Team, TeamConfig } from "../../types/datamodel"; export interface TeamEditorProps { team?: Team; @@ -16,12 +16,14 @@ export interface TeamListProps { isLoading?: boolean; } -export const defaultTeamConfig: TeamConfig = { +export const defaultTeamConfig: Component = { provider: "autogen_agentchat.teams.RoundRobinGroupChat", component_type: "team", version: 1, component_version: 1, - description: null, + description: + "A team of agents that chat with users in a round-robin fashion.", + label: "General Team", config: { participants: [ { @@ -101,5 +103,5 @@ export const defaultTeamConfig: TeamConfig = { }; export const defaultTeam: Team = { - config: defaultTeamConfig, + component: defaultTeamConfig, }; From 0c022fe5fce74e91744e207bfb11130cded1aa00 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Mon, 27 Jan 2025 12:59:18 -0800 Subject: [PATCH 08/16] inititial ui updates --- .../frontend/src/components/types/guards.ts | 137 +++++++++++++----- .../components/views/team/builder/builder.tsx | 2 +- .../components/views/team/builder/store.tsx | 65 ++++----- .../components/views/team/builder/types.ts | 8 +- .../src/components/views/team/sidebar.tsx | 2 - 5 files changed, 131 insertions(+), 83 deletions(-) diff --git a/python/packages/autogen-studio/frontend/src/components/types/guards.ts b/python/packages/autogen-studio/frontend/src/components/types/guards.ts index b01c56838018..5264ff556993 100644 --- a/python/packages/autogen-studio/frontend/src/components/types/guards.ts +++ b/python/packages/autogen-studio/frontend/src/components/types/guards.ts @@ -7,6 +7,18 @@ import type { ToolConfig, TerminationConfig, ChatCompletionContextConfig, + SelectorGroupChatConfig, + RoundRobinGroupChatConfig, + MultimodalWebSurferConfig, + AssistantAgentConfig, + UserProxyAgentConfig, + OpenAIClientConfig, + AzureOpenAIClientConfig, + FunctionToolConfig, + OrTerminationConfig, + MaxMessageTerminationConfig, + TextMentionTerminationConfig, + UnboundedChatCompletionContextConfig, } from "./datamodel"; // Provider constants @@ -38,6 +50,48 @@ const PROVIDERS = { "autogen_core.model_context.UnboundedChatCompletionContext", } as const; +// Provider type and mapping +export type Provider = (typeof PROVIDERS)[keyof typeof PROVIDERS]; + +type ProviderToConfig = { + // Teams + [PROVIDERS.SELECTOR_TEAM]: SelectorGroupChatConfig; + [PROVIDERS.ROUND_ROBIN_TEAM]: RoundRobinGroupChatConfig; + + // Agents + [PROVIDERS.ASSISTANT_AGENT]: AssistantAgentConfig; + [PROVIDERS.USER_PROXY]: UserProxyAgentConfig; + [PROVIDERS.WEB_SURFER]: MultimodalWebSurferConfig; + + // Models + [PROVIDERS.OPENAI]: OpenAIClientConfig; + [PROVIDERS.AZURE_OPENAI]: AzureOpenAIClientConfig; + + // Tools + [PROVIDERS.FUNCTION_TOOL]: FunctionToolConfig; + + // Termination + [PROVIDERS.OR_TERMINATION]: OrTerminationConfig; + [PROVIDERS.MAX_MESSAGE]: MaxMessageTerminationConfig; + [PROVIDERS.TEXT_MENTION]: TextMentionTerminationConfig; + + // Contexts + [PROVIDERS.UNBOUNDED_CONTEXT]: UnboundedChatCompletionContextConfig; +}; + +// Helper type to get config type from provider +type ConfigForProvider

= P extends keyof ProviderToConfig + ? ProviderToConfig[P] + : never; + +// Generic component type guard +function isComponentOfType

( + component: Component, + provider: P +): component is Component> { + return component.provider === provider; +} + // Base component type guards export function isTeamComponent( component: Component @@ -75,91 +129,94 @@ export function isChatCompletionContextComponent( return component.component_type === "chat_completion_context"; } -// Team provider guards +// Team provider guards with proper type narrowing export function isRoundRobinTeam( component: Component -): boolean { - return component.provider === PROVIDERS.ROUND_ROBIN_TEAM; +): component is Component { + return isComponentOfType(component, PROVIDERS.ROUND_ROBIN_TEAM); } -export function isSelectorTeam(component: Component): boolean { - return component.provider === PROVIDERS.SELECTOR_TEAM; +export function isSelectorTeam( + component: Component +): component is Component { + return isComponentOfType(component, PROVIDERS.SELECTOR_TEAM); } -// Agent provider guards +// Agent provider guards with proper type narrowing export function isAssistantAgent( component: Component -): boolean { - return component.provider === PROVIDERS.ASSISTANT_AGENT; +): component is Component { + return isComponentOfType(component, PROVIDERS.ASSISTANT_AGENT); } export function isUserProxyAgent( component: Component -): boolean { - return component.provider === PROVIDERS.USER_PROXY; +): component is Component { + return isComponentOfType(component, PROVIDERS.USER_PROXY); } export function isWebSurferAgent( component: Component -): boolean { - return component.provider === PROVIDERS.WEB_SURFER; +): component is Component { + return isComponentOfType(component, PROVIDERS.WEB_SURFER); } -// Model provider guards -export function isOpenAIModel(component: Component): boolean { - return component.provider === PROVIDERS.OPENAI; +// Model provider guards with proper type narrowing +export function isOpenAIModel( + component: Component +): component is Component { + return isComponentOfType(component, PROVIDERS.OPENAI); } export function isAzureOpenAIModel( component: Component -): boolean { - return component.provider === PROVIDERS.AZURE_OPENAI; +): component is Component { + return isComponentOfType(component, PROVIDERS.AZURE_OPENAI); } -// Tool provider guards -export function isFunctionTool(component: Component): boolean { - return component.provider === PROVIDERS.FUNCTION_TOOL; +// Tool provider guards with proper type narrowing +export function isFunctionTool( + component: Component +): component is Component { + return isComponentOfType(component, PROVIDERS.FUNCTION_TOOL); } -// Termination provider guards +// Termination provider guards with proper type narrowing export function isOrTermination( component: Component -): boolean { - return component.provider === PROVIDERS.OR_TERMINATION; +): component is Component { + return isComponentOfType(component, PROVIDERS.OR_TERMINATION); } export function isMaxMessageTermination( component: Component -): boolean { - return component.provider === PROVIDERS.MAX_MESSAGE; +): component is Component { + return isComponentOfType(component, PROVIDERS.MAX_MESSAGE); } export function isTextMentionTermination( component: Component -): boolean { - return component.provider === PROVIDERS.TEXT_MENTION; +): component is Component { + return isComponentOfType(component, PROVIDERS.TEXT_MENTION); } -// Context provider guards +// Context provider guards with proper type narrowing export function isUnboundedContext( component: Component -): boolean { - return component.provider === PROVIDERS.UNBOUNDED_CONTEXT; +): component is Component { + return isComponentOfType(component, PROVIDERS.UNBOUNDED_CONTEXT); } -// Helper function for type narrowing -export function assertComponent( +// Runtime assertions +export function assertComponentType

( component: Component, - providerCheck: (component: Component) => boolean -): asserts component is Component { - if (!providerCheck(component)) { + provider: P +): asserts component is Component> { + if (!isComponentOfType(component, provider)) { throw new Error( - `Component provider ${component.provider} does not match expected type` + `Expected component with provider ${provider}, got ${component.provider}` ); } } -// Example usage: -// const component: Component = someComponent; -// assertComponent(component, isRoundRobinTeam); -// Now TypeScript knows component is Component +export { PROVIDERS }; diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/builder.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/builder/builder.tsx index 6abcb4cfcf69..4369916c53a6 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/builder/builder.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/builder.tsx @@ -113,7 +113,7 @@ export const TeamBuilder: React.FC = ({ // Load initial config React.useEffect(() => { - if (team?.config) { + if (team?.component) { const { nodes: initialNodes, edges: initialEdges } = loadFromJson( team.config ); diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/store.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/builder/store.tsx index 9c591bdd8bac..a3fbff57be0c 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/builder/store.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/store.tsx @@ -15,38 +15,27 @@ import { ComponentTypes, ComponentConfigTypes, TerminationConfig, + Component, } from "../../../types/datamodel"; import { convertTeamConfigToGraph, getLayoutedElements } from "./utils"; +import { + isTeamComponent, + isAgentComponent, + isToolComponent, + isTerminationComponent, + isModelComponent, + isSelectorTeam, +} from "../../../types/guards"; const MAX_HISTORY = 50; -const isTeamConfig = (config: any): config is TeamConfig => { - return "team_type" in config; -}; - -const isAgentConfig = (config: any): config is AgentConfig => { - return "agent_type" in config; -}; - -const isModelConfig = (config: any): config is ModelConfig => { - return "model_type" in config; -}; - -const isToolConfig = (config: any): config is ToolConfig => { - return "tool_type" in config; -}; - -const isTerminationConfig = (config: any): config is TerminationConfig => { - return "termination_type" in config; -}; - export interface TeamBuilderState { nodes: CustomNode[]; edges: CustomEdge[]; selectedNodeId: string | null; history: Array<{ nodes: CustomNode[]; edges: CustomEdge[] }>; currentHistoryIndex: number; - originalConfig: TeamConfig | null; + originalComponent: Component | null; addNode: ( type: ComponentTypes, position: Position, @@ -66,20 +55,20 @@ export interface TeamBuilderState { redo: () => void; // Sync with JSON - syncToJson: () => TeamConfig | null; - loadFromJson: (config: TeamConfig) => GraphState; + syncToJson: () => Component | null; + loadFromJson: (config: Component) => GraphState; layoutNodes: () => void; resetHistory: () => void; } -const buildTeamConfig = ( +const buildTeamComponent = ( teamNode: CustomNode, nodes: CustomNode[], edges: CustomEdge[] -): TeamConfig | null => { - if (!isTeamConfig(teamNode.data.config)) return null; +): Component | null => { + if (!isTeamComponent(teamNode.data.component)) return null; - const config = { ...teamNode.data.config }; + const component = { ...teamNode.data.component }; // Use edge queries instead of connections const modelEdge = edges.find( @@ -89,10 +78,10 @@ const buildTeamConfig = ( const modelNode = nodes.find((n) => n.id === modelEdge.source); if ( modelNode && - isModelConfig(modelNode.data.config) && - config.team_type === "SelectorGroupChat" + isModelComponent(modelNode.data.component) && + isSelectorTeam(component) ) { - config.model_client = modelNode.data.config; + component.config.model_client = modelNode.data.component; } } @@ -102,8 +91,14 @@ const buildTeamConfig = ( ); if (terminationEdge) { const terminationNode = nodes.find((n) => n.id === terminationEdge.source); - if (terminationNode && isTerminationConfig(terminationNode.data.config)) { - config.termination_condition = terminationNode.data.config; + // if (terminationNode && isTerminationConfig(terminationNode.data.config)) { + // config.termination_condition = terminationNode.data.config; + // } + if ( + terminationNode && + isTerminationComponent(terminationNode.data.component) + ) { + component.config.termination_condition = terminationNode.data.component; } } @@ -153,7 +148,7 @@ export const useTeamBuilderStore = create((set, get) => ({ selectedNodeId: null, history: [], currentHistoryIndex: -1, - originalConfig: null, + originalComponent: null, addNode: ( type: ComponentTypes, @@ -623,7 +618,7 @@ export const useTeamBuilderStore = create((set, get) => ({ if (teamNodes.length === 0) return null; const teamNode = teamNodes[0]; - return buildTeamConfig(teamNode, state.nodes, state.edges); + return buildTeamComponent(teamNode, state.nodes, state.edges); }, layoutNodes: () => { @@ -655,7 +650,7 @@ export const useTeamBuilderStore = create((set, get) => ({ return { nodes: layoutedNodes, edges: layoutedEdges, - originalConfig: config, + originalComponent: config, history: [{ nodes: layoutedNodes, edges: layoutedEdges }], // Reset history currentHistoryIndex: 0, // Reset to 0 selectedNodeId: null, diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/types.ts b/python/packages/autogen-studio/frontend/src/components/views/team/builder/types.ts index ca96204c9549..0312cf8049cc 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/builder/types.ts +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/types.ts @@ -1,10 +1,9 @@ import { Node, Edge } from "@xyflow/react"; -import { ComponentConfigTypes, ComponentTypes } from "../../../types/datamodel"; +import { Component, ComponentConfig } from "../../../types/datamodel"; export interface NodeData extends Record { label: string; - type: ComponentTypes; - config: ComponentConfigTypes; + component: Component; } // Define our node type that extends the XYFlow Node type @@ -41,8 +40,7 @@ export interface FormFieldMapping { } export interface DragItem { - type: ComponentTypes; - config: ComponentConfigTypes; + config: ComponentConfig; } export interface NodeComponentProps { diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx index 39cac6bc321d..93c52f856024 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx @@ -42,8 +42,6 @@ export const TeamSidebar: React.FC = ({ isLoading = false, }) => { const defaultGallery = useGalleryStore((state) => state.getDefaultGallery()); - console.log(defaultGallery); - const createTeam = () => { const newTeam = Object.assign({}, defaultTeam); newTeam.component.label = "new_team_" + new Date().getTime(); From c9bed36d6da0ebe210017f8facfd617ccecabea0 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Tue, 28 Jan 2025 10:08:17 -0800 Subject: [PATCH 09/16] checkpoint complex graph --- .../src/components/types/datamodel.ts | 15 +- .../frontend/src/components/types/guards.ts | 10 +- .../frontend/src/components/views/atoms.tsx | 2 +- .../components/views/team/builder/builder.tsx | 30 +- .../components/views/team/builder/nodes.tsx | 341 +++++++------ .../components/views/team/builder/store.tsx | 461 +++++++++++------- .../components/views/team/builder/types.ts | 7 +- .../components/views/team/builder/utils.ts | 137 +++--- .../src/components/views/team/manager.tsx | 3 +- .../src/components/views/team/sidebar.tsx | 3 +- 10 files changed, 592 insertions(+), 417 deletions(-) diff --git a/python/packages/autogen-studio/frontend/src/components/types/datamodel.ts b/python/packages/autogen-studio/frontend/src/components/types/datamodel.ts index 76717271885b..195950f08100 100644 --- a/python/packages/autogen-studio/frontend/src/components/types/datamodel.ts +++ b/python/packages/autogen-studio/frontend/src/components/types/datamodel.ts @@ -1,13 +1,14 @@ // Base Component System + +export type ComponentTypes = + | "team" + | "agent" + | "model" + | "tool" + | "termination"; export interface Component { provider: string; - component_type: - | "team" - | "agent" - | "model" - | "tool" - | "termination" - | "chat_completion_context"; + component_type: ComponentTypes; version?: number; component_version?: number; description?: string | null; diff --git a/python/packages/autogen-studio/frontend/src/components/types/guards.ts b/python/packages/autogen-studio/frontend/src/components/types/guards.ts index 5264ff556993..d2f64ff99ace 100644 --- a/python/packages/autogen-studio/frontend/src/components/types/guards.ts +++ b/python/packages/autogen-studio/frontend/src/components/types/guards.ts @@ -123,11 +123,11 @@ export function isTerminationComponent( return component.component_type === "termination"; } -export function isChatCompletionContextComponent( - component: Component -): component is Component { - return component.component_type === "chat_completion_context"; -} +// export function isChatCompletionContextComponent( +// component: Component +// ): component is Component { +// return component.component_type === "chat_completion_context"; +// } // Team provider guards with proper type narrowing export function isRoundRobinTeam( diff --git a/python/packages/autogen-studio/frontend/src/components/views/atoms.tsx b/python/packages/autogen-studio/frontend/src/components/views/atoms.tsx index fe6199083e9c..8fac8ecdc2af 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/atoms.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/atoms.tsx @@ -81,7 +81,7 @@ export const TruncatableText = memo( ${className} `} > - {displayContent} + {/* {displayContent} */} {displayContent} {shouldTruncate && !isExpanded && (

diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/builder.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/builder/builder.tsx index 4369916c53a6..8c22e132ffda 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/builder/builder.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/builder.tsx @@ -1,3 +1,4 @@ +//team/builder/builder.tsx import React, { useCallback, useRef, useState } from "react"; import { DndContext, @@ -115,7 +116,7 @@ export const TeamBuilder: React.FC = ({ React.useEffect(() => { if (team?.component) { const { nodes: initialNodes, edges: initialEdges } = loadFromJson( - team.config + team.component ); setNodes(initialNodes); setEdges(initialEdges); @@ -139,21 +140,21 @@ export const TeamBuilder: React.FC = ({ // Handle save const handleSave = useCallback(async () => { try { - const config = syncToJson(); - if (!config) { + const component = syncToJson(); + if (!component) { throw new Error("Unable to generate valid configuration"); } if (onChange) { - console.log("Saving team configuration", config); + console.log("Saving team configuration", component); const teamData: Partial = team ? { ...team, - config, + component, created_at: undefined, updated_at: undefined, } - : { config }; + : { component }; await onChange(teamData); resetHistory(); } @@ -212,7 +213,10 @@ export const TeamBuilder: React.FC = ({ const targetNode = nodes.find((node) => node.id === over.id); if (!targetNode) return; - const isValid = validateDropTarget(draggedType, targetNode.data.type); + const isValid = validateDropTarget( + draggedType, + targetNode.data.component.component_type + ); // Add visual feedback class to target node if (isValid) { targetNode.className = "drop-target-valid"; @@ -235,7 +239,10 @@ export const TeamBuilder: React.FC = ({ if (!targetNode) return; // Validate drop - const isValid = validateDropTarget(draggedItem.type, targetNode.data.type); + const isValid = validateDropTarget( + draggedItem.type, + targetNode.data.component.component_type + ); if (!isValid) return; const position = { @@ -244,12 +251,7 @@ export const TeamBuilder: React.FC = ({ }; // Pass both new node data AND target node id - addNode( - draggedItem.type as ComponentTypes, - position, - draggedItem.config, - nodeId - ); + addNode(position, draggedItem.config, nodeId); }; const onDragStart = (item: DragItem) => { diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/nodes.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/builder/nodes.tsx index 7701e11d26a5..2c212b709f63 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/builder/nodes.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/nodes.tsx @@ -26,13 +26,28 @@ import { ToolConfig, TerminationConfig, ComponentTypes, + Component, + ComponentConfig, } from "../../../types/datamodel"; import { useDroppable } from "@dnd-kit/core"; import { TruncatableText } from "../../atoms"; import { useTeamBuilderStore } from "./store"; +import { + isAssistantAgent, + isAzureOpenAIModel, + isFunctionTool, + isMaxMessageTermination, + isOpenAIModel, + isOrTermination, + isSelectorTeam, + isTextMentionTermination, +} from "../../../types/guards"; // Icon mapping for different node types -const iconMap: Record = { +const iconMap: Record< + Component["component_type"], + LucideIcon +> = { team: Users, agent: Bot, tool: Wrench, @@ -123,7 +138,7 @@ const BaseNode: React.FC = ({
- {data.type} + {data.component.component_type}
- {data.config.description && ( -
- -
- )} +
+ {descriptionContent} +
{children}
@@ -194,10 +204,9 @@ const ConnectionBadge: React.FC<{ // Team Node export const TeamNode: React.FC> = (props) => { - const config = props.data.config as TeamConfig; - const hasModel = - config.team_type === "SelectorGroupChat" && !!config.model_client; - const participantCount = config.participants?.length || 0; + const component = props.data.component as Component; + const hasModel = isSelectorTeam(component) && !!component.config.model_client; + const participantCount = component.config.participants?.length || 0; return ( > = (props) => { 0} - label={`${participantCount} Agent ${ + label={`${participantCount} Agent${ participantCount > 1 ? "s" : "" }`} /> @@ -216,21 +225,25 @@ export const TeamNode: React.FC> = (props) => { } descriptionContent={
-
Type: {config.team_type}
- {config.team_type === "SelectorGroupChat" && - config.selector_prompt && ( -
- Selector:{" "} - -
- )} +
+ +
+ {isSelectorTeam(component) && component.config.selector_prompt && ( +
+ Selector:{" "} + +
+ )}
} > - {config.team_type === "SelectorGroupChat" && ( + {isSelectorTeam(component) && ( > = (props) => {
{hasModel && ( -
{config.model_client.model}
+
+ {component.config.model_client.config.model} +
)}
@@ -260,22 +275,20 @@ export const TeamNode: React.FC> = (props) => {
} > - {true && ( - - )} +
- {config.participants?.map((participant, index) => ( + {component.config.participants?.map((participant, index) => (
- {participant.name} + {participant.config.name}
))} @@ -287,19 +300,21 @@ export const TeamNode: React.FC> = (props) => { - {config.termination_condition && ( + { - )} + }
- {config.termination_condition && ( + {component.config.termination_condition && (
- {config.termination_condition.termination_type} + + {component.config.termination_condition.component_type} +
)} > = (props) => { }; export const AgentNode: React.FC> = (props) => { - const config = props.data.config as AgentConfig; - const hasModel = !!config.model_client; - const toolCount = config.tools?.length || 0; + const component = props.data.component as Component; + const hasModel = + isAssistantAgent(component) && !!component.config.model_client; + const toolCount = isAssistantAgent(component) + ? component.config.tools?.length || 0 + : 0; return ( > = (props) => { icon={iconMap.agent} headerContent={
- - 0} - label={`${toolCount} Tools`} - /> + {isAssistantAgent(component) && ( + <> + + 0} + label={`${toolCount} Tools`} + /> + + )}
} descriptionContent={
-
Type: {config.agent_type}
- {config.system_message && ( -
- -
- )} +
+ {" "} + {component.config.name} +
+
{component.description}
} > @@ -355,66 +373,69 @@ export const AgentNode: React.FC> = (props) => { className="my-left-handle" /> - - - -
- {config.model_client && ( - <> - {" "} -
{config.model_client.model}
- - )} - -
- Drop model here + {isAssistantAgent(component) && ( + <> + + + +
+ {component.config.model_client && ( +
+ {component.config.model_client.config.model} +
+ )} + +
+ Drop model here +
+
- -
- - - - { - - } -
- {config.tools && toolCount > 0 && ( + + + +
- {config.tools.map((tool, index) => ( -
- - {tool.name} + {component.config.tools && toolCount > 0 && ( +
+ {component.config.tools.map((tool, index) => ( +
+ + {tool.config.name} +
+ ))}
- ))} -
- )} - -
- Drop tools here + )} + +
+ Drop tools here +
+
-
-
-
+ + + )} ); }; // Model Node export const ModelNode: React.FC> = (props) => { - const config = props.data.config as ModelConfig; + const component = props.data.component as Component; + const isOpenAI = isOpenAIModel(component); + const isAzure = isAzureOpenAIModel(component); return ( > = (props) => { icon={iconMap.model} descriptionContent={
-
Type: {config.model_type}
- {config.base_url && ( -
URL: {config.base_url}
+
{component.description}
+ {isOpenAI && component.config.base_url && ( +
URL: {component.config.base_url}
+ )} + {isAzure && ( +
+ Endpoint: {component.config.azure_endpoint} +
)}
} > -
Model: {config.model}
+
Model: {component.config.model}
); @@ -444,29 +470,43 @@ export const ModelNode: React.FC> = (props) => { // Tool Node export const ToolNode: React.FC> = (props) => { - const config = props.data.config as ToolConfig; + const component = props.data.component as Component; + const isFunctionToolType = isFunctionTool(component); return ( Tool Type: {config.tool_type}
} + descriptionContent={ +
+ {" "} + {component.config.name} +
+ } > -
{config.description}
+
{component.config.description}
- -
- -
-
+ {isFunctionToolType && ( + +
+ +
+
+ )} ); }; @@ -475,13 +515,18 @@ export const ToolNode: React.FC> = (props) => { // First, let's add the Termination Node component export const TerminationNode: React.FC> = (props) => { - const config = props.data.config as TerminationConfig; + const component = props.data.component as Component; + const isMaxMessages = isMaxMessageTermination(component); + const isTextMention = isTextMentionTermination(component); + const isOr = isOrTermination(component); return ( Type: {config.termination_type}
} + descriptionContent={ +
{component.description || component.label}
+ } > > = (props) => {
- {config.termination_type === "MaxMessageTermination" && ( -
Max Messages: {config.max_messages}
+ {isMaxMessages && ( +
Max Messages: {component.config.max_messages}
)} - {config.termination_type === "TextMentionTermination" && ( -
Text: {config.text}
+ {isTextMention &&
Text: {component.config.text}
} + {isOr && ( +
OR Conditions: {component.config.conditions.length}
)}
@@ -514,23 +560,46 @@ export const nodeTypes = { }; const EDGE_STYLES = { - "model-connection": { stroke: "rgb(59, 130, 246)" }, - "tool-connection": { stroke: "rgb(34, 197, 94)" }, - "agent-connection": { stroke: "rgb(168, 85, 247)" }, - "termination-connection": { stroke: "rgb(255, 159, 67)" }, + "model-connection": { stroke: "rgb(220,220,220)" }, + "tool-connection": { stroke: "rgb(220,220,220)" }, + "agent-connection": { stroke: "rgb(220,220,220)" }, + "termination-connection": { stroke: "rgb(220,220,220)" }, } as const; type EdgeType = keyof typeof EDGE_STYLES; +type CustomEdgeProps = EdgeProps & { + type: EdgeType; +}; -export const CustomEdge = ({ data, ...props }: EdgeProps) => { +export const CustomEdge = ({ + type, + data, + deletable, + ...props +}: CustomEdgeProps) => { const [edgePath] = getBezierPath(props); - const edgeType = (data?.type as EdgeType) || "model-connection"; + const edgeType = type || "model-connection"; + + // Extract only the SVG path properties we want to pass + const { style: baseStyle, ...pathProps } = props; + const { + // Filter out the problematic props + sourceX, + sourceY, + sourcePosition, + targetPosition, + sourceHandleId, + targetHandleId, + pathOptions, + selectable, + ...validPathProps + } = pathProps; return ( ); }; diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/store.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/builder/store.tsx index a3fbff57be0c..240b569e8f46 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/builder/store.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/store.tsx @@ -1,4 +1,6 @@ +// builder/store.tsx import { create } from "zustand"; +import { isEqual } from "lodash"; import { CustomNode, CustomEdge, @@ -10,14 +12,15 @@ import { nanoid } from "nanoid"; import { TeamConfig, AgentConfig, - ModelConfig, ToolConfig, - ComponentTypes, - ComponentConfigTypes, - TerminationConfig, Component, + ComponentConfig, } from "../../../types/datamodel"; -import { convertTeamConfigToGraph, getLayoutedElements } from "./utils"; +import { + convertTeamConfigToGraph, + getLayoutedElements, + getUniqueName, +} from "./utils"; import { isTeamComponent, isAgentComponent, @@ -25,6 +28,8 @@ import { isTerminationComponent, isModelComponent, isSelectorTeam, + isAssistantAgent, + isFunctionTool, } from "../../../types/guards"; const MAX_HISTORY = 50; @@ -37,9 +42,8 @@ export interface TeamBuilderState { currentHistoryIndex: number; originalComponent: Component | null; addNode: ( - type: ComponentTypes, position: Position, - config: ComponentConfigTypes, + component: Component, targetNodeId?: string ) => void; @@ -106,21 +110,26 @@ const buildTeamComponent = ( const participantEdges = edges.filter( (e) => e.source === teamNode.id && e.type === "agent-connection" ); - config.participants = participantEdges + component.config.participants = participantEdges .map((edge) => { const agentNode = nodes.find((n) => n.id === edge.target); - if (!agentNode || !isAgentConfig(agentNode.data.config)) return null; - - const agentConfig = { ...agentNode.data.config }; + if (!agentNode || !isAgentComponent(agentNode.data.component)) + return null; + const agentComponent = { ...agentNode.data.component }; // Get agent's model using edges const agentModelEdge = edges.find( (e) => e.target === edge.target && e.type === "model-connection" ); if (agentModelEdge) { const modelNode = nodes.find((n) => n.id === agentModelEdge.source); - if (modelNode && isModelConfig(modelNode.data.config)) { - agentConfig.model_client = modelNode.data.config; + if ( + modelNode && + isModelComponent(modelNode.data.component) && + isAssistantAgent(agentComponent) + ) { + // Check specific agent type + agentComponent.config.model_client = modelNode.data.component; } } @@ -128,18 +137,24 @@ const buildTeamComponent = ( const toolEdges = edges.filter( (e) => e.target === edge.target && e.type === "tool-connection" ); - agentConfig.tools = toolEdges - .map((toolEdge) => { - const toolNode = nodes.find((n) => n.id === toolEdge.source); - return toolNode?.data.config as ToolConfig; - }) - .filter((tool): tool is ToolConfig => tool !== null); - - return agentConfig; + + if (isAssistantAgent(agentComponent)) { + agentComponent.config.tools = toolEdges + .map((toolEdge) => { + const toolNode = nodes.find((n) => n.id === toolEdge.source); + if (toolNode && isToolComponent(toolNode.data.component)) { + return toolNode.data.component; + } + return null; + }) + .filter((tool): tool is Component => tool !== null); + } + + return agentComponent; }) - .filter((agent): agent is AgentConfig => agent !== null); + .filter((agent): agent is Component => agent !== null); - return config; + return component; }; export const useTeamBuilderStore = create((set, get) => ({ @@ -151,35 +166,26 @@ export const useTeamBuilderStore = create((set, get) => ({ originalComponent: null, addNode: ( - type: ComponentTypes, position: Position, - config: ComponentConfigTypes, + component: Component, targetNodeId?: string ) => { set((state) => { - // Determine label based on config type - let label = ""; - if (isTeamConfig(config)) { - label = config.name || "Team"; - } else if (isAgentConfig(config)) { - label = config.name || "Agent"; - } else if (isModelConfig(config)) { - label = config.model || "Model"; - } else if (isToolConfig(config)) { - label = config.name || "Tool"; - } else if (isTerminationConfig(config)) { - label = config.termination_type || "Termination"; - } + // Deep clone the incoming component to avoid reference issues + const clonedComponent = JSON.parse(JSON.stringify(component)); + + // Determine label based on component type + let label = clonedComponent.label || clonedComponent.component_type; - // Create new node without connections object + // Create new node const newNode: CustomNode = { id: nanoid(), - type, position, + type: clonedComponent.component_type, data: { label: label || "Node", - type, - config, + component: clonedComponent, + type: clonedComponent.component_type as NodeData["type"], }, }; @@ -191,8 +197,9 @@ export const useTeamBuilderStore = create((set, get) => ({ if (targetNode) { if ( - type === "model" && - ["team", "agent"].includes(targetNode.data.type) + clonedComponent.component_type === "model" && + (isTeamComponent(targetNode.data.component) || + isAgentComponent(targetNode.data.component)) ) { // Find existing model connection and node const existingModelEdge = newEdges.find( @@ -229,23 +236,18 @@ export const useTeamBuilderStore = create((set, get) => ({ }); // Update config - if ( - isTeamConfig(targetNode.data.config) && - isModelConfig(config) && - targetNode.data.config.team_type === "SelectorGroupChat" - ) { - targetNode.data.config.model_client = config; - } else if ( - isAgentConfig(targetNode.data.config) && - isModelConfig(config) - ) { - targetNode.data.config.model_client = config; + if (isModelComponent(clonedComponent)) { + if (isSelectorTeam(targetNode.data.component)) { + targetNode.data.component.config.model_client = clonedComponent; + } else if (isAssistantAgent(targetNode.data.component)) { + targetNode.data.component.config.model_client = clonedComponent; + } } } else if ( - type === "termination" && + clonedComponent.component_type === "termination" && targetNode.data.type === "team" ) { - // Find existing termination connection and node + // Handle termination connection const existingTerminationEdge = newEdges.find( (edge) => edge.target === targetNodeId && @@ -253,13 +255,10 @@ export const useTeamBuilderStore = create((set, get) => ({ ); if (existingTerminationEdge) { - // Remove the existing termination node const existingTerminationNodeId = existingTerminationEdge.source; newNodes = newNodes.filter( (node) => node.id !== existingTerminationNodeId ); - - // Remove all edges connected to the old termination node newEdges = newEdges.filter( (edge) => edge.source !== existingTerminationNodeId && @@ -267,10 +266,7 @@ export const useTeamBuilderStore = create((set, get) => ({ ); } - // Add the new termination node newNodes.push(newNode); - - // Add new termination connection newEdges.push({ id: nanoid(), source: newNode.id, @@ -280,15 +276,31 @@ export const useTeamBuilderStore = create((set, get) => ({ type: "termination-connection", }); - // Update config if ( - isTeamConfig(targetNode.data.config) && - isTerminationConfig(config) + isTeamComponent(targetNode.data.component) && + isTerminationComponent(clonedComponent) ) { - targetNode.data.config.termination_condition = config; + targetNode.data.component.config.termination_condition = + clonedComponent; } - } else if (type === "tool" && targetNode.data.type === "agent") { - // Add tool connection + } else if ( + clonedComponent.component_type === "tool" && + targetNode.data.type === "agent" + ) { + // Handle tool connection with unique name + if ( + isAssistantAgent(targetNode.data.component) && + isAssistantAgent(newNode.data.component) + ) { + const existingTools = + targetNode.data.component.config.tools || []; + const existingNames = existingTools.map((t) => t.config.name); + newNode.data.component.config.name = getUniqueName( + clonedComponent.config.name, + existingNames + ); + } + newNodes.push(newNode); newEdges.push({ id: nanoid(), @@ -299,15 +311,34 @@ export const useTeamBuilderStore = create((set, get) => ({ type: "tool-connection", }); - // Update config - if (isAgentConfig(targetNode.data.config) && isToolConfig(config)) { - if (!targetNode.data.config.tools) { - targetNode.data.config.tools = []; + if ( + isAssistantAgent(targetNode.data.component) && + isToolComponent(newNode.data.component) + ) { + if (!targetNode.data.component.config.tools) { + targetNode.data.component.config.tools = []; } - targetNode.data.config.tools.push(config); + targetNode.data.component.config.tools.push( + newNode.data.component + ); } - } else if (type === "agent" && targetNode.data.type === "team") { - // Add agent connection + } else if ( + clonedComponent.component_type === "agent" && + isTeamComponent(targetNode.data.component) && + isAssistantAgent(newNode.data.component) + ) { + // Handle agent connection with unique name + const existingParticipants = + targetNode.data.component.config.participants || []; + const existingNames = existingParticipants.map( + (p) => p.config.name + ); + + newNode.data.component.config.name = getUniqueName( + clonedComponent.config.name, + existingNames + ); + newNodes.push(newNode); newEdges.push({ id: nanoid(), @@ -318,12 +349,16 @@ export const useTeamBuilderStore = create((set, get) => ({ type: "agent-connection", }); - // Update config - if (isTeamConfig(targetNode.data.config) && isAgentConfig(config)) { - if (!targetNode.data.config.participants) { - targetNode.data.config.participants = []; + if ( + isTeamComponent(targetNode.data.component) && + isAgentComponent(newNode.data.component) + ) { + if (!targetNode.data.component.config.participants) { + targetNode.data.component.config.participants = []; } - targetNode.data.config.participants.push(config); + targetNode.data.component.config.participants.push( + newNode.data.component + ); } } else { // For all other cases, just add the new node @@ -353,72 +388,85 @@ export const useTeamBuilderStore = create((set, get) => ({ updateNode: (nodeId: string, updates: Partial) => { set((state) => { const newNodes = state.nodes.map((node) => { - if (node.id !== nodeId) return node; + if (node.id !== nodeId) { + // If this isn't the directly updated node, check if it needs related updates + const isTeamWithUpdatedAgent = + isTeamComponent(node.data.component) && + state.edges.some( + (e) => + e.type === "agent-connection" && + e.target === nodeId && + e.source === node.id + ); - // Update the node with new data - const updatedNode = { + if (isTeamWithUpdatedAgent && isTeamComponent(node.data.component)) { + return { + ...node, + data: { + ...node.data, + component: { + ...node.data.component, + config: { + ...node.data.component.config, + participants: node.data.component.config.participants.map( + (participant) => + participant === + state.nodes.find((n) => n.id === nodeId)?.data.component + ? updates.component + : participant + ), + }, + }, + }, + }; + } + + const isAgentWithUpdatedTool = + isAssistantAgent(node.data.component) && + state.edges.some( + (e) => + e.type === "tool-connection" && + e.source === nodeId && + e.target === node.id + ); + + if (isAgentWithUpdatedTool && isAssistantAgent(node.data.component)) { + return { + ...node, + data: { + ...node.data, + component: { + ...node.data.component, + config: { + ...node.data.component.config, + tools: (node.data.component.config.tools || []).map( + (tool) => + tool === + state.nodes.find((n) => n.id === nodeId)?.data.component + ? updates.component + : tool + ), + }, + }, + }, + }; + } + + return node; + } + + // This is the directly updated node + const updatedComponent = updates.component || node.data.component; + return { ...node, data: { ...node.data, ...updates, - // Update label based on config type - label: (() => { - const config = { ...node.data.config, ...updates.config }; - if (isTeamConfig(config)) return config.name || "Team"; - if (isAgentConfig(config)) return config.name || "Agent"; - if (isModelConfig(config)) return config.model || "Model"; - if (isToolConfig(config)) return config.name || "Tool"; - if (isTerminationConfig(config)) - return config.termination_type || "Termination"; - return node.data.label; - })(), + component: updatedComponent, }, }; - - return updatedNode; }); - // Update related nodes' configs - const updatedNode = newNodes.find((n) => n.id === nodeId); - if (!updatedNode) return state; - - // If an agent was updated, update its parent team's participants - if (updatedNode.data.type === "agent") { - const teamEdge = state.edges.find( - (e) => e.type === "agent-connection" && e.target === nodeId - ); - if (teamEdge) { - newNodes.forEach((node) => { - if (node.id === teamEdge.source && isTeamConfig(node.data.config)) { - const agentConfig = updatedNode.data.config as AgentConfig; - node.data.config.participants = node.data.config.participants.map( - (p) => (p.name === agentConfig.name ? agentConfig : p) - ); - } - }); - } - } - - // If a tool was updated, update its parent agent's tools - if (updatedNode.data.type === "tool") { - const agentEdge = state.edges.find( - (e) => e.type === "tool-connection" && e.source === nodeId - ); - if (agentEdge) { - newNodes.forEach((node) => { - if ( - node.id === agentEdge.target && - isAgentConfig(node.data.config) - ) { - const toolConfig = updatedNode.data.config as ToolConfig; - node.data.config.tools = node.data.config.tools?.map((t) => - t.name === toolConfig.name ? toolConfig : t - ); - } - }); - } - } - return { nodes: newNodes, history: [ @@ -433,6 +481,7 @@ export const useTeamBuilderStore = create((set, get) => ({ removeNode: (nodeId: string) => { set((state) => { const nodesToRemove = new Set(); + const updatedNodes = new Map(); const collectNodesToRemove = (id: string) => { const node = state.nodes.find((n) => n.id === id); @@ -445,9 +494,9 @@ export const useTeamBuilderStore = create((set, get) => ({ (edge) => edge.source === id || edge.target === id ); - // Handle cascading deletes based on node type - if (node.data.type === "team") { - // Find and remove all agents + // Handle cascading deletes based on component type + if (isTeamComponent(node.data.component)) { + // Find and remove all connected agents connectedEdges .filter((e) => e.type === "agent-connection") .forEach((e) => collectNodesToRemove(e.target)); @@ -461,7 +510,7 @@ export const useTeamBuilderStore = create((set, get) => ({ connectedEdges .filter((e) => e.type === "termination-connection") .forEach((e) => collectNodesToRemove(e.source)); - } else if (node.data.type === "agent") { + } else if (isAgentComponent(node.data.component)) { // Remove agent's model if exists connectedEdges .filter((e) => e.type === "model-connection") @@ -472,22 +521,43 @@ export const useTeamBuilderStore = create((set, get) => ({ .filter((e) => e.type === "tool-connection") .forEach((e) => collectNodesToRemove(e.source)); - // Also need to remove agent from team's config + // Update team's participants if agent is connected to a team const teamEdge = connectedEdges.find( (e) => e.type === "agent-connection" ); if (teamEdge) { const teamNode = state.nodes.find((n) => n.id === teamEdge.source); - if (teamNode && isTeamConfig(teamNode.data.config)) { - const agentConfig = node.data.config as AgentConfig; - teamNode.data.config.participants = - teamNode.data.config.participants.filter( - (p) => p.name !== agentConfig.name - ); + if (teamNode && isTeamComponent(teamNode.data.component)) { + // Create updated team node with filtered participants + const updatedTeamNode = { + ...teamNode, + data: { + ...teamNode.data, + component: { + ...teamNode.data.component, + config: { + ...teamNode.data.component.config, + participants: + teamNode.data.component.config.participants.filter( + (p) => { + // Find a node that matches this participant but isn't being deleted + const participantNode = state.nodes.find( + (n) => + !nodesToRemove.has(n.id) && + isEqual(n.data.component, p) + ); + return participantNode !== undefined; + } + ), + }, + }, + }, + }; + updatedNodes.set(teamNode.id, updatedTeamNode); } } - } else if (node.data.type === "tool") { - // Update agent's tools array when removing a tool + } else if (isToolComponent(node.data.component)) { + // Update connected agent's tools array const agentEdge = connectedEdges.find( (e) => e.type === "tool-connection" ); @@ -495,22 +565,56 @@ export const useTeamBuilderStore = create((set, get) => ({ const agentNode = state.nodes.find( (n) => n.id === agentEdge.target ); - if (agentNode && isAgentConfig(agentNode.data.config)) { - const toolConfig = node.data.config as ToolConfig; - agentNode.data.config.tools = agentNode.data.config.tools?.filter( - (t) => t.name !== toolConfig.name - ); + if (agentNode && isAssistantAgent(agentNode.data.component)) { + // Create updated agent node with filtered tools + const updatedAgentNode = { + ...agentNode, + data: { + ...agentNode.data, + component: { + ...agentNode.data.component, + config: { + ...agentNode.data.component.config, + tools: ( + agentNode.data.component.config.tools || [] + ).filter((t) => { + // Find a node that matches this tool but isn't being deleted + const toolNode = state.nodes.find( + (n) => + !nodesToRemove.has(n.id) && + isEqual(n.data.component, t) + ); + return toolNode !== undefined; + }), + }, + }, + }, + }; + updatedNodes.set(agentNode.id, updatedAgentNode); } } - } else if (node.data.type === "termination") { - // Update team's termination condition when removing it + } else if (isTerminationComponent(node.data.component)) { + // Update connected team's termination condition const teamEdge = connectedEdges.find( (e) => e.type === "termination-connection" ); if (teamEdge) { const teamNode = state.nodes.find((n) => n.id === teamEdge.target); - if (teamNode && isTeamConfig(teamNode.data.config)) { - teamNode.data.config.termination_condition = undefined; + if (teamNode && isTeamComponent(teamNode.data.component)) { + const updatedTeamNode = { + ...teamNode, + data: { + ...teamNode.data, + component: { + ...teamNode.data.component, + config: { + ...teamNode.data.component.config, + termination_condition: undefined, + }, + }, + }, + }; + updatedNodes.set(teamNode.id, updatedTeamNode); } } } @@ -519,10 +623,10 @@ export const useTeamBuilderStore = create((set, get) => ({ // Start the cascade deletion from the initial node collectNodesToRemove(nodeId); - // Remove all collected nodes - const newNodes = state.nodes.filter( - (node) => !nodesToRemove.has(node.id) - ); + // Create new nodes array with both removals and updates + const newNodes = state.nodes + .filter((node) => !nodesToRemove.has(node.id)) + .map((node) => updatedNodes.get(node.id) || node); // Remove all affected edges const newEdges = state.edges.filter( @@ -614,7 +718,9 @@ export const useTeamBuilderStore = create((set, get) => ({ syncToJson: () => { const state = get(); - const teamNodes = state.nodes.filter((node) => node.data.type === "team"); + const teamNodes = state.nodes.filter( + (node) => node.data.component.component_type === "team" + ); if (teamNodes.length === 0) return null; const teamNode = teamNodes[0]; @@ -639,24 +745,27 @@ export const useTeamBuilderStore = create((set, get) => ({ }); }, - loadFromJson: (config: TeamConfig) => { - const { nodes, edges } = convertTeamConfigToGraph(config); + loadFromJson: (component: Component) => { + // Get graph representation of team config + const { nodes, edges } = convertTeamConfigToGraph(component); + + // Apply layout to elements const { nodes: layoutedNodes, edges: layoutedEdges } = getLayoutedElements( nodes, edges ); - set((state) => { - return { - nodes: layoutedNodes, - edges: layoutedEdges, - originalComponent: config, - history: [{ nodes: layoutedNodes, edges: layoutedEdges }], // Reset history - currentHistoryIndex: 0, // Reset to 0 - selectedNodeId: null, - }; + // Update store with new state and reset history + set({ + nodes: layoutedNodes, + edges: layoutedEdges, + originalComponent: component, // Store original component for reference + history: [{ nodes: layoutedNodes, edges: layoutedEdges }], // Reset history with initial state + currentHistoryIndex: 0, + selectedNodeId: null, }); + // Return final graph state return { nodes: layoutedNodes, edges: layoutedEdges }; }, resetHistory: () => { diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/types.ts b/python/packages/autogen-studio/frontend/src/components/views/team/builder/types.ts index 0312cf8049cc..15468d145b9f 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/builder/types.ts +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/types.ts @@ -1,5 +1,10 @@ +// builder/types.ts import { Node, Edge } from "@xyflow/react"; -import { Component, ComponentConfig } from "../../../types/datamodel"; +import { + Component, + ComponentConfig, + ComponentTypes, +} from "../../../types/datamodel"; export interface NodeData extends Record { label: string; diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/utils.ts b/python/packages/autogen-studio/frontend/src/components/views/team/builder/utils.ts index 6e5c70931097..2d840ca41b8e 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/builder/utils.ts +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/utils.ts @@ -1,14 +1,12 @@ import dagre from "@dagrejs/dagre"; -import { CustomNode, CustomEdge } from "./types"; +import { CustomNode, CustomEdge, EdgeTypes } from "./types"; import { nanoid } from "nanoid"; import { TeamConfig, - ModelConfig, - AgentConfig, - ToolConfig, - ComponentTypes, - TerminationConfig, + Component, + ComponentConfig, } from "../../../types/datamodel"; +import { isAssistantAgent, isSelectorTeam } from "../../../types/guards"; interface ConversionResult { nodes: CustomNode[]; @@ -39,29 +37,16 @@ const calculateParticipantPosition = ( // Helper to create nodes with consistent structure const createNode = ( - type: ComponentTypes, position: Position, - config: - | TeamConfig - | ModelConfig - | AgentConfig - | ToolConfig - | TerminationConfig, + component: Component, label?: string ): CustomNode => ({ id: nanoid(), - type, position, + type: component.component_type, data: { - label: label || type, - type, - config, - connections: { - modelClient: null, - tools: [], - participants: [], - termination: null, - }, + label: label || component.label || component.component_type, + component, }, }); @@ -69,12 +54,7 @@ const createNode = ( const createEdge = ( source: string, target: string, - type: - | "model-connection" - | "tool-connection" - | "agent-connection" - | "team-connection" - | "termination-connection" + type: EdgeTypes ): CustomEdge => ({ id: `e${source}-${target}`, source, @@ -83,29 +63,21 @@ const createEdge = ( }); export const convertTeamConfigToGraph = ( - config: TeamConfig + teamComponent: Component ): ConversionResult => { const nodes: CustomNode[] = []; const edges: CustomEdge[] = []; // Create team node - const teamNode = createNode( - "team", - { x: 400, y: 50 }, - { - ...config, - // participants: [], // Clear participants as we'll rebuild from edges - } - ); + const teamNode = createNode({ x: 400, y: 50 }, teamComponent); nodes.push(teamNode); // Add model client if present - if (config.team_type === "SelectorGroupChat" && config.model_client) { + if (isSelectorTeam(teamComponent) && teamComponent.config.model_client) { const modelNode = createNode( - "model", { x: 200, y: 50 }, - config.model_client, - config.model_client.model + teamComponent.config.model_client, + teamComponent.config.model_client.config.model ); nodes.push(modelNode); edges.push({ @@ -119,15 +91,12 @@ export const convertTeamConfigToGraph = ( } // Add participants (agents) - config.participants.forEach((participant, index) => { + teamComponent.config.participants.forEach((participant, index) => { const position = calculateParticipantPosition( index, - config.participants.length + teamComponent.config.participants.length ); - const agentNode = createNode("agent", position, { - ...participant, - // tools: [], // Clear tools as we'll rebuild from edges - }); + const agentNode = createNode(position, participant); nodes.push(agentNode); // Connect to team @@ -141,15 +110,14 @@ export const convertTeamConfigToGraph = ( }); // Add agent's model client if present - if (participant.model_client) { + if (isAssistantAgent(participant) && participant.config.model_client) { const agentModelNode = createNode( - "model", { x: position.x - 150, y: position.y, }, - participant.model_client, - participant.model_client.model + participant.config.model_client, + participant.config.model_client.config.model ); nodes.push(agentModelNode); edges.push({ @@ -163,33 +131,33 @@ export const convertTeamConfigToGraph = ( } // Add agent's tools - participant.tools?.forEach((tool, toolIndex) => { - const toolNode = createNode( - "tool", - { - x: position.x + 150, - y: position.y + toolIndex * 100, - }, - tool - ); - nodes.push(toolNode); - edges.push({ - id: nanoid(), - source: toolNode.id, - target: agentNode.id, - sourceHandle: `${toolNode.id}-tool-output-handle`, - targetHandle: `${agentNode.id}-tool-input-handle`, - type: "tool-connection", + if (isAssistantAgent(participant) && participant.config.tools) { + participant.config.tools.forEach((tool, toolIndex) => { + const toolNode = createNode( + { + x: position.x + 150, + y: position.y + toolIndex * 100, + }, + tool + ); + nodes.push(toolNode); + edges.push({ + id: nanoid(), + source: toolNode.id, + target: agentNode.id, + sourceHandle: `${toolNode.id}-tool-output-handle`, + targetHandle: `${agentNode.id}-tool-input-handle`, + type: "tool-connection", + }); }); - }); + } }); // Add termination condition if present - if (config?.termination_condition) { + if (teamComponent.config.termination_condition) { const terminationNode = createNode( - "termination", { x: 600, y: 50 }, - config.termination_condition + teamComponent.config.termination_condition ); nodes.push(terminationNode); edges.push({ @@ -205,6 +173,7 @@ export const convertTeamConfigToGraph = ( return { nodes, edges }; }; +// Rest of the file remains the same since it deals with layout calculations const NODE_WIDTH = 272; const NODE_HEIGHT = 200; @@ -285,13 +254,31 @@ export const getNodeConnections = (nodeId: string, edges: CustomEdge[]) => { modelClient: edges.find((e) => e.target === nodeId && e.type === "model-connection") ?.source || null, - tools: edges .filter((e) => e.target === nodeId && e.type === "tool-connection") .map((e) => e.source), - participants: edges .filter((e) => e.source === nodeId && e.type === "agent-connection") .map((e) => e.target), }; }; + +export const getUniqueName = ( + baseName: string, + existingNames: string[] +): string => { + // Convert baseName to valid identifier format + let validBaseName = baseName + // Replace spaces and special characters with underscore + .replace(/[^a-zA-Z0-9_$]/g, "_") + // Ensure it starts with a letter, underscore, or dollar sign + .replace(/^([^a-zA-Z_$])/, "_$1"); + + if (!existingNames.includes(validBaseName)) return validBaseName; + + let counter = 1; + while (existingNames.includes(`${validBaseName}_${counter}`)) { + counter++; + } + return `${validBaseName}_${counter}`; +}; diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/manager.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/manager.tsx index 6191c553a00f..b8a72a2be058 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/manager.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/manager.tsx @@ -37,6 +37,7 @@ export const TeamManager: React.FC = () => { setIsLoading(true); const data = await teamAPI.listTeams(user.email); setTeams(data); + console.log("team data", data); if (!currentTeam && data.length > 0) { setCurrentTeam(data[0]); } @@ -201,7 +202,7 @@ export const TeamManager: React.FC = () => { <> - {currentTeam.component.label} + {currentTeam.component?.label} {currentTeam.id ? ( "" ) : ( diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx index 93c52f856024..d64917a60e21 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx @@ -47,6 +47,7 @@ export const TeamSidebar: React.FC = ({ newTeam.component.label = "new_team_" + new Date().getTime(); onCreateTeam(newTeam); }; + // Render collapsed state if (!isOpen) { return ( @@ -160,7 +161,7 @@ export const TeamSidebar: React.FC = ({ {/* Team Name and Actions Row */}
- {team.component.label} + {team.component?.label}
{/* From 5e208a4916c6b12f9ed29577dbb2ef76bbb458ad Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Tue, 28 Jan 2025 19:32:27 -0800 Subject: [PATCH 10/16] update impl --- .../autogenstudio/database/db_manager.py | 3 +- .../autogenstudio/datamodel/db.py | 4 +- .../autogenstudio/web/managers/connection.py | 5 + .../session/chat/agentflow/agentflow.tsx | 17 +- .../components/views/session/chat/chat.tsx | 6 +- .../components/views/session/chat/runview.tsx | 5 +- .../src/components/views/session/editor.tsx | 2 +- .../components/views/team/builder/nodes.tsx | 5 +- .../components/views/team/builder/store.tsx | 584 ++++++------------ .../components/views/team/builder/utils.ts | 267 +++----- 10 files changed, 281 insertions(+), 617 deletions(-) diff --git a/python/packages/autogen-studio/autogenstudio/database/db_manager.py b/python/packages/autogen-studio/autogenstudio/database/db_manager.py index bf5358436679..0c4f2063d108 100644 --- a/python/packages/autogen-studio/autogenstudio/database/db_manager.py +++ b/python/packages/autogen-studio/autogenstudio/database/db_manager.py @@ -348,8 +348,7 @@ async def _check_team_exists( teams = self.get(Team, {"user_id": user_id}).data - for team in teams: - print(team.config, "******" ,config) + for team in teams: if team.config == config: return team diff --git a/python/packages/autogen-studio/autogenstudio/datamodel/db.py b/python/packages/autogen-studio/autogenstudio/datamodel/db.py index f3ca24b82744..b8fbf0748143 100644 --- a/python/packages/autogen-studio/autogenstudio/datamodel/db.py +++ b/python/packages/autogen-studio/autogenstudio/datamodel/db.py @@ -26,7 +26,7 @@ class Team(SQLModel, table=True): user_id: Optional[str] = None version: Optional[str] = "0.0.1" component: Union[ComponentModel, dict] = Field(sa_column=Column(JSON)) - + class Message(SQLModel, table=True): @@ -42,7 +42,7 @@ class Message(SQLModel, table=True): ) # pylint: disable=not-callable user_id: Optional[str] = None version: Optional[str] = "0.0.1" - component: Union[MessageConfig, dict] = Field(default_factory=MessageConfig, sa_column=Column(JSON)) + config: Union[MessageConfig, dict] = Field(default_factory=MessageConfig, sa_column=Column(JSON)) session_id: Optional[int] = Field( default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE")) ) diff --git a/python/packages/autogen-studio/autogenstudio/web/managers/connection.py b/python/packages/autogen-studio/autogenstudio/web/managers/connection.py index 271b53d87675..51c6e903882f 100644 --- a/python/packages/autogen-studio/autogenstudio/web/managers/connection.py +++ b/python/packages/autogen-studio/autogenstudio/web/managers/connection.py @@ -1,6 +1,7 @@ import asyncio import logging from datetime import datetime, timezone +import traceback from typing import Any, Callable, Dict, Optional, Union from uuid import UUID @@ -89,6 +90,8 @@ async def start_stream(self, run_id: UUID, task: str, team_config: dict) -> None input_func = self.create_input_func(run_id) + + async for message in team_manager.run_stream( task=task, team_config=team_config, input_func=input_func, cancellation_token=cancellation_token ): @@ -137,12 +140,14 @@ async def start_stream(self, run_id: UUID, task: str, team_config: dict) -> None except Exception as e: logger.error(f"Stream error for run {run_id}: {e}") + traceback.print_exc() await self._handle_stream_error(run_id, e) finally: self._cancellation_tokens.pop(run_id, None) async def _save_message(self, run_id: UUID, message: Union[AgentEvent | ChatMessage, ChatMessage]) -> None: """Save a message to the database""" + run = await self._get_run(run_id) if run: db_message = Message( diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/chat/agentflow/agentflow.tsx b/python/packages/autogen-studio/frontend/src/components/views/session/chat/agentflow/agentflow.tsx index 9b18038c6102..e569b517a23d 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/session/chat/agentflow/agentflow.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/session/chat/agentflow/agentflow.tsx @@ -25,6 +25,7 @@ import { AgentConfig, TeamConfig, Run, + Component, } from "../../../../types/datamodel"; import { CustomEdge, CustomEdgeData } from "./edge"; import { useConfigStore } from "../../../../../hooks/store"; @@ -32,7 +33,7 @@ import { AgentFlowToolbar } from "./toolbar"; import { EdgeMessageModal } from "./edgemessagemodal"; interface AgentFlowProps { - teamConfig: TeamConfig; + teamConfig: Component; run: Run; } @@ -151,7 +152,7 @@ const getLayoutedElements = ( const createNode = ( id: string, type: "user" | "agent" | "end", - agentConfig?: AgentConfig, + agentConfig?: Component, isActive: boolean = false, run?: Run ): Node => { @@ -218,7 +219,7 @@ const createNode = ( data: { type: "agent", label: id, - agentType: agentConfig?.agent_type || "", + agentType: agentConfig?.label || "", description: agentConfig?.description || "", isActive, status: "", @@ -281,8 +282,8 @@ const AgentFlow: React.FC = ({ teamConfig, run }) => { // Add first message node if it exists if (messages.length > 0) { - const firstAgentConfig = teamConfig.participants.find( - (p) => p.name === messages[0].source + const firstAgentConfig = teamConfig.config.participants.find( + (p) => p.config.name === messages[0].source ); nodeMap.set( messages[0].source, @@ -322,8 +323,8 @@ const AgentFlow: React.FC = ({ teamConfig, run }) => { } if (!nodeMap.has(nextMsg.source)) { - const agentConfig = teamConfig.participants.find( - (p) => p.name === nextMsg.source + const agentConfig = teamConfig.config.participants.find( + (p) => p.config.name === nextMsg.source ); nodeMap.set( nextMsg.source, @@ -479,7 +480,7 @@ const AgentFlow: React.FC = ({ teamConfig, run }) => { return { nodes: Array.from(nodeMap.values()), edges: newEdges }; }, - [teamConfig.participants, run, settings] + [teamConfig.config.participants, run, settings] ); const handleToggleFullscreen = useCallback(() => { diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/chat/chat.tsx b/python/packages/autogen-studio/frontend/src/components/views/session/chat/chat.tsx index 65d9452bbf49..8f5ea8bcd39a 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/session/chat/chat.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/session/chat/chat.tsx @@ -11,6 +11,7 @@ import { RunStatus, TeamResult, Session, + Component, } from "../../../types/datamodel"; import { appContext } from "../../../../hooks/provider"; import ChatInput from "./chatinput"; @@ -46,7 +47,8 @@ export default function ChatView({ session }: ChatViewProps) { const [activeSocket, setActiveSocket] = React.useState( null ); - const [teamConfig, setTeamConfig] = React.useState(null); + const [teamConfig, setTeamConfig] = + React.useState | null>(null); const inputTimeoutRef = React.useRef(null); const activeSocketRef = React.useRef(null); @@ -94,7 +96,7 @@ export default function ChatView({ session }: ChatViewProps) { teamAPI .getTeam(session.team_id, user.email) .then((team) => { - setTeamConfig(team.config); + setTeamConfig(team.component); }) .catch((error) => { console.error("Error loading team config:", error); diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/chat/runview.tsx b/python/packages/autogen-studio/frontend/src/components/views/session/chat/runview.tsx index 57557ec1527a..f6ba05e02553 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/session/chat/runview.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/session/chat/runview.tsx @@ -11,7 +11,7 @@ import { ChevronUp, Bot, } from "lucide-react"; -import { Run, Message, TeamConfig } from "../../../types/datamodel"; +import { Run, Message, TeamConfig, Component } from "../../../types/datamodel"; import AgentFlow from "./agentflow/agentflow"; import { RenderMessage } from "./rendermessage"; import InputRequestView from "./inputrequest"; @@ -24,7 +24,7 @@ import { interface RunViewProps { run: Run; - teamConfig?: TeamConfig; + teamConfig?: Component; onInputResponse?: (response: string) => void; onCancel?: () => void; isFirstRun?: boolean; @@ -54,6 +54,7 @@ const RunView: React.FC = ({ }, [run.messages]); // Only depend on messages changing const calculateThreadTokens = (messages: Message[]) => { + console.log("messages", messages); return messages.reduce((total, msg) => { if (!msg.config.models_usage) return total; return ( diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/editor.tsx b/python/packages/autogen-studio/frontend/src/components/views/session/editor.tsx index c8358695521c..c89ee39a70f6 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/session/editor.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/session/editor.tsx @@ -132,7 +132,7 @@ export const SessionEditor: React.FC = ({ } options={teams.map((team) => ({ value: team.id, - label: `${team.config.name} (${team.config.team_type})`, + label: `${team.component.label} (${team.component.component_type})`, }))} notFoundContent={loading ? : null} /> diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/nodes.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/builder/nodes.tsx index 2c212b709f63..f65d6e24bf16 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/builder/nodes.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/nodes.tsx @@ -313,7 +313,8 @@ export const TeamNode: React.FC> = (props) => {
- {component.config.termination_condition.component_type} + {component.config.termination_condition.label || + component.config.termination_condition.component_type}
)} @@ -370,7 +371,7 @@ export const AgentNode: React.FC> = (props) => { type="target" position={Position.Left} id={`${props.id}-agent-input-handle`} - className="my-left-handle" + className="my-left-handle z-100" /> {isAssistantAgent(component) && ( diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/store.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/builder/store.tsx index 240b569e8f46..ef145ef01dfb 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/builder/store.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/store.tsx @@ -1,4 +1,3 @@ -// builder/store.tsx import { create } from "zustand"; import { isEqual } from "lodash"; import { @@ -29,7 +28,6 @@ import { isModelComponent, isSelectorTeam, isAssistantAgent, - isFunctionTool, } from "../../../types/guards"; const MAX_HISTORY = 50; @@ -41,10 +39,12 @@ export interface TeamBuilderState { history: Array<{ nodes: CustomNode[]; edges: CustomEdge[] }>; currentHistoryIndex: number; originalComponent: Component | null; + + // Simplified actions addNode: ( position: Position, component: Component, - targetNodeId?: string + targetNodeId: string ) => void; updateNode: (nodeId: string, updates: Partial) => void; @@ -74,38 +74,6 @@ const buildTeamComponent = ( const component = { ...teamNode.data.component }; - // Use edge queries instead of connections - const modelEdge = edges.find( - (e) => e.target === teamNode.id && e.type === "model-connection" - ); - if (modelEdge) { - const modelNode = nodes.find((n) => n.id === modelEdge.source); - if ( - modelNode && - isModelComponent(modelNode.data.component) && - isSelectorTeam(component) - ) { - component.config.model_client = modelNode.data.component; - } - } - - // Add termination connection handling - const terminationEdge = edges.find( - (e) => e.target === teamNode.id && e.type === "termination-connection" - ); - if (terminationEdge) { - const terminationNode = nodes.find((n) => n.id === terminationEdge.source); - // if (terminationNode && isTerminationConfig(terminationNode.data.config)) { - // config.termination_condition = terminationNode.data.config; - // } - if ( - terminationNode && - isTerminationComponent(terminationNode.data.component) - ) { - component.config.termination_condition = terminationNode.data.component; - } - } - // Get participants using edges const participantEdges = edges.filter( (e) => e.source === teamNode.id && e.type === "agent-connection" @@ -115,42 +83,7 @@ const buildTeamComponent = ( const agentNode = nodes.find((n) => n.id === edge.target); if (!agentNode || !isAgentComponent(agentNode.data.component)) return null; - - const agentComponent = { ...agentNode.data.component }; - // Get agent's model using edges - const agentModelEdge = edges.find( - (e) => e.target === edge.target && e.type === "model-connection" - ); - if (agentModelEdge) { - const modelNode = nodes.find((n) => n.id === agentModelEdge.source); - if ( - modelNode && - isModelComponent(modelNode.data.component) && - isAssistantAgent(agentComponent) - ) { - // Check specific agent type - agentComponent.config.model_client = modelNode.data.component; - } - } - - // Get agent's tools using edges - const toolEdges = edges.filter( - (e) => e.target === edge.target && e.type === "tool-connection" - ); - - if (isAssistantAgent(agentComponent)) { - agentComponent.config.tools = toolEdges - .map((toolEdge) => { - const toolNode = nodes.find((n) => n.id === toolEdge.source); - if (toolNode && isToolComponent(toolNode.data.component)) { - return toolNode.data.component; - } - return null; - }) - .filter((tool): tool is Component => tool !== null); - } - - return agentComponent; + return agentNode.data.component; }) .filter((agent): agent is Component => agent !== null); @@ -168,206 +101,183 @@ export const useTeamBuilderStore = create((set, get) => ({ addNode: ( position: Position, component: Component, - targetNodeId?: string + targetNodeId: string ) => { set((state) => { // Deep clone the incoming component to avoid reference issues const clonedComponent = JSON.parse(JSON.stringify(component)); - - // Determine label based on component type - let label = clonedComponent.label || clonedComponent.component_type; - - // Create new node - const newNode: CustomNode = { - id: nanoid(), - position, - type: clonedComponent.component_type, - data: { - label: label || "Node", - component: clonedComponent, - type: clonedComponent.component_type as NodeData["type"], - }, - }; - let newNodes = [...state.nodes]; let newEdges = [...state.edges]; + console.log( + "Adding node", + clonedComponent, + isTerminationComponent(clonedComponent), + targetNodeId + ); + if (targetNodeId) { const targetNode = state.nodes.find((n) => n.id === targetNodeId); - if (targetNode) { + console.log("Target node", targetNode); + if (!targetNode) return state; + + // Handle configuration updates based on component type + if (isModelComponent(clonedComponent)) { if ( - clonedComponent.component_type === "model" && - (isTeamComponent(targetNode.data.component) || - isAgentComponent(targetNode.data.component)) + isTeamComponent(targetNode.data.component) && + isSelectorTeam(targetNode.data.component) ) { - // Find existing model connection and node - const existingModelEdge = newEdges.find( - (edge) => - edge.target === targetNodeId && edge.type === "model-connection" - ); - - if (existingModelEdge) { - // Remove the existing model node - const existingModelNodeId = existingModelEdge.source; - newNodes = newNodes.filter( - (node) => node.id !== existingModelNodeId - ); - - // Remove all edges connected to the old model node - newEdges = newEdges.filter( - (edge) => - edge.source !== existingModelNodeId && - edge.target !== existingModelNodeId - ); - } - - // Add the new model node - newNodes.push(newNode); - - // Add new model connection - newEdges.push({ - id: nanoid(), - source: newNode.id, - target: targetNodeId, - sourceHandle: `${newNode.id}-model-output-handle`, - targetHandle: `${targetNodeId}-model-input-handle`, - type: "model-connection", - }); - - // Update config - if (isModelComponent(clonedComponent)) { - if (isSelectorTeam(targetNode.data.component)) { - targetNode.data.component.config.model_client = clonedComponent; - } else if (isAssistantAgent(targetNode.data.component)) { - targetNode.data.component.config.model_client = clonedComponent; - } - } + targetNode.data.component.config.model_client = clonedComponent; + return { + nodes: newNodes, + edges: newEdges, + history: [ + ...state.history.slice(0, state.currentHistoryIndex + 1), + { nodes: newNodes, edges: newEdges }, + ].slice(-MAX_HISTORY), + currentHistoryIndex: state.currentHistoryIndex + 1, + }; } else if ( - clonedComponent.component_type === "termination" && - targetNode.data.type === "team" + isAgentComponent(targetNode.data.component) && + isAssistantAgent(targetNode.data.component) ) { - // Handle termination connection - const existingTerminationEdge = newEdges.find( - (edge) => - edge.target === targetNodeId && - edge.type === "termination-connection" - ); - - if (existingTerminationEdge) { - const existingTerminationNodeId = existingTerminationEdge.source; - newNodes = newNodes.filter( - (node) => node.id !== existingTerminationNodeId - ); - newEdges = newEdges.filter( - (edge) => - edge.source !== existingTerminationNodeId && - edge.target !== existingTerminationNodeId - ); - } - - newNodes.push(newNode); - newEdges.push({ - id: nanoid(), - source: newNode.id, - target: targetNodeId, - sourceHandle: `${newNode.id}-termination-output-handle`, - targetHandle: `${targetNodeId}-termination-input-handle`, - type: "termination-connection", - }); - - if ( - isTeamComponent(targetNode.data.component) && - isTerminationComponent(clonedComponent) - ) { - targetNode.data.component.config.termination_condition = - clonedComponent; - } - } else if ( - clonedComponent.component_type === "tool" && - targetNode.data.type === "agent" + targetNode.data.component.config.model_client = clonedComponent; + return { + nodes: newNodes, + edges: newEdges, + history: [ + ...state.history.slice(0, state.currentHistoryIndex + 1), + { nodes: newNodes, edges: newEdges }, + ].slice(-MAX_HISTORY), + currentHistoryIndex: state.currentHistoryIndex + 1, + }; + } + } else if (isToolComponent(clonedComponent)) { + if ( + isAgentComponent(targetNode.data.component) && + isAssistantAgent(targetNode.data.component) ) { - // Handle tool connection with unique name - if ( - isAssistantAgent(targetNode.data.component) && - isAssistantAgent(newNode.data.component) - ) { - const existingTools = - targetNode.data.component.config.tools || []; - const existingNames = existingTools.map((t) => t.config.name); - newNode.data.component.config.name = getUniqueName( - clonedComponent.config.name, - existingNames - ); + if (!targetNode.data.component.config.tools) { + targetNode.data.component.config.tools = []; } - - newNodes.push(newNode); - newEdges.push({ - id: nanoid(), - source: newNode.id, - target: targetNodeId, - sourceHandle: `${newNode.id}-tool-output-handle`, - targetHandle: `${targetNodeId}-tool-input-handle`, - type: "tool-connection", + const toolName = getUniqueName( + clonedComponent.config.name, + targetNode.data.component.config.tools.map((t) => t.config.name) + ); + clonedComponent.config.name = toolName; + targetNode.data.component.config.tools.push(clonedComponent); + return { + nodes: newNodes, + edges: newEdges, + history: [ + ...state.history.slice(0, state.currentHistoryIndex + 1), + { nodes: newNodes, edges: newEdges }, + ].slice(-MAX_HISTORY), + currentHistoryIndex: state.currentHistoryIndex + 1, + }; + } + } else if (isTerminationComponent(clonedComponent)) { + console.log("Termination component added", clonedComponent); + if (isTeamComponent(targetNode.data.component)) { + newNodes = state.nodes.map((node) => { + if (node.id === targetNodeId) { + return { + ...node, + data: { + ...node.data, + component: { + ...node.data.component, + config: { + ...node.data.component.config, + termination_condition: clonedComponent, + }, + }, + }, + }; + } + return node; }); - if ( - isAssistantAgent(targetNode.data.component) && - isToolComponent(newNode.data.component) - ) { - if (!targetNode.data.component.config.tools) { - targetNode.data.component.config.tools = []; - } - targetNode.data.component.config.tools.push( - newNode.data.component - ); - } - } else if ( - clonedComponent.component_type === "agent" && - isTeamComponent(targetNode.data.component) && - isAssistantAgent(newNode.data.component) - ) { - // Handle agent connection with unique name - const existingParticipants = - targetNode.data.component.config.participants || []; - const existingNames = existingParticipants.map( - (p) => p.config.name - ); + return { + nodes: newNodes, + edges: newEdges, + history: [ + ...state.history.slice(0, state.currentHistoryIndex + 1), + { nodes: newNodes, edges: newEdges }, + ].slice(-MAX_HISTORY), + currentHistoryIndex: state.currentHistoryIndex + 1, + }; + } + } + } - newNode.data.component.config.name = getUniqueName( + // Handle team and agent nodes + if (isTeamComponent(clonedComponent)) { + const newNode: CustomNode = { + id: nanoid(), + position, + type: clonedComponent.component_type, + data: { + label: clonedComponent.label || "Team", + component: clonedComponent, + type: clonedComponent.component_type as NodeData["type"], + }, + }; + newNodes.push(newNode); + } else if (isAgentComponent(clonedComponent)) { + // Find the team node to connect to + const teamNode = newNodes.find((n) => + isTeamComponent(n.data.component) + ); + if (teamNode) { + // Ensure unique agent name + if ( + isAssistantAgent(clonedComponent) && + isTeamComponent(teamNode.data.component) + ) { + const existingAgents = + teamNode.data.component.config.participants || []; + const existingNames = existingAgents.map((p) => p.config.name); + clonedComponent.config.name = getUniqueName( clonedComponent.config.name, existingNames ); + } - newNodes.push(newNode); - newEdges.push({ - id: nanoid(), - source: targetNodeId, - target: newNode.id, - sourceHandle: `${targetNodeId}-agent-output-handle`, - targetHandle: `${newNode.id}-agent-input-handle`, - type: "agent-connection", - }); - - if ( - isTeamComponent(targetNode.data.component) && - isAgentComponent(newNode.data.component) - ) { - if (!targetNode.data.component.config.participants) { - targetNode.data.component.config.participants = []; - } - targetNode.data.component.config.participants.push( - newNode.data.component - ); + const newNode: CustomNode = { + id: nanoid(), + position, + type: clonedComponent.component_type, + data: { + label: clonedComponent.label || clonedComponent.config.name, + component: clonedComponent, + type: clonedComponent.component_type as NodeData["type"], + }, + }; + + newNodes.push(newNode); + + // Add connection to team + newEdges.push({ + id: nanoid(), + source: teamNode.id, + target: newNode.id, + sourceHandle: `${teamNode.id}-agent-output-handle`, + targetHandle: `${newNode.id}-agent-input-handle`, + type: "agent-connection", + }); + + // Update team's participants + if (isTeamComponent(teamNode.data.component)) { + if (!teamNode.data.component.config.participants) { + teamNode.data.component.config.participants = []; } - } else { - // For all other cases, just add the new node - newNodes.push(newNode); + teamNode.data.component.config.participants.push( + newNode.data.component as Component + ); } } - } else { - // If no target node, just add the new node - newNodes.push(newNode); } const { nodes: layoutedNodes, edges: layoutedEdges } = @@ -420,38 +330,6 @@ export const useTeamBuilderStore = create((set, get) => ({ }, }; } - - const isAgentWithUpdatedTool = - isAssistantAgent(node.data.component) && - state.edges.some( - (e) => - e.type === "tool-connection" && - e.source === nodeId && - e.target === node.id - ); - - if (isAgentWithUpdatedTool && isAssistantAgent(node.data.component)) { - return { - ...node, - data: { - ...node.data, - component: { - ...node.data.component, - config: { - ...node.data.component.config, - tools: (node.data.component.config.tools || []).map( - (tool) => - tool === - state.nodes.find((n) => n.id === nodeId)?.data.component - ? updates.component - : tool - ), - }, - }, - }, - }; - } - return node; } @@ -500,27 +378,7 @@ export const useTeamBuilderStore = create((set, get) => ({ connectedEdges .filter((e) => e.type === "agent-connection") .forEach((e) => collectNodesToRemove(e.target)); - - // Remove team's model if exists - connectedEdges - .filter((e) => e.type === "model-connection") - .forEach((e) => collectNodesToRemove(e.source)); - - // Remove termination condition if exists - connectedEdges - .filter((e) => e.type === "termination-connection") - .forEach((e) => collectNodesToRemove(e.source)); } else if (isAgentComponent(node.data.component)) { - // Remove agent's model if exists - connectedEdges - .filter((e) => e.type === "model-connection") - .forEach((e) => collectNodesToRemove(e.source)); - - // Remove all agent's tools - connectedEdges - .filter((e) => e.type === "tool-connection") - .forEach((e) => collectNodesToRemove(e.source)); - // Update team's participants if agent is connected to a team const teamEdge = connectedEdges.find( (e) => e.type === "agent-connection" @@ -528,7 +386,6 @@ export const useTeamBuilderStore = create((set, get) => ({ if (teamEdge) { const teamNode = state.nodes.find((n) => n.id === teamEdge.source); if (teamNode && isTeamComponent(teamNode.data.component)) { - // Create updated team node with filtered participants const updatedTeamNode = { ...teamNode, data: { @@ -539,15 +396,7 @@ export const useTeamBuilderStore = create((set, get) => ({ ...teamNode.data.component.config, participants: teamNode.data.component.config.participants.filter( - (p) => { - // Find a node that matches this participant but isn't being deleted - const participantNode = state.nodes.find( - (n) => - !nodesToRemove.has(n.id) && - isEqual(n.data.component, p) - ); - return participantNode !== undefined; - } + (p) => !isEqual(p, node.data.component) ), }, }, @@ -556,67 +405,6 @@ export const useTeamBuilderStore = create((set, get) => ({ updatedNodes.set(teamNode.id, updatedTeamNode); } } - } else if (isToolComponent(node.data.component)) { - // Update connected agent's tools array - const agentEdge = connectedEdges.find( - (e) => e.type === "tool-connection" - ); - if (agentEdge) { - const agentNode = state.nodes.find( - (n) => n.id === agentEdge.target - ); - if (agentNode && isAssistantAgent(agentNode.data.component)) { - // Create updated agent node with filtered tools - const updatedAgentNode = { - ...agentNode, - data: { - ...agentNode.data, - component: { - ...agentNode.data.component, - config: { - ...agentNode.data.component.config, - tools: ( - agentNode.data.component.config.tools || [] - ).filter((t) => { - // Find a node that matches this tool but isn't being deleted - const toolNode = state.nodes.find( - (n) => - !nodesToRemove.has(n.id) && - isEqual(n.data.component, t) - ); - return toolNode !== undefined; - }), - }, - }, - }, - }; - updatedNodes.set(agentNode.id, updatedAgentNode); - } - } - } else if (isTerminationComponent(node.data.component)) { - // Update connected team's termination condition - const teamEdge = connectedEdges.find( - (e) => e.type === "termination-connection" - ); - if (teamEdge) { - const teamNode = state.nodes.find((n) => n.id === teamEdge.target); - if (teamNode && isTeamComponent(teamNode.data.component)) { - const updatedTeamNode = { - ...teamNode, - data: { - ...teamNode.data, - component: { - ...teamNode.data.component, - config: { - ...teamNode.data.component.config, - termination_condition: undefined, - }, - }, - }, - }; - updatedNodes.set(teamNode.id, updatedTeamNode); - } - } } }; @@ -647,41 +435,28 @@ export const useTeamBuilderStore = create((set, get) => ({ }, addEdge: (edge: CustomEdge) => { - set((state) => { - let newEdges = [...state.edges]; - - if (edge.type === "model-connection") { - newEdges = newEdges.filter( - (e) => !(e.target === edge.target && e.type === "model-connection") - ); - } - - newEdges.push(edge); - - return { - edges: newEdges, - history: [ - ...state.history.slice(0, state.currentHistoryIndex + 1), - { nodes: state.nodes, edges: newEdges }, - ].slice(-MAX_HISTORY), - currentHistoryIndex: state.currentHistoryIndex + 1, - }; - }); + set((state) => ({ + edges: [...state.edges, edge], + history: [ + ...state.history.slice(0, state.currentHistoryIndex + 1), + { nodes: state.nodes, edges: [...state.edges, edge] }, + ].slice(-MAX_HISTORY), + currentHistoryIndex: state.currentHistoryIndex + 1, + })); }, removeEdge: (edgeId: string) => { - set((state) => { - const newEdges = state.edges.filter((edge) => edge.id !== edgeId); - - return { - edges: newEdges, - history: [ - ...state.history.slice(0, state.currentHistoryIndex + 1), - { nodes: state.nodes, edges: newEdges }, - ].slice(-MAX_HISTORY), - currentHistoryIndex: state.currentHistoryIndex + 1, - }; - }); + set((state) => ({ + edges: state.edges.filter((edge) => edge.id !== edgeId), + history: [ + ...state.history.slice(0, state.currentHistoryIndex + 1), + { + nodes: state.nodes, + edges: state.edges.filter((edge) => edge.id !== edgeId), + }, + ].slice(-MAX_HISTORY), + currentHistoryIndex: state.currentHistoryIndex + 1, + })); }, setSelectedNode: (nodeId: string | null) => { @@ -745,9 +520,9 @@ export const useTeamBuilderStore = create((set, get) => ({ }); }, - loadFromJson: (component: Component) => { + loadFromJson: (config: Component) => { // Get graph representation of team config - const { nodes, edges } = convertTeamConfigToGraph(component); + const { nodes, edges } = convertTeamConfigToGraph(config); // Apply layout to elements const { nodes: layoutedNodes, edges: layoutedEdges } = getLayoutedElements( @@ -759,8 +534,8 @@ export const useTeamBuilderStore = create((set, get) => ({ set({ nodes: layoutedNodes, edges: layoutedEdges, - originalComponent: component, // Store original component for reference - history: [{ nodes: layoutedNodes, edges: layoutedEdges }], // Reset history with initial state + originalComponent: config, + history: [{ nodes: layoutedNodes, edges: layoutedEdges }], currentHistoryIndex: 0, selectedNodeId: null, }); @@ -768,6 +543,7 @@ export const useTeamBuilderStore = create((set, get) => ({ // Return final graph state return { nodes: layoutedNodes, edges: layoutedEdges }; }, + resetHistory: () => { set((state) => ({ history: [{ nodes: state.nodes, edges: state.edges }], diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/utils.ts b/python/packages/autogen-studio/frontend/src/components/views/team/builder/utils.ts index 2d840ca41b8e..cab952865cc3 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/builder/utils.ts +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/utils.ts @@ -1,37 +1,54 @@ -import dagre from "@dagrejs/dagre"; -import { CustomNode, CustomEdge, EdgeTypes } from "./types"; import { nanoid } from "nanoid"; import { TeamConfig, Component, ComponentConfig, } from "../../../types/datamodel"; -import { isAssistantAgent, isSelectorTeam } from "../../../types/guards"; - -interface ConversionResult { - nodes: CustomNode[]; - edges: CustomEdge[]; -} +import { CustomNode, CustomEdge } from "./types"; interface Position { x: number; y: number; } -// Calculate positions for participants in a grid layout -const calculateParticipantPosition = ( +// Layout configuration +const LAYOUT_CONFIG = { + TEAM_NODE: { + X_POSITION: 100, + MIN_Y_POSITION: 200, + }, + AGENT: { + START_X: 600, // Starting X position for first agent + START_Y: 200, // Starting Y position for first agent + X_STAGGER: 100, // X offset for each subsequent agent + Y_STAGGER: 200, // Y offset for each subsequent agent + }, + NODE: { + WIDTH: 272, + HEIGHT: 200, + }, +}; + +// Calculate staggered position for agents +const calculateAgentPosition = ( index: number, - totalParticipants: number + totalAgents: number ): Position => { - const GRID_SPACING = 250; - const PARTICIPANTS_PER_ROW = 3; - - const row = Math.floor(index / PARTICIPANTS_PER_ROW); - const col = index % PARTICIPANTS_PER_ROW; + return { + x: LAYOUT_CONFIG.AGENT.START_X + index * LAYOUT_CONFIG.AGENT.X_STAGGER, + y: LAYOUT_CONFIG.AGENT.START_Y + index * LAYOUT_CONFIG.AGENT.Y_STAGGER, + }; +}; +// Calculate team node position based on agent positions +const calculateTeamPosition = (totalAgents: number): Position => { + const centerY = ((totalAgents - 1) * LAYOUT_CONFIG.AGENT.Y_STAGGER) / 2; return { - x: col * GRID_SPACING, - y: (row + 1) * GRID_SPACING, + x: LAYOUT_CONFIG.TEAM_NODE.X_POSITION, + y: Math.max( + LAYOUT_CONFIG.TEAM_NODE.MIN_Y_POSITION, + LAYOUT_CONFIG.AGENT.START_Y + centerY + ), }; }; @@ -47,6 +64,7 @@ const createNode = ( data: { label: label || component.label || component.component_type, component, + type: component.component_type, }, }); @@ -54,224 +72,85 @@ const createNode = ( const createEdge = ( source: string, target: string, - type: EdgeTypes + type: "agent-connection" ): CustomEdge => ({ id: `e${source}-${target}`, source, target, + sourceHandle: `${source}-agent-output-handle`, + targetHandle: `${target}-agent-input-handle`, type, }); +// Convert team configuration to graph structure export const convertTeamConfigToGraph = ( teamComponent: Component -): ConversionResult => { +): { nodes: CustomNode[]; edges: CustomEdge[] } => { const nodes: CustomNode[] = []; const edges: CustomEdge[] = []; + const totalAgents = teamComponent.config.participants.length; // Create team node - const teamNode = createNode({ x: 400, y: 50 }, teamComponent); + const teamNode = createNode( + calculateTeamPosition(totalAgents), + teamComponent + ); nodes.push(teamNode); - // Add model client if present - if (isSelectorTeam(teamComponent) && teamComponent.config.model_client) { - const modelNode = createNode( - { x: 200, y: 50 }, - teamComponent.config.model_client, - teamComponent.config.model_client.config.model - ); - nodes.push(modelNode); - edges.push({ - id: nanoid(), - source: modelNode.id, - target: teamNode.id, - sourceHandle: `${modelNode.id}-model-output-handle`, - targetHandle: `${teamNode.id}-model-input-handle`, - type: "model-connection", - }); - } - - // Add participants (agents) + // Create agent nodes with staggered layout teamComponent.config.participants.forEach((participant, index) => { - const position = calculateParticipantPosition( - index, - teamComponent.config.participants.length - ); + const position = calculateAgentPosition(index, totalAgents); const agentNode = createNode(position, participant); nodes.push(agentNode); // Connect to team - edges.push({ - id: nanoid(), - source: teamNode.id, - target: agentNode.id, - sourceHandle: `${teamNode.id}-agent-output-handle`, - targetHandle: `${agentNode.id}-agent-input-handle`, - type: "agent-connection", - }); - - // Add agent's model client if present - if (isAssistantAgent(participant) && participant.config.model_client) { - const agentModelNode = createNode( - { - x: position.x - 150, - y: position.y, - }, - participant.config.model_client, - participant.config.model_client.config.model - ); - nodes.push(agentModelNode); - edges.push({ - id: nanoid(), - source: agentModelNode.id, - target: agentNode.id, - sourceHandle: `${agentModelNode.id}-model-output-handle`, - targetHandle: `${agentNode.id}-model-input-handle`, - type: "model-connection", - }); - } - - // Add agent's tools - if (isAssistantAgent(participant) && participant.config.tools) { - participant.config.tools.forEach((tool, toolIndex) => { - const toolNode = createNode( - { - x: position.x + 150, - y: position.y + toolIndex * 100, - }, - tool - ); - nodes.push(toolNode); - edges.push({ - id: nanoid(), - source: toolNode.id, - target: agentNode.id, - sourceHandle: `${toolNode.id}-tool-output-handle`, - targetHandle: `${agentNode.id}-tool-input-handle`, - type: "tool-connection", - }); - }); - } + edges.push(createEdge(teamNode.id, agentNode.id, "agent-connection")); }); - // Add termination condition if present - if (teamComponent.config.termination_condition) { - const terminationNode = createNode( - { x: 600, y: 50 }, - teamComponent.config.termination_condition - ); - nodes.push(terminationNode); - edges.push({ - id: nanoid(), - source: terminationNode.id, - target: teamNode.id, - sourceHandle: `${terminationNode.id}-termination-output-handle`, - targetHandle: `${teamNode.id}-termination-input-handle`, - type: "termination-connection", - }); - } - return { nodes, edges }; }; -// Rest of the file remains the same since it deals with layout calculations -const NODE_WIDTH = 272; -const NODE_HEIGHT = 200; - +// This is the function expected by the store export const getLayoutedElements = ( nodes: CustomNode[], edges: CustomEdge[] -) => { - const g = new dagre.graphlib.Graph(); - const calculateRank = (node: CustomNode) => { - if (node.data.type === "model") { - // Check if this model is connected to a team or agent - const isTeamModel = edges.some( - (e) => - e.source === node.id && - nodes.find((n) => n.id === e.target)?.data.type === "team" - ); - return isTeamModel ? 0 : 2; - } +): { nodes: CustomNode[]; edges: CustomEdge[] } => { + // Find team node and count agents + const teamNode = nodes.find((n) => n.data.type === "team"); + if (!teamNode) return { nodes, edges }; - switch (node.data.type) { - case "team": - return 1; - case "agent": - return 3; - case "tool": - return 4; - case "termination": - return 1; // Same rank as team - default: - return 0; - } - }; - - g.setGraph({ - rankdir: "LR", - nodesep: 250, - ranksep: 150, - ranker: "network-simplex", // or "tight-tree" depending on needs - align: "DL", - }); - g.setDefaultEdgeLabel(() => ({})); - - // Add nodes to the graph with their dimensions - nodes.forEach((node) => { - const rank = calculateRank(node); - g.setNode(node.id, { - width: NODE_WIDTH, - height: NODE_HEIGHT, - rank, - }); - }); + // Count agent nodes + const agentNodes = nodes.filter((n) => n.data.type !== "team"); + const totalAgents = agentNodes.length; - // Add edges to the graph - edges.forEach((edge) => { - g.setEdge(edge.source, edge.target); - }); - - // Apply the layout - dagre.layout(g); - - // Get the laid out nodes with their new positions + // Calculate new positions const layoutedNodes = nodes.map((node) => { - const nodeWithPosition = g.node(node.id); - return { - ...node, - position: { - x: nodeWithPosition.x - NODE_WIDTH / 2, - y: nodeWithPosition.y - NODE_HEIGHT / 2, - }, - }; + if (node.data.type === "team") { + // Position team node + return { + ...node, + position: calculateTeamPosition(totalAgents), + }; + } else { + // Position agent node + const agentIndex = agentNodes.findIndex((n) => n.id === node.id); + return { + ...node, + position: calculateAgentPosition(agentIndex, totalAgents), + }; + } }); return { nodes: layoutedNodes, edges }; }; -export const getNodeConnections = (nodeId: string, edges: CustomEdge[]) => { - return { - modelClient: - edges.find((e) => e.target === nodeId && e.type === "model-connection") - ?.source || null, - tools: edges - .filter((e) => e.target === nodeId && e.type === "tool-connection") - .map((e) => e.source), - participants: edges - .filter((e) => e.source === nodeId && e.type === "agent-connection") - .map((e) => e.target), - }; -}; - +// Generate unique names (unchanged) export const getUniqueName = ( baseName: string, existingNames: string[] ): string => { - // Convert baseName to valid identifier format let validBaseName = baseName - // Replace spaces and special characters with underscore .replace(/[^a-zA-Z0-9_$]/g, "_") - // Ensure it starts with a letter, underscore, or dollar sign .replace(/^([^a-zA-Z_$])/, "_$1"); if (!existingNames.includes(validBaseName)) return validBaseName; From 527f492b07470afc50ceb2054c4b42401cdc8a2d Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Wed, 29 Jan 2025 12:11:13 -0800 Subject: [PATCH 11/16] fix save agent bug and other items. --- .../autogenstudio/teammanager/teammanager.py | 4 + .../components/views/session/chat/runview.tsx | 2 +- .../components/views/team/builder/builder.tsx | 2 + .../components/views/team/builder/library.tsx | 17 +- .../views/team/builder/node-editor.tsx | 1020 +++++++---------- .../components/views/team/builder/nodes.tsx | 18 +- .../components/views/team/builder/store.tsx | 2 +- .../components/views/team/builder/types.ts | 2 +- .../src/components/views/team/manager.tsx | 4 +- 9 files changed, 423 insertions(+), 648 deletions(-) diff --git a/python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py b/python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py index 663a5442627c..36d6379d42a6 100644 --- a/python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py +++ b/python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py @@ -74,6 +74,10 @@ async def _create_team( # Use Component.load_component directly team = Team.load_component(config) + + for agent in team._participants: + if hasattr(agent, "input_func"): + agent.input_func = input_func # TBD - set input function return team diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/chat/runview.tsx b/python/packages/autogen-studio/frontend/src/components/views/session/chat/runview.tsx index f6ba05e02553..73afef90e415 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/session/chat/runview.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/session/chat/runview.tsx @@ -54,7 +54,7 @@ const RunView: React.FC = ({ }, [run.messages]); // Only depend on messages changing const calculateThreadTokens = (messages: Message[]) => { - console.log("messages", messages); + // console.log("messages", messages); return messages.reduce((total, msg) => { if (!msg.config.models_usage) return total; return ( diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/builder.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/builder/builder.tsx index 8c22e132ffda..8684d73d4f5f 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/builder/builder.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/builder.tsx @@ -67,6 +67,7 @@ export const TeamBuilder: React.FC = ({ history, updateNode, selectedNodeId, + setSelectedNode, } = useTeamBuilderStore(); const currentHistoryIndex = useTeamBuilderStore( @@ -410,6 +411,7 @@ export const TeamBuilder: React.FC = ({ handleSave(); } }} + onClose={() => setSelectedNode(null)} /> diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/library.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/builder/library.tsx index a1118d6657c4..34118174e464 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/builder/library.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/library.tsx @@ -10,22 +10,16 @@ import { Timer, Maximize2, Minimize2, + GripVertical, } from "lucide-react"; -import type { - AgentConfig, - ModelConfig, - TerminationConfig, - ToolConfig, -} from "../../../types/datamodel"; import Sider from "antd/es/layout/Sider"; import { useGalleryStore } from "../../gallery/store"; +import { ComponentTypes } from "../../../types/datamodel"; interface ComponentConfigTypes { [key: string]: any; } -type ComponentTypes = "agent" | "model" | "tool" | "termination"; - interface LibraryProps {} interface PresetItemProps { @@ -66,11 +60,12 @@ const PresetItem: React.FC = ({ style={style} {...attributes} {...listeners} - className="p-2 text-primary mb-2 border border-secondary rounded cursor-move hover:bg-secondary transition-colors" + className="p-2 text-primary mb-2 border border-secondary rounded cursor-move hover:bg-secondary transition-colors " >
+ {icon} - {label} + {label}
); @@ -109,7 +104,7 @@ export const ComponentLibrary: React.FC = () => { title: "Tools", type: "tool" as ComponentTypes, items: defaultGallery.items.components.tools.map((tool) => ({ - label: tool.label, + label: tool.config.name, config: tool, })), icon: , diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/node-editor.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/builder/node-editor.tsx index 07520e47a2d8..192e4b881f57 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/builder/node-editor.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/node-editor.tsx @@ -1,665 +1,439 @@ -import React, { useEffect, useState } from "react"; -import { Drawer, Button, Space, message, Select, Input } from "antd"; +import React from "react"; +import { Input, Select, Switch, InputNumber, Form, Button, Drawer } from "antd"; import { NodeEditorProps } from "./types"; -import { useTeamBuilderStore } from "./store"; import { - TeamConfig, - ComponentTypes, - TeamTypes, - ModelTypes, - SelectorGroupChatConfig, - RoundRobinGroupChatConfig, - ModelConfig, - AzureOpenAIModelConfig, - OpenAIModelConfig, - ComponentConfigTypes, - AgentConfig, - ToolConfig, - AgentTypes, - ToolTypes, - TerminationConfig, - TerminationTypes, - MaxMessageTerminationConfig, - TextMentionTerminationConfig, - CombinationTerminationConfig, -} from "../../../types/datamodel"; + isTeamComponent, + isAgentComponent, + isModelComponent, + isToolComponent, + isTerminationComponent, + isSelectorTeam, + isRoundRobinTeam, + isAssistantAgent, + isUserProxyAgent, + isWebSurferAgent, + isOpenAIModel, + isAzureOpenAIModel, + isFunctionTool, + isOrTermination, + isMaxMessageTermination, + isTextMentionTermination, +} from "../../../types/guards"; const { TextArea } = Input; +const { Option } = Select; -interface EditorProps { - value: T; - onChange: (value: T) => void; - disabled?: boolean; -} +export const NodeEditor: React.FC< + NodeEditorProps & { onClose: () => void } +> = ({ node, onUpdate, onClose }) => { + const [form] = Form.useForm(); -const TeamEditor: React.FC> = ({ - value, - onChange, - disabled, -}) => { - const handleTypeChange = (teamType: TeamTypes) => { - if (teamType === "SelectorGroupChat") { - onChange({ - ...value, - team_type: teamType, - selector_prompt: "", - model_client: { - component_type: "model", - model: "", - model_type: "OpenAIChatCompletionClient", - }, - } as SelectorGroupChatConfig); - } else { - const { selector_prompt, model_client, ...rest } = - value as SelectorGroupChatConfig; - onChange({ - ...rest, - team_type: teamType, - } as RoundRobinGroupChatConfig); + // Initialize form values when node changes + React.useEffect(() => { + if (node) { + form.setFieldsValue(node.data.component); } + }, [node, form]); + + if (!node) return null; + + const component = node.data.component; + + const handleFormSubmit = (values: any) => { + const updatedData = { + ...node.data, + component: { + ...node.data.component, + label: values.label, // These go on the component + description: values.description, // Not on NodeData + config: { + ...node.data.component.config, + ...values.config, + }, + }, + }; + onUpdate(updatedData); }; - return ( - -
- - onChange({ ...value, name: e.target.value })} - disabled={disabled} - /> -
+ const renderTeamFields = () => { + if (!component) return null; - {value.team_type === "SelectorGroupChat" && ( + if (isSelectorTeam(component)) { + return ( <> -
- -