Skip to content

Commit

Permalink
Merge branch 'main' into update_agbench
Browse files Browse the repository at this point in the history
  • Loading branch information
afourney authored Feb 7, 2025
2 parents a710b7a + 4c1c12d commit f022d7f
Showing 1 changed file with 9 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def notify_event_received(self, request_id: str) -> None:
self.input_events[request_id] = event


def aprint(output: str, end: str = "\n") -> Awaitable[None]:
return asyncio.to_thread(print, output, end=end)
def aprint(output: str, end: str = "\n", flush: bool = False) -> Awaitable[None]:
return asyncio.to_thread(print, output, end=end, flush=flush)


async def Console(
Expand Down Expand Up @@ -126,7 +126,7 @@ async def Console(
f"Total completion tokens: {total_usage.completion_tokens}\n"
f"Duration: {duration:.2f} seconds\n"
)
await aprint(output, end="")
await aprint(output, end="", flush=True)

# mypy ignore
last_processed = message # type: ignore
Expand All @@ -141,7 +141,7 @@ async def Console(
output += f"[Prompt tokens: {message.chat_message.models_usage.prompt_tokens}, Completion tokens: {message.chat_message.models_usage.completion_tokens}]\n"
total_usage.completion_tokens += message.chat_message.models_usage.completion_tokens
total_usage.prompt_tokens += message.chat_message.models_usage.prompt_tokens
await aprint(output, end="")
await aprint(output, end="", flush=True)

# Print summary.
if output_stats:
Expand All @@ -156,7 +156,7 @@ async def Console(
f"Total completion tokens: {total_usage.completion_tokens}\n"
f"Duration: {duration:.2f} seconds\n"
)
await aprint(output, end="")
await aprint(output, end="", flush=True)

# mypy ignore
last_processed = message # type: ignore
Expand All @@ -169,23 +169,24 @@ async def Console(
message = cast(AgentEvent | ChatMessage, message) # type: ignore
if not streaming_chunks:
# Print message sender.
await aprint(f"{'-' * 10} {message.source} {'-' * 10}", end="\n")
await aprint(f"{'-' * 10} {message.source} {'-' * 10}", end="\n", flush=True)
if isinstance(message, ModelClientStreamingChunkEvent):
await aprint(message.content, end="")
streaming_chunks.append(message.content)
else:
if streaming_chunks:
streaming_chunks.clear()
# Chunked messages are already printed, so we just print a newline.
await aprint("", end="\n")
await aprint("", end="\n", flush=True)
else:
# Print message content.
await aprint(_message_to_str(message, render_image_iterm=render_image_iterm), end="\n")
await aprint(_message_to_str(message, render_image_iterm=render_image_iterm), end="\n", flush=True)
if message.models_usage:
if output_stats:
await aprint(
f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]",
end="\n",
flush=True,
)
total_usage.completion_tokens += message.models_usage.completion_tokens
total_usage.prompt_tokens += message.models_usage.prompt_tokens
Expand Down

0 comments on commit f022d7f

Please sign in to comment.