Skip to content

Commit 5c969d3

Browse files
authored
fix: add state management for oai assistant (#5352)
<!-- Thank you for your contribution! Please review https://microsoft.github.io/autogen/docs/Contribute before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? <!-- Please give a short summary of the change and the problem this solves. --> To allow serialization of OAI Assistant Agent. ## Related issue number <!-- For example: "Closes #1234" --> Closes #5130 ## Checks - [ ] I've included any doc changes needed for https://microsoft.github.io/autogen/. See https://microsoft.github.io/autogen/docs/Contribute#documentation to build and test documentation locally. - [ ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ ] I've made sure all auto checks have passed.
1 parent 68cc2e1 commit 5c969d3

File tree

2 files changed

+245
-16
lines changed

2 files changed

+245
-16
lines changed

python/packages/autogen-ext/src/autogen_ext/agents/openai/_openai_assistant_agent.py

+29
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Iterable,
1212
List,
1313
Literal,
14+
Mapping,
1415
Optional,
1516
Sequence,
1617
Set,
@@ -36,6 +37,7 @@
3637
from autogen_core.models._model_client import ChatCompletionClient
3738
from autogen_core.models._types import FunctionExecutionResult
3839
from autogen_core.tools import FunctionTool, Tool
40+
from pydantic import BaseModel, Field
3941

4042
from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI, NotGiven
4143
from openai.pagination import AsyncCursorPage
@@ -77,6 +79,15 @@ def _convert_tool_to_function_param(tool: Tool) -> "FunctionToolParam":
7779
return FunctionToolParam(type="function", function=function_def)
7880

7981

82+
class OpenAIAssistantAgentState(BaseModel):
83+
type: str = Field(default="OpenAIAssistantAgentState")
84+
assistant_id: Optional[str] = None
85+
thread_id: Optional[str] = None
86+
initial_message_ids: List[str] = Field(default_factory=list)
87+
vector_store_id: Optional[str] = None
88+
uploaded_file_ids: List[str] = Field(default_factory=list)
89+
90+
8091
class OpenAIAssistantAgent(BaseChatAgent):
8192
"""An agent implementation that uses the Assistant API to generate responses.
8293
@@ -666,3 +677,21 @@ async def delete_vector_store(self, cancellation_token: CancellationToken) -> No
666677
self._vector_store_id = None
667678
except Exception as e:
668679
event_logger.error(f"Failed to delete vector store: {str(e)}")
680+
681+
async def save_state(self) -> Mapping[str, Any]:
682+
state = OpenAIAssistantAgentState(
683+
assistant_id=self._assistant.id if self._assistant else self._assistant_id,
684+
thread_id=self._thread.id if self._thread else self._init_thread_id,
685+
initial_message_ids=list(self._initial_message_ids),
686+
vector_store_id=self._vector_store_id,
687+
uploaded_file_ids=self._uploaded_file_ids,
688+
)
689+
return state.model_dump()
690+
691+
async def load_state(self, state: Mapping[str, Any]) -> None:
692+
agent_state = OpenAIAssistantAgentState.model_validate(state)
693+
self._assistant_id = agent_state.assistant_id
694+
self._init_thread_id = agent_state.thread_id
695+
self._initial_message_ids = set(agent_state.initial_message_ids)
696+
self._vector_store_id = agent_state.vector_store_id
697+
self._uploaded_file_ids = agent_state.uploaded_file_ids

python/packages/autogen-ext/tests/test_openai_assistant_agent.py

+216-16
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1+
import io
12
import os
3+
from contextlib import asynccontextmanager
24
from enum import Enum
3-
from typing import List, Literal, Optional, Union
5+
from pathlib import Path
6+
from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Union
7+
from unittest.mock import AsyncMock, MagicMock
48

9+
import aiofiles
510
import pytest
6-
from autogen_agentchat.messages import TextMessage
11+
from autogen_agentchat.messages import ChatMessage, TextMessage
712
from autogen_core import CancellationToken
813
from autogen_core.tools._base import BaseTool, Tool
914
from autogen_ext.agents.openai import OpenAIAssistantAgent
@@ -57,14 +62,104 @@ async def run(self, args: DisplayQuizArgs, cancellation_token: CancellationToken
5762
return QuizResponses(responses=responses)
5863

5964

65+
class FakeText:
66+
def __init__(self, value: str):
67+
self.value = value
68+
69+
70+
class FakeTextContent:
71+
def __init__(self, text: str):
72+
self.type = "text"
73+
self.text = FakeText(text)
74+
75+
76+
class FakeMessage:
77+
def __init__(self, id: str, text: str):
78+
self.id = id
79+
# The agent expects content to be a list of objects with a "type" attribute.
80+
self.content = [FakeTextContent(text)]
81+
82+
83+
class FakeCursorPage:
84+
def __init__(self, data: List[ChatMessage | FakeMessage]) -> None:
85+
self.data = data
86+
87+
def has_next_page(self) -> bool:
88+
return False
89+
90+
91+
def create_mock_openai_client() -> AsyncAzureOpenAI:
92+
# Create the base client as an AsyncMock.
93+
client = AsyncMock(spec=AsyncAzureOpenAI)
94+
95+
# Create a "beta" attribute with the required nested structure.
96+
beta = MagicMock()
97+
client.beta = beta
98+
99+
# Setup beta.assistants with dummy create/retrieve/update/delete.
100+
beta.assistants = MagicMock()
101+
beta.assistants.create = AsyncMock(return_value=MagicMock(id="assistant-mock"))
102+
beta.assistants.retrieve = AsyncMock(return_value=MagicMock(id="assistant-mock"))
103+
beta.assistants.update = AsyncMock(return_value=MagicMock(id="assistant-mock"))
104+
beta.assistants.delete = AsyncMock(return_value=None)
105+
106+
# Setup beta.threads with create and retrieve.
107+
beta.threads = MagicMock()
108+
beta.threads.create = AsyncMock(return_value=MagicMock(id="thread-mock", tool_resources=None))
109+
beta.threads.retrieve = AsyncMock(return_value=MagicMock(id="thread-mock", tool_resources=None))
110+
111+
# Setup beta.threads.messages with create, list, and delete.
112+
beta.threads.messages = MagicMock()
113+
beta.threads.messages.create = AsyncMock(return_value=MagicMock(id="msg-mock", content="mock content"))
114+
115+
# Default fake messages – these may be overridden in individual tests.
116+
name_message = FakeMessage("msg-mock", "Your name is John, you are a software engineer.")
117+
118+
def mock_list(thread_id: str, **kwargs: Dict[str, Any]) -> FakeCursorPage:
119+
# Default behavior returns the "name" message.
120+
if thread_id == "thread-mock":
121+
return FakeCursorPage([name_message])
122+
return FakeCursorPage([FakeMessage("msg-mock", "Default response")])
123+
124+
beta.threads.messages.list = AsyncMock(side_effect=mock_list)
125+
beta.threads.messages.delete = AsyncMock(return_value=MagicMock(deleted=True))
126+
127+
# Setup beta.threads.runs with create, retrieve, and submit_tool_outputs.
128+
beta.threads.runs = MagicMock()
129+
beta.threads.runs.create = AsyncMock(return_value=MagicMock(id="run-mock", status="completed"))
130+
beta.threads.runs.retrieve = AsyncMock(return_value=MagicMock(id="run-mock", status="completed"))
131+
beta.threads.runs.submit_tool_outputs = AsyncMock(return_value=MagicMock(id="run-mock", status="completed"))
132+
133+
# Setup beta.vector_stores with create, delete, and file_batches.
134+
beta.vector_stores = MagicMock()
135+
beta.vector_stores.create = AsyncMock(return_value=MagicMock(id="vector-mock"))
136+
beta.vector_stores.delete = AsyncMock(return_value=None)
137+
beta.vector_stores.file_batches = MagicMock()
138+
beta.vector_stores.file_batches.create_and_poll = AsyncMock(return_value=None)
139+
140+
# Setup client.files with create and delete.
141+
client.files = MagicMock()
142+
client.files.create = AsyncMock(return_value=MagicMock(id="file-mock"))
143+
client.files.delete = AsyncMock(return_value=None)
144+
145+
return client
146+
147+
148+
# Fixture for the mock client.
149+
@pytest.fixture
150+
def mock_openai_client() -> AsyncAzureOpenAI:
151+
return create_mock_openai_client()
152+
153+
60154
@pytest.fixture
61155
def client() -> AsyncAzureOpenAI:
62156
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
63157
api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview")
64158
api_key = os.getenv("AZURE_OPENAI_API_KEY")
65159

66-
if not azure_endpoint:
67-
pytest.skip("Azure OpenAI endpoint not found in environment variables")
160+
# Return mock client if credentials not available
161+
if not azure_endpoint or not api_key:
162+
return create_mock_openai_client()
68163

69164
# Try Azure CLI credentials if API key not provided
70165
if not api_key:
@@ -76,7 +171,7 @@ def client() -> AsyncAzureOpenAI:
76171
azure_endpoint=azure_endpoint, api_version=api_version, azure_ad_token_provider=token_provider
77172
)
78173
except Exception:
79-
pytest.skip("Failed to get Azure CLI credentials and no API key provided")
174+
return create_mock_openai_client()
80175

81176
# Fall back to API key auth if provided
82177
return AsyncAzureOpenAI(azure_endpoint=azure_endpoint, api_version=api_version, api_key=api_key)
@@ -105,10 +200,38 @@ def cancellation_token() -> CancellationToken:
105200
return CancellationToken()
106201

107202

203+
# A fake aiofiles.open to bypass filesystem access.
204+
@asynccontextmanager
205+
async def fake_aiofiles_open(*args: Any, **kwargs: Dict[str, Any]) -> AsyncGenerator[io.BytesIO, None]:
206+
yield io.BytesIO(b"dummy file content")
207+
208+
108209
@pytest.mark.asyncio
109-
async def test_file_retrieval(agent: OpenAIAssistantAgent, cancellation_token: CancellationToken) -> None:
110-
file_path = r"C:\Users\lpinheiro\Github\autogen-test\data\SampleBooks\jungle_book.txt"
111-
await agent.on_upload_for_file_search(file_path, cancellation_token)
210+
async def test_file_retrieval(
211+
agent: OpenAIAssistantAgent, cancellation_token: CancellationToken, monkeypatch: pytest.MonkeyPatch, tmp_path: Path
212+
) -> None:
213+
# Arrange: Define a fake async file opener that returns a file-like object with an async read() method.
214+
class FakeAiofilesFile:
215+
async def read(self) -> bytes:
216+
return b"dummy file content"
217+
218+
@asynccontextmanager
219+
async def fake_async_aiofiles_open(*args: Any, **kwargs: Dict[str, Any]) -> AsyncGenerator[FakeAiofilesFile, None]:
220+
yield FakeAiofilesFile()
221+
222+
monkeypatch.setattr(aiofiles, "open", fake_async_aiofiles_open)
223+
224+
# We also override the messages.list to return a fake file search result.
225+
fake_file_message = FakeMessage(
226+
"msg-mock", "The first sentence of the jungle book is 'Mowgli was raised by wolves.'"
227+
)
228+
agent._client.beta.threads.messages.list = AsyncMock(return_value=FakeCursorPage([fake_file_message])) # type: ignore
229+
230+
# Create a temporary file.
231+
file_path = tmp_path / "jungle_book.txt"
232+
file_path.write_text("dummy content")
233+
234+
await agent.on_upload_for_file_search(str(file_path), cancellation_token)
112235

113236
message = TextMessage(source="user", content="What is the first sentence of the jungle scout book?")
114237
response = await agent.on_messages([message], cancellation_token)
@@ -123,7 +246,14 @@ async def test_file_retrieval(agent: OpenAIAssistantAgent, cancellation_token: C
123246

124247

125248
@pytest.mark.asyncio
126-
async def test_code_interpreter(agent: OpenAIAssistantAgent, cancellation_token: CancellationToken) -> None:
249+
async def test_code_interpreter(
250+
agent: OpenAIAssistantAgent, cancellation_token: CancellationToken, monkeypatch: pytest.MonkeyPatch
251+
) -> None:
252+
# Arrange: For code interpreter, have the messages.list return a result with "x = 1".
253+
agent._client.beta.threads.messages.list = AsyncMock( # type: ignore
254+
return_value=FakeCursorPage([FakeMessage("msg-mock", "x = 1")])
255+
)
256+
127257
message = TextMessage(source="user", content="I need to solve the equation `3x + 11 = 14`. Can you help me?")
128258
response = await agent.on_messages([message], cancellation_token)
129259

@@ -136,33 +266,71 @@ async def test_code_interpreter(agent: OpenAIAssistantAgent, cancellation_token:
136266

137267

138268
@pytest.mark.asyncio
139-
async def test_quiz_creation(agent: OpenAIAssistantAgent, cancellation_token: CancellationToken) -> None:
269+
async def test_quiz_creation(
270+
agent: OpenAIAssistantAgent, cancellation_token: CancellationToken, monkeypatch: pytest.MonkeyPatch
271+
) -> None:
272+
monkeypatch.setattr(DisplayQuizTool, "run_json", DisplayQuizTool.run)
273+
274+
# Create a fake tool call for display_quiz.
275+
fake_tool_call = MagicMock()
276+
fake_tool_call.type = "function"
277+
fake_tool_call.id = "tool-call-1"
278+
fake_tool_call.function = MagicMock()
279+
fake_tool_call.function.name = "display_quiz"
280+
fake_tool_call.function.arguments = (
281+
'{"title": "Quiz Title", "questions": [{"question_text": "What is 2+2?", '
282+
'"question_type": "MULTIPLE_CHOICE", "choices": ["3", "4", "5"]}]}'
283+
)
284+
285+
# Create a run that requires action (tool call).
286+
run_requires_action = MagicMock()
287+
run_requires_action.id = "run-mock"
288+
run_requires_action.status = "requires_action"
289+
run_requires_action.required_action = MagicMock()
290+
run_requires_action.required_action.submit_tool_outputs = MagicMock()
291+
run_requires_action.required_action.submit_tool_outputs.tool_calls = [fake_tool_call]
292+
293+
# Create a completed run for the subsequent retrieval.
294+
run_completed = MagicMock()
295+
run_completed.id = "run-mock"
296+
run_completed.status = "completed"
297+
run_completed.required_action = None
298+
299+
# Set up the beta.threads.runs.retrieve mock to return these in sequence.
300+
agent._client.beta.threads.runs.retrieve.side_effect = [run_requires_action, run_completed] # type: ignore
301+
302+
# Also, set the messages.list call (after run completion) to return a quiz message.
303+
quiz_tool_message = FakeMessage("msg-mock", "Quiz created: Q1) 2+2=? Answer: b) 4; Q2) Free: Sample free response")
304+
agent._client.beta.threads.messages.list = AsyncMock(return_value=FakeCursorPage([quiz_tool_message])) # type: ignore
305+
306+
# Create a user message to trigger the tool invocation.
140307
message = TextMessage(
141308
source="user",
142309
content="Create a short quiz about basic math with one multiple choice question and one free response question.",
143310
)
144311
response = await agent.on_messages([message], cancellation_token)
145312

313+
# Check that the final response has non-empty inner messages (i.e. tool call events).
146314
assert response.chat_message.content is not None
147315
assert isinstance(response.chat_message.content, str)
148316
assert len(response.chat_message.content) > 0
149317
assert isinstance(response.inner_messages, list)
150-
assert any(tool_msg.content for tool_msg in response.inner_messages if hasattr(tool_msg, "content"))
318+
# Ensure that at least one inner message has non-empty content.
319+
assert any(hasattr(tool_msg, "content") and tool_msg.content for tool_msg in response.inner_messages)
151320

152321
await agent.delete_assistant(cancellation_token)
153322

154323

155324
@pytest.mark.asyncio
156325
async def test_on_reset_behavior(client: AsyncAzureOpenAI, cancellation_token: CancellationToken) -> None:
157-
# Create thread with initial message
326+
# Arrange: Use the default behavior for reset.
158327
thread = await client.beta.threads.create()
159328
await client.beta.threads.messages.create(
160329
thread_id=thread.id,
161330
content="Hi, my name is John and I'm a software engineer. Use this information to help me.",
162331
role="user",
163332
)
164333

165-
# Create agent with existing thread
166334
agent = OpenAIAssistantAgent(
167335
name="assistant",
168336
instructions="Help the user with their task.",
@@ -172,19 +340,51 @@ async def test_on_reset_behavior(client: AsyncAzureOpenAI, cancellation_token: C
172340
thread_id=thread.id,
173341
)
174342

175-
# Test before reset
176343
message1 = TextMessage(source="user", content="What is my name?")
177344
response1 = await agent.on_messages([message1], cancellation_token)
178345
assert isinstance(response1.chat_message.content, str)
179346
assert "john" in response1.chat_message.content.lower()
180347

181-
# Reset agent state
182348
await agent.on_reset(cancellation_token)
183349

184-
# Test after reset
185350
message2 = TextMessage(source="user", content="What is my name?")
186351
response2 = await agent.on_messages([message2], cancellation_token)
187352
assert isinstance(response2.chat_message.content, str)
188353
assert "john" in response2.chat_message.content.lower()
189354

190355
await agent.delete_assistant(cancellation_token)
356+
357+
358+
@pytest.mark.asyncio
359+
async def test_save_and_load_state(mock_openai_client: AsyncAzureOpenAI) -> None:
360+
agent = OpenAIAssistantAgent(
361+
name="assistant",
362+
description="Dummy assistant for state testing",
363+
client=mock_openai_client,
364+
model="dummy-model",
365+
instructions="dummy instructions",
366+
tools=[],
367+
)
368+
agent._assistant_id = "assistant-123" # type: ignore
369+
agent._init_thread_id = "thread-456" # type: ignore
370+
agent._initial_message_ids = {"msg1", "msg2"} # type: ignore
371+
agent._vector_store_id = "vector-789" # type: ignore
372+
agent._uploaded_file_ids = ["file-abc", "file-def"] # type: ignore
373+
374+
saved_state = await agent.save_state()
375+
376+
new_agent = OpenAIAssistantAgent(
377+
name="assistant",
378+
description="Dummy assistant for state testing",
379+
client=mock_openai_client,
380+
model="dummy-model",
381+
instructions="dummy instructions",
382+
tools=[],
383+
)
384+
await new_agent.load_state(saved_state)
385+
386+
assert new_agent._assistant_id == "assistant-123" # type: ignore
387+
assert new_agent._init_thread_id == "thread-456" # type: ignore
388+
assert new_agent._initial_message_ids == {"msg1", "msg2"} # type: ignore
389+
assert new_agent._vector_store_id == "vector-789" # type: ignore
390+
assert new_agent._uploaded_file_ids == ["file-abc", "file-def"] # type: ignore

0 commit comments

Comments
 (0)