Skip to content

Commit

Permalink
Allow closure agent to ignore unknown messages, add docs (microsoft#4836
Browse files Browse the repository at this point in the history
)

Allow closure agent to ignore unknown messages
  • Loading branch information
jackgerrits authored Dec 27, 2024
1 parent 2819515 commit a5681d7
Showing 1 changed file with 80 additions and 8 deletions.
88 changes: 80 additions & 8 deletions python/packages/autogen-core/src/autogen_core/_closure_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import inspect
from typing import Any, Awaitable, Callable, List, Mapping, Protocol, Sequence, TypeVar, get_type_hints
import warnings
from typing import Any, Awaitable, Callable, List, Literal, Mapping, Protocol, Sequence, TypeVar, get_type_hints

from ._agent_id import AgentId
from ._agent_instantiation import AgentInstantiationContext
Expand Down Expand Up @@ -73,7 +74,11 @@ async def publish_message(

class ClosureAgent(BaseAgent, ClosureContext):
def __init__(
self, description: str, closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]]
self,
description: str,
closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]],
*,
unknown_type_policy: Literal["error", "warn", "ignore"] = "warn",
) -> None:
try:
runtime = AgentInstantiationContext.current_runtime()
Expand All @@ -89,6 +94,7 @@ def __init__(
handled_types = get_handled_types_from_closure(closure)
self._expected_types = handled_types
self._closure = closure
self._unknown_type_policy = unknown_type_policy
super().__init__(description)

@property
Expand All @@ -110,9 +116,17 @@ def runtime(self) -> AgentRuntime:

async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any:
if type(message) not in self._expected_types:
raise CantHandleException(
f"Message type {type(message)} not in target types {self._expected_types} of {self.id}"
)
if self._unknown_type_policy == "warn":
warnings.warn(
f"Message type {type(message)} not in target types {self._expected_types} of {self.id}. Set unknown_type_policy to 'error' to raise an exception, or 'ignore' to suppress this warning.",
stacklevel=1,
)
return None
elif self._unknown_type_policy == "error":
raise CantHandleException(
f"Message type {type(message)} not in target types {self._expected_types} of {self.id}. Set unknown_type_policy to 'warn' to suppress this exception, or 'ignore' to suppress this warning."
)

return await self._closure(self, message, ctx)

async def save_state(self) -> Mapping[str, Any]:
Expand All @@ -130,19 +144,77 @@ async def register_closure(
type: str,
closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]],
*,
skip_class_subscriptions: bool = False,
unknown_type_policy: Literal["error", "warn", "ignore"] = "warn",
skip_direct_message_subscription: bool = False,
description: str = "",
subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None = None,
) -> AgentType:
"""The closure agent allows you to define an agent using a closure, or function without needing to define a class. It allows values to be extracted out of the runtime.
The closure can define the type of message which is expected, or `Any` can be used to accept any type of message.
Example:
.. code-block:: python
import asyncio
from autogen_core import SingleThreadedAgentRuntime, MessageContext, ClosureAgent, ClosureContext
from dataclasses import dataclass
from autogen_core._default_subscription import DefaultSubscription
from autogen_core._default_topic import DefaultTopicId
@dataclass
class MyMessage:
content: str
async def main():
queue = asyncio.Queue[MyMessage]()
async def output_result(_ctx: ClosureContext, message: MyMessage, ctx: MessageContext) -> None:
await queue.put(message)
runtime = SingleThreadedAgentRuntime()
await ClosureAgent.register_closure(
runtime, "output_result", output_result, subscriptions=lambda: [DefaultSubscription()]
)
runtime.start()
await runtime.publish_message(MyMessage("Hello, world!"), DefaultTopicId())
await runtime.stop_when_idle()
result = await queue.get()
print(result)
asyncio.run(main())
Args:
runtime (AgentRuntime): Runtime to register the agent to
type (str): Agent type of registered agent
closure (Callable[[ClosureContext, T, MessageContext], Awaitable[Any]]): Closure to handle messages
unknown_type_policy (Literal["error", "warn", "ignore"], optional): What to do if a type is encountered that does not match the closure type. Defaults to "warn".
skip_direct_message_subscription (bool, optional): Do not add direct message subscription for this agent. Defaults to False.
description (str, optional): Description of what agent does. Defaults to "".
subscriptions (Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None, optional): List of subscriptions for this closure agent. Defaults to None.
Returns:
AgentType: Type of the agent that was registered
"""

def factory() -> ClosureAgent:
return ClosureAgent(description=description, closure=closure)
return ClosureAgent(description=description, closure=closure, unknown_type_policy=unknown_type_policy)

assert len(cls._unbound_subscriptions()) == 0, "Closure agents are expected to have no class subscriptions"
agent_type = await cls.register(
runtime=runtime,
type=type,
factory=factory, # type: ignore
skip_class_subscriptions=skip_class_subscriptions,
# There should be no need to process class subscriptions, as the closure agent does not have any subscriptions.s
skip_class_subscriptions=True,
skip_direct_message_subscription=skip_direct_message_subscription,
)

Expand Down

0 comments on commit a5681d7

Please sign in to comment.