Skip to content

Commit a72be1f

Browse files
authored
Python: Add thread_id to assistant message to avoid duplication (#11221)
### Motivation and Context The current OpenAI assistant get_response and invoke methods don't add the current thread_id to the CMC's metadata, therefore during a `thread.on_new_message(...)` call, the assistant's message will get added again. <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> ### Description Add the `thread_id` to the metadata so we don't add the assistant's response again. **Todo**: add agent integration tests and check messages so we have coverage for this. <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [X] The code builds clean without any errors or warnings - [X] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [X] All unit tests pass, and I have added new tests where possible - [X] I didn't break anyone 😄
1 parent a9b99d4 commit a72be1f

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

python/semantic_kernel/agents/open_ai/assistant_thread_actions.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,10 @@ async def _handle_streaming_requires_action(
554554

555555
@classmethod
556556
async def get_messages(
557-
cls: type[_T], client: AsyncOpenAI, thread_id: str, sort_order: Literal["asc", "desc"] = "desc"
557+
cls: type[_T],
558+
client: AsyncOpenAI,
559+
thread_id: str,
560+
sort_order: Literal["asc", "desc"] | None = None,
558561
) -> AsyncIterable["ChatMessageContent"]:
559562
"""Get messages from the thread.
560563
@@ -572,7 +575,7 @@ async def get_messages(
572575
while True:
573576
messages = await client.beta.threads.messages.list(
574577
thread_id=thread_id,
575-
order=sort_order,
578+
order=sort_order, # type: ignore
576579
after=last_id,
577580
)
578581

python/semantic_kernel/agents/open_ai/open_ai_assistant_agent.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ async def _on_new_message(self, new_message: str | ChatMessageContent) -> None:
140140
assert self._id is not None # nosec
141141
await AssistantThreadActions.create_message(self._client, self._id, new_message)
142142

143-
async def get_messages(self, sort_order: Literal["asc", "desc"] = "desc") -> AsyncIterable[ChatMessageContent]:
143+
async def get_messages(self, sort_order: Literal["asc", "desc"] | None = None) -> AsyncIterable[ChatMessageContent]:
144144
"""Get the messages in the thread.
145145
146146
Args:
@@ -556,6 +556,7 @@ async def get_response(
556556
**run_level_params, # type: ignore
557557
):
558558
if is_visible and response.metadata.get("code") is not True:
559+
response.metadata["thread_id"] = thread.id
559560
response_messages.append(response)
560561

561562
if not response_messages:
@@ -650,16 +651,17 @@ async def invoke(
650651
}
651652
run_level_params = {k: v for k, v in run_level_params.items() if v is not None}
652653

653-
async for is_visible, message in AssistantThreadActions.invoke(
654+
async for is_visible, response in AssistantThreadActions.invoke(
654655
agent=self,
655656
thread_id=thread.id,
656657
kernel=kernel,
657658
arguments=arguments,
658659
**run_level_params, # type: ignore
659660
):
660661
if is_visible:
661-
await thread.on_new_message(message)
662-
yield AgentResponseItem(message=message, thread=thread)
662+
response.metadata["thread_id"] = thread.id
663+
await thread.on_new_message(response)
664+
yield AgentResponseItem(message=response, thread=thread)
663665

664666
@trace_agent_invocation
665667
@override

0 commit comments

Comments
 (0)