Skip to content

Commit

Permalink
doc & fix: Enhance AgentInstantiationContext with detailed documentat…
Browse files Browse the repository at this point in the history
…ion and examples for agent instantiation; Fix a but that caused value error when the expected class is not provided in register_factory (#5555)

Resolves #5519

Also spotted and fixed a bug that caused value error from `register_factory`, when the `expected_class` was not provided.
  • Loading branch information
ekzhu authored Feb 15, 2025
1 parent 69c0b2b commit 80891b4
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,81 @@


class AgentInstantiationContext:
"""A static class that provides context for agent instantiation.
This static class can be used to access the current runtime and agent ID
during agent instantiation -- inside the factory function or the agent's
class constructor.
Example:
Get the current runtime and agent ID inside the factory function and
the agent's constructor:
.. code-block:: python
import asyncio
from dataclasses import dataclass
from autogen_core import (
AgentId,
AgentInstantiationContext,
MessageContext,
RoutedAgent,
SingleThreadedAgentRuntime,
message_handler,
)
@dataclass
class TestMessage:
content: str
class TestAgent(RoutedAgent):
def __init__(self, description: str):
super().__init__(description)
# Get the current runtime -- we don't use it here, but it's available.
_ = AgentInstantiationContext.current_runtime()
# Get the current agent ID.
agent_id = AgentInstantiationContext.current_agent_id()
print(f"Current AgentID from constructor: {agent_id}")
@message_handler
async def handle_test_message(self, message: TestMessage, ctx: MessageContext) -> None:
print(f"Received message: {message.content}")
def test_agent_factory() -> TestAgent:
# Get the current runtime -- we don't use it here, but it's available.
_ = AgentInstantiationContext.current_runtime()
# Get the current agent ID.
agent_id = AgentInstantiationContext.current_agent_id()
print(f"Current AgentID from factory: {agent_id}")
return TestAgent(description="Test agent")
async def main() -> None:
# Create a SingleThreadedAgentRuntime instance.
runtime = SingleThreadedAgentRuntime()
# Start the runtime.
runtime.start()
# Register the agent type with a factory function.
await runtime.register_factory("test_agent", test_agent_factory)
# Send a message to the agent. The runtime will instantiate the agent and call the message handler.
await runtime.send_message(TestMessage(content="Hello, world!"), AgentId("test_agent", "default"))
# Stop the runtime.
await runtime.stop()
asyncio.run(main())
"""

def __init__(self) -> None:
raise RuntimeError(
"AgentInstantiationContext cannot be instantiated. It is a static class that provides context management for agent instantiation."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async def main() -> None:
Args:
type (str): The type of agent this factory creates. It is not the same as agent class name. The `type` parameter is used to differentiate between different factory functions rather than agent classes.
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `autogen_core.AgentInstantiationContext` to access variables like the current runtime and agent ID.
expected_class (type[T] | None, optional): The expected class of the agent, used for runtime validation of the factory. Defaults to None.
expected_class (type[T] | None, optional): The expected class of the agent, used for runtime validation of the factory. Defaults to None. If None, no validation is performed.
"""
...

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -676,10 +676,20 @@ def start(self) -> None:
.. code-block:: python
import asyncio
from autogen_core import SingleThreadedAgentRuntime
runtime = SingleThreadedAgentRuntime()
runtime.start()
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
runtime.start()
# ... do other things ...
await runtime.stop()
asyncio.run(main())
"""
if self._run_context is not None:
Expand Down Expand Up @@ -765,7 +775,7 @@ async def factory_wrapper() -> T:
else:
agent_instance = maybe_agent_instance

if type_func_alias(agent_instance) != expected_class:
if expected_class is not None and type_func_alias(agent_instance) != expected_class:
raise ValueError("Factory registered using the wrong type.")

return agent_instance
Expand Down
25 changes: 25 additions & 0 deletions python/packages/autogen-core/tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,31 @@ def tracer_provider() -> TracerProvider:
return get_test_tracer_provider(test_exporter)


@pytest.mark.asyncio
async def test_agent_type_register_factory() -> None:
runtime = SingleThreadedAgentRuntime()

def agent_factory() -> NoopAgent:
id = AgentInstantiationContext.current_agent_id()
assert id == AgentId("name1", "default")
agent = NoopAgent()
assert agent.id == id
return agent

await runtime.register_factory(type=AgentType("name1"), agent_factory=agent_factory, expected_class=NoopAgent)

with pytest.raises(ValueError):
# This should fail because the expected class does not match the actual class.
await runtime.register_factory(
type=AgentType("name1"),
agent_factory=agent_factory, # type: ignore
expected_class=CascadingAgent,
)

# Without expected_class, no error.
await runtime.register_factory(type=AgentType("name2"), agent_factory=agent_factory)


@pytest.mark.asyncio
async def test_agent_type_must_be_unique() -> None:
runtime = SingleThreadedAgentRuntime()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ async def factory_wrapper() -> T:
else:
agent_instance = maybe_agent_instance

if type_func_alias(agent_instance) != expected_class:
if expected_class is not None and type_func_alias(agent_instance) != expected_class:
raise ValueError("Factory registered using the wrong type.")

return agent_instance
Expand Down
1 change: 1 addition & 0 deletions python/packages/autogen-ext/tests/test_worker_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ async def test_agent_types_must_be_unique_single_worker() -> None:
)

await worker.register_factory(type=AgentType("name4"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
await worker.register_factory(type=AgentType("name5"), agent_factory=lambda: NoopAgent())

await worker.stop()
await host.stop()
Expand Down

0 comments on commit 80891b4

Please sign in to comment.