Skip to content

Commit

Permalink
Track SelectorGroup select_speaker tokens with new Message type
Browse files Browse the repository at this point in the history
  • Loading branch information
gziz committed Dec 20, 2024
1 parent a271708 commit d349011
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import AsyncGenerator, List, Protocol, Sequence

from autogen_core import CancellationToken
from autogen_core.models._types import RequestUsage

from ..messages import AgentEvent, ChatMessage

Expand All @@ -16,6 +17,9 @@ class TaskResult:
stop_reason: str | None = None
"""The reason the task stopped."""

usage: RequestUsage | None = None
"""The usage of the task."""


class TaskRunner(Protocol):
"""A task runner."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,25 @@ class ToolCallSummaryMessage(BaseMessage):
type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage"


class UsageEvent(BaseMessage):
"""An event signaling the usage of a model."""

content: str = ""
"""The content of the usage event."""

models_usage: RequestUsage
"""The model client usage incurred when producing this message."""

type: Literal["UsageEvent"] = "UsageEvent"


ChatMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
]
"""Messages for agent-to-agent communication only."""


AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent, Field(discriminator="type")]
AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent | UsageEvent, Field(discriminator="type")]
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TypeSubscription,
)
from autogen_core._closure_agent import ClosureContext
from autogen_core.models._types import RequestUsage

from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
Expand Down Expand Up @@ -74,6 +75,8 @@ def __init__(
# Flag to track if the group chat is running.
self._is_running = False

self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)

@abstractmethod
def _create_group_chat_manager_factory(
self,
Expand Down Expand Up @@ -418,8 +421,14 @@ async def stop_runtime() -> None:
yield message
output_messages.append(message)

usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
for message in output_messages:
if message.models_usage:
usage.prompt_tokens += message.models_usage.prompt_tokens
usage.completion_tokens += message.models_usage.completion_tokens

# Yield the final result.
yield TaskResult(messages=output_messages, stop_reason=self._stop_reason)
yield TaskResult(messages=output_messages, stop_reason=self._stop_reason, usage=usage)

finally:
# Wait for the shutdown task to finish.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import re
from typing import Any, Callable, Dict, List, Mapping, Sequence

from autogen_core._default_topic import DefaultTopicId
from autogen_core.models import ChatCompletionClient, SystemMessage

from autogen_agentchat.teams._group_chat._events import GroupChatMessage

from ... import TRACE_LOGGER_NAME
from ...base import ChatAgent, TerminationCondition
from ...messages import (
Expand All @@ -16,6 +19,7 @@
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
UsageEvent,
)
from ...state import SelectorManagerState
from ._base_group_chat import BaseGroupChat
Expand Down Expand Up @@ -153,6 +157,12 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
agent_name = participants[0]
self._previous_speaker = agent_name
trace_logger.debug(f"Selected speaker: {agent_name}")

await self.publish_message(
GroupChatMessage(message=UsageEvent(source=self._id._type, models_usage=response.usage)),
topic_id=DefaultTopicId(type=self._output_topic_type),
)

return agent_name

def _mentioned_agents(self, message_content: str, agent_names: List[str]) -> Dict[str, int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from autogen_core.models import RequestUsage

from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, UsageEvent


def _is_running_in_iterm() -> bool:
Expand Down Expand Up @@ -90,7 +90,9 @@ async def Console(
sys.stdout.flush()
# mypy ignore
last_processed = message # type: ignore

elif isinstance(message, UsageEvent):
total_usage.completion_tokens += message.models_usage.completion_tokens
total_usage.prompt_tokens += message.models_usage.prompt_tokens
else:
# Cast required for mypy to be happy
message = cast(AgentEvent | ChatMessage, message) # type: ignore
Expand Down

0 comments on commit d349011

Please sign in to comment.