From 4c1c12d3506c5e6d63ca6bc773283fec4c2be466 Mon Sep 17 00:00:00 2001 From: afourney Date: Thu, 6 Feb 2025 22:20:06 -0800 Subject: [PATCH] Flush console output after every message. (#5415) --- .../src/autogen_agentchat/ui/_console.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py b/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py index 1f80166d32a1..0a95c842ea08 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py @@ -75,8 +75,8 @@ def notify_event_received(self, request_id: str) -> None: self.input_events[request_id] = event -def aprint(output: str, end: str = "\n") -> Awaitable[None]: - return asyncio.to_thread(print, output, end=end) +def aprint(output: str, end: str = "\n", flush: bool = False) -> Awaitable[None]: + return asyncio.to_thread(print, output, end=end, flush=flush) async def Console( @@ -126,7 +126,7 @@ async def Console( f"Total completion tokens: {total_usage.completion_tokens}\n" f"Duration: {duration:.2f} seconds\n" ) - await aprint(output, end="") + await aprint(output, end="", flush=True) # mypy ignore last_processed = message # type: ignore @@ -141,7 +141,7 @@ async def Console( output += f"[Prompt tokens: {message.chat_message.models_usage.prompt_tokens}, Completion tokens: {message.chat_message.models_usage.completion_tokens}]\n" total_usage.completion_tokens += message.chat_message.models_usage.completion_tokens total_usage.prompt_tokens += message.chat_message.models_usage.prompt_tokens - await aprint(output, end="") + await aprint(output, end="", flush=True) # Print summary. if output_stats: @@ -156,7 +156,7 @@ async def Console( f"Total completion tokens: {total_usage.completion_tokens}\n" f"Duration: {duration:.2f} seconds\n" ) - await aprint(output, end="") + await aprint(output, end="", flush=True) # mypy ignore last_processed = message # type: ignore @@ -169,7 +169,7 @@ async def Console( message = cast(AgentEvent | ChatMessage, message) # type: ignore if not streaming_chunks: # Print message sender. - await aprint(f"{'-' * 10} {message.source} {'-' * 10}", end="\n") + await aprint(f"{'-' * 10} {message.source} {'-' * 10}", end="\n", flush=True) if isinstance(message, ModelClientStreamingChunkEvent): await aprint(message.content, end="") streaming_chunks.append(message.content) @@ -177,15 +177,16 @@ async def Console( if streaming_chunks: streaming_chunks.clear() # Chunked messages are already printed, so we just print a newline. - await aprint("", end="\n") + await aprint("", end="\n", flush=True) else: # Print message content. - await aprint(_message_to_str(message, render_image_iterm=render_image_iterm), end="\n") + await aprint(_message_to_str(message, render_image_iterm=render_image_iterm), end="\n", flush=True) if message.models_usage: if output_stats: await aprint( f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]", end="\n", + flush=True, ) total_usage.completion_tokens += message.models_usage.completion_tokens total_usage.prompt_tokens += message.models_usage.prompt_tokens