diff --git a/autogen/agentchat/contrib/capabilities/context_handling.py b/autogen/agentchat/contrib/capabilities/context_handling.py index d12af6051641..ebbc00e1097c 100644 --- a/autogen/agentchat/contrib/capabilities/context_handling.py +++ b/autogen/agentchat/contrib/capabilities/context_handling.py @@ -46,7 +46,7 @@ def add_to_agent(self, agent: ConversableAgent): """ Adds TransformChatHistory capability to the given agent. """ - agent.register_hook(hookable_method="process_all_messages", hook=self._transform_messages) + agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages) def _transform_messages(self, messages: List[Dict]) -> List[Dict]: """ diff --git a/autogen/agentchat/contrib/capabilities/teachability.py b/autogen/agentchat/contrib/capabilities/teachability.py index c5c959da8d8f..e90612fa53b2 100644 --- a/autogen/agentchat/contrib/capabilities/teachability.py +++ b/autogen/agentchat/contrib/capabilities/teachability.py @@ -61,7 +61,7 @@ def add_to_agent(self, agent: ConversableAgent): self.teachable_agent = agent # Register a hook for processing the last message. - agent.register_hook(hookable_method="process_last_message", hook=self.process_last_message) + agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message) # Was an llm_config passed to the constructor? if self.llm_config is None: @@ -82,7 +82,7 @@ def prepopulate_db(self): """Adds a few arbitrary memos to the DB.""" self.memo_store.prepopulate() - def process_last_message(self, text): + def process_last_received_message(self, text): """ Appends any relevant memos to the message text, and stores any apparent teachings in new memos. Uses TextAnalyzerAgent to make decisions about memo storage and retrieval. diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index a3726678a452..b31c8ce786d3 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -223,7 +223,11 @@ def __init__( # Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration. # New hookable methods should be added to this list as required to support new agent capabilities. - self.hook_lists = {"process_last_message": [], "process_all_messages": []} + self.hook_lists = { + "process_last_received_message": [], + "process_all_messages_before_reply": [], + "process_message_before_send": [], + } @property def name(self) -> str: @@ -467,6 +471,15 @@ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id: self._oai_messages[conversation_id].append(oai_message) return True + def _process_message_before_send( + self, message: Union[Dict, str], recipient: Agent, silent: bool + ) -> Union[Dict, str]: + """Process the message before sending it to the recipient.""" + hook_list = self.hook_lists["process_message_before_send"] + for hook in hook_list: + message = hook(message, recipient, silent) + return message + def send( self, message: Union[Dict, str], @@ -509,6 +522,7 @@ def send( Returns: ChatResult: a ChatResult object. """ + message = self._process_message_before_send(message, recipient, silent) # When the agent composes and sends the message, the role of the message is "assistant" # unless it's "function". valid = self._append_oai_message(message, "assistant", recipient) @@ -561,6 +575,7 @@ async def a_send( Returns: ChatResult: an ChatResult object. """ + message = self._process_message_before_send(message, recipient, silent) # When the agent composes and sends the message, the role of the message is "assistant" # unless it's "function". valid = self._append_oai_message(message, "assistant", recipient) @@ -1634,11 +1649,11 @@ def generate_reply( # Call the hookable method that gives registered hooks a chance to process all messages. # Message modifications do not affect the incoming messages or self._oai_messages. - messages = self.process_all_messages(messages) + messages = self.process_all_messages_before_reply(messages) # Call the hookable method that gives registered hooks a chance to process the last message. # Message modifications do not affect the incoming messages or self._oai_messages. - messages = self.process_last_message(messages) + messages = self.process_last_received_message(messages) for reply_func_tuple in self._reply_func_list: reply_func = reply_func_tuple["reply_func"] @@ -1695,11 +1710,11 @@ async def a_generate_reply( # Call the hookable method that gives registered hooks a chance to process all messages. # Message modifications do not affect the incoming messages or self._oai_messages. - messages = self.process_all_messages(messages) + messages = self.process_all_messages_before_reply(messages) # Call the hookable method that gives registered hooks a chance to process the last message. # Message modifications do not affect the incoming messages or self._oai_messages. - messages = self.process_last_message(messages) + messages = self.process_last_received_message(messages) for reply_func_tuple in self._reply_func_list: reply_func = reply_func_tuple["reply_func"] @@ -2333,11 +2348,11 @@ def register_hook(self, hookable_method: str, hook: Callable): assert hook not in hook_list, f"{hook} is already registered as a hook." hook_list.append(hook) - def process_all_messages(self, messages: List[Dict]) -> List[Dict]: + def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]: """ Calls any registered capability hooks to process all messages, potentially modifying the messages. """ - hook_list = self.hook_lists["process_all_messages"] + hook_list = self.hook_lists["process_all_messages_before_reply"] # If no hooks are registered, or if there are no messages to process, return the original message list. if len(hook_list) == 0 or messages is None: return messages @@ -2348,14 +2363,14 @@ def process_all_messages(self, messages: List[Dict]) -> List[Dict]: processed_messages = hook(processed_messages) return processed_messages - def process_last_message(self, messages): + def process_last_received_message(self, messages): """ Calls any registered capability hooks to use and potentially modify the text of the last message, as long as the last message is not a function call or exit command. """ # If any required condition is not met, return the original message list. - hook_list = self.hook_lists["process_last_message"] + hook_list = self.hook_lists["process_last_received_message"] if len(hook_list) == 0: return messages # No hooks registered. if messages is None: diff --git a/test/agentchat/test_conversable_agent.py b/test/agentchat/test_conversable_agent.py index b71a6341c877..2a5eaf5f5bb2 100644 --- a/test/agentchat/test_conversable_agent.py +++ b/test/agentchat/test_conversable_agent.py @@ -1074,6 +1074,24 @@ def test_max_turn(): assert len(res.chat_history) <= 6 +def test_process_before_send(): + print_mock = unittest.mock.MagicMock() + + def send_to_frontend(message, recipient, silent): + if not silent: + print(f"Message sent to {recipient.name}: {message}") + print_mock(message=message) + return message + + dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER") + dummy_agent_2 = ConversableAgent(name="dummy_agent_2", llm_config=False, human_input_mode="NEVER") + dummy_agent_1.register_hook("process_message_before_send", send_to_frontend) + dummy_agent_1.send("hello", dummy_agent_2) + print_mock.assert_called_once_with(message="hello") + dummy_agent_1.send("silent hello", dummy_agent_2, silent=True) + print_mock.assert_called_once_with(message="hello") + + if __name__ == "__main__": # test_trigger() # test_context() @@ -1081,4 +1099,5 @@ def test_max_turn(): # test_generate_code_execution_reply() # test_conversable_agent() # test_no_llm_config() - test_max_turn() + # test_max_turn() + test_process_before_send()