Skip to content

Commit a6612e6

Browse files
authored
Added tests for FileSurfer. (microsoft#4913)
1 parent e11fd83 commit a6612e6

File tree

2 files changed

+442
-316
lines changed

2 files changed

+442
-316
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import asyncio
2+
import json
3+
import logging
4+
import os
5+
from datetime import datetime
6+
from typing import Any, AsyncGenerator, List
7+
8+
import aiofiles
9+
import pytest
10+
from autogen_agentchat import EVENT_LOGGER_NAME
11+
from autogen_ext.agents.file_surfer import FileSurfer
12+
from autogen_ext.models.openai import OpenAIChatCompletionClient
13+
from openai.resources.chat.completions import AsyncCompletions
14+
from openai.types.chat.chat_completion import ChatCompletion, Choice
15+
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
16+
from openai.types.chat.chat_completion_message import ChatCompletionMessage
17+
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
18+
from openai.types.completion_usage import CompletionUsage
19+
from pydantic import BaseModel
20+
21+
22+
class FileLogHandler(logging.Handler):
23+
def __init__(self, filename: str) -> None:
24+
super().__init__()
25+
self.filename = filename
26+
self.file_handler = logging.FileHandler(filename)
27+
28+
def emit(self, record: logging.LogRecord) -> None:
29+
ts = datetime.fromtimestamp(record.created).isoformat()
30+
if isinstance(record.msg, BaseModel):
31+
record.msg = json.dumps(
32+
{
33+
"timestamp": ts,
34+
"message": record.msg.model_dump(),
35+
"type": record.msg.__class__.__name__,
36+
},
37+
)
38+
self.file_handler.emit(record)
39+
40+
41+
class _MockChatCompletion:
42+
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
43+
self._saved_chat_completions = chat_completions
44+
self._curr_index = 0
45+
46+
async def mock_create(
47+
self, *args: Any, **kwargs: Any
48+
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
49+
await asyncio.sleep(0.1)
50+
completion = self._saved_chat_completions[self._curr_index]
51+
self._curr_index += 1
52+
return completion
53+
54+
55+
logger = logging.getLogger(EVENT_LOGGER_NAME)
56+
logger.setLevel(logging.DEBUG)
57+
logger.addHandler(FileLogHandler("test_filesurfer_agent.log"))
58+
59+
60+
@pytest.mark.asyncio
61+
async def test_run_filesurfer(monkeypatch: pytest.MonkeyPatch) -> None:
62+
# Create a test file
63+
test_file = os.path.abspath("test_filesurfer_agent.html")
64+
async with aiofiles.open(test_file, "wt") as file:
65+
await file.write("""<html>
66+
<head>
67+
<title>FileSurfer test file</title>
68+
</head>
69+
<body>
70+
<h1>FileSurfer test H1</h1>
71+
<p>FileSurfer test body</p>
72+
</body>
73+
</html>""")
74+
75+
# Mock the API calls
76+
model = "gpt-4o-2024-05-13"
77+
chat_completions = [
78+
ChatCompletion(
79+
id="id1",
80+
choices=[
81+
Choice(
82+
finish_reason="tool_calls",
83+
index=0,
84+
message=ChatCompletionMessage(
85+
content=None,
86+
tool_calls=[
87+
ChatCompletionMessageToolCall(
88+
id="1",
89+
type="function",
90+
function=Function(
91+
name="open_path",
92+
arguments=json.dumps({"path": test_file}),
93+
),
94+
)
95+
],
96+
role="assistant",
97+
),
98+
)
99+
],
100+
created=0,
101+
model=model,
102+
object="chat.completion",
103+
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
104+
),
105+
ChatCompletion(
106+
id="id2",
107+
choices=[
108+
Choice(
109+
finish_reason="tool_calls",
110+
index=0,
111+
message=ChatCompletionMessage(
112+
content=None,
113+
tool_calls=[
114+
ChatCompletionMessageToolCall(
115+
id="1",
116+
type="function",
117+
function=Function(
118+
name="open_path",
119+
arguments=json.dumps({"path": os.path.dirname(test_file)}),
120+
),
121+
)
122+
],
123+
role="assistant",
124+
),
125+
)
126+
],
127+
created=0,
128+
model=model,
129+
object="chat.completion",
130+
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
131+
),
132+
]
133+
mock = _MockChatCompletion(chat_completions)
134+
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
135+
agent = FileSurfer(
136+
"FileSurfer",
137+
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
138+
)
139+
140+
# Get the FileSurfer to read the file, and the directory
141+
assert agent._name == "FileSurfer" # pyright: ignore[reportPrivateUsage]
142+
result = await agent.run(task="Please read the test file")
143+
assert "# FileSurfer test H1" in result.messages[1].content
144+
145+
result = await agent.run(task="Please read the test directory")
146+
assert "# Index of " in result.messages[1].content
147+
assert "test_filesurfer_agent.html" in result.messages[1].content

0 commit comments

Comments
 (0)