Skip to content

Commit

Permalink
Prompting changes to better support smaller models. (#5386)
Browse files Browse the repository at this point in the history
A series of changes to the
`python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py`
file have been made to better support smaller models.

This includes changes to the prompts, state descriptions, and ordering
of messages.

Regression tasks with OpenAI models shows no change in GAIA scores,
while scores for Llama are significantly improved.
  • Loading branch information
afourney authored Feb 7, 2025
1 parent 3b2bf82 commit 3c30d89
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@
from typing import (
Any,
AsyncGenerator,
BinaryIO,
Dict,
List,
Optional,
Sequence,
cast,
)
from urllib.parse import quote_plus

Expand All @@ -31,6 +29,7 @@
AssistantMessage,
ChatCompletionClient,
LLMMessage,
ModelFamily,
RequestUsage,
SystemMessage,
UserMessage,
Expand All @@ -42,7 +41,6 @@

from ._events import WebSurferEvent
from ._prompts import (
WEB_SURFER_OCR_PROMPT,
WEB_SURFER_QA_PROMPT,
WEB_SURFER_QA_SYSTEM_MESSAGE,
WEB_SURFER_TOOL_PROMPT_MM,
Expand Down Expand Up @@ -444,6 +442,22 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> UserCo
# Clone the messages, removing old screenshots
history: List[LLMMessage] = remove_images(self._chat_history)

# Split the history, removing the last message
if len(history):
user_request = history.pop()
else:
user_request = UserMessage(content="Empty request.", source="user")

# Truncate the history for smaller models
if self._model_client.model_info["family"] not in [
ModelFamily.GPT_4O,
ModelFamily.O1,
ModelFamily.O3,
ModelFamily.GPT_4,
ModelFamily.GPT_35,
]:
history = []

# Ask the page for interactive elements, then prepare the state-of-mark screenshot
rects = await self._playwright_controller.get_interactive_rects(self._page)
viewport = await self._playwright_controller.get_visual_viewport(self._page)
Expand Down Expand Up @@ -499,21 +513,31 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> UserCo
other_targets.extend(self._format_target_list(rects_below, rects))

if len(other_targets) > 0:
if len(other_targets) > 30:
other_targets = other_targets[0:30]
other_targets.append("...")
other_targets_str = (
"Additional valid interaction targets (not shown) include:\n" + "\n".join(other_targets) + "\n\n"
"Additional valid interaction targets include (but are not limited to):\n"
+ "\n".join(other_targets)
+ "\n\n"
)
else:
other_targets_str = ""

state_description = "Your " + await self._get_state_description()
tool_names = "\n".join([t["name"] for t in tools])
page_title = await self._page.title()

prompt_message = None
if self._model_client.model_info["vision"]:
text_prompt = WEB_SURFER_TOOL_PROMPT_MM.format(
url=self._page.url,
state_description=state_description,
visible_targets=visible_targets,
other_targets_str=other_targets_str,
focused_hint=focused_hint,
tool_names=tool_names,
title=page_title,
url=self._page.url,
).strip()

# Scale the screenshot for the MLM, and close the original
Expand All @@ -522,26 +546,42 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> UserCo
if self.to_save_screenshots:
scaled_screenshot.save(os.path.join(self.debug_dir, "screenshot_scaled.png")) # type: ignore

# Add the message
history.append(UserMessage(content=[text_prompt, AGImage.from_pil(scaled_screenshot)], source=self.name))
# Create the message
prompt_message = UserMessage(
content=[re.sub(r"(\n\s*){3,}", "\n\n", text_prompt), AGImage.from_pil(scaled_screenshot)],
source=self.name,
)
else:
visible_text = await self._playwright_controller.get_visible_text(self._page)

text_prompt = WEB_SURFER_TOOL_PROMPT_TEXT.format(
url=self._page.url,
state_description=state_description,
visible_targets=visible_targets,
other_targets_str=other_targets_str,
focused_hint=focused_hint,
tool_names=tool_names,
visible_text=visible_text.strip(),
title=page_title,
url=self._page.url,
).strip()

# Add the message
history.append(UserMessage(content=text_prompt, source=self.name))
# Create the message
prompt_message = UserMessage(content=re.sub(r"(\n\s*){3,}", "\n\n", text_prompt), source=self.name)

history.append(prompt_message)
history.append(user_request)

# {history[-2].content if isinstance(history[-2].content, str) else history[-2].content[0]}
# print(f"""
# ================={len(history)}=================
# {history[-2].content}
# =====
# {history[-1].content}
# ===================================================
# """)

# Make the request
response = await self._model_client.create(
history, tools=tools, extra_create_args={"tool_choice": "auto"}, cancellation_token=cancellation_token
) # , "parallel_tool_calls": False})

self.model_usage.append(response.usage)
message = response.content
self._last_download = None
Expand Down Expand Up @@ -716,23 +756,12 @@ async def _execute_tool(
metadata_hash = hashlib.md5(page_metadata.encode("utf-8")).hexdigest()
if metadata_hash != self._prior_metadata_hash:
page_metadata = (
"\nThe following metadata was extracted from the webpage:\n\n" + page_metadata.strip() + "\n"
"\n\nThe following metadata was extracted from the webpage:\n\n" + page_metadata.strip() + "\n"
)
else:
page_metadata = ""
self._prior_metadata_hash = metadata_hash

# Describe the viewport of the new page in words
viewport = await self._playwright_controller.get_visual_viewport(self._page)
percent_visible = int(viewport["height"] * 100 / viewport["scrollHeight"])
percent_scrolled = int(viewport["pageTop"] * 100 / viewport["scrollHeight"])
if percent_scrolled < 1: # Allow some rounding error
position_text = "at the top of the page"
elif percent_scrolled + percent_visible >= 99: # Allow some rounding error
position_text = "at the bottom of the page"
else:
position_text = str(percent_scrolled) + "% down from the top of the page"

new_screenshot = await self._page.screenshot()
if self.to_save_screenshots:
current_timestamp = "_" + int(time.time()).__str__()
Expand All @@ -748,25 +777,40 @@ async def _execute_tool(
)
)

ocr_text = (
await self._get_ocr_text(new_screenshot, cancellation_token=cancellation_token)
if self.use_ocr is True
else await self._playwright_controller.get_visible_text(self._page)
)

# Return the complete observation
page_title = await self._page.title()
message_content = f"{action_description}\n\n Here is a screenshot of the webpage: [{page_title}]({self._page.url}).\n The viewport shows {percent_visible}% of the webpage, and is positioned {position_text} {page_metadata}\n"
if self.use_ocr:
message_content += f"Automatic OCR of the page screenshot has detected the following text:\n\n{ocr_text}"
else:
message_content += f"The following text is visible in the viewport:\n\n{ocr_text}"
state_description = "The " + await self._get_state_description()
message_content = (
f"{action_description}\n\n" + state_description + page_metadata + "\nHere is a screenshot of the page."
)

return [
message_content,
re.sub(r"(\n\s*){3,}", "\n\n", message_content), # Removing blank lines
AGImage.from_pil(PIL.Image.open(io.BytesIO(new_screenshot))),
]

async def _get_state_description(self) -> str:
assert self._playwright_controller is not None
assert self._page is not None

# Describe the viewport of the new page in words
viewport = await self._playwright_controller.get_visual_viewport(self._page)
percent_visible = int(viewport["height"] * 100 / viewport["scrollHeight"])
percent_scrolled = int(viewport["pageTop"] * 100 / viewport["scrollHeight"])
if percent_scrolled < 1: # Allow some rounding error
position_text = "at the top of the page"
elif percent_scrolled + percent_visible >= 99: # Allow some rounding error
position_text = "at the bottom of the page"
else:
position_text = str(percent_scrolled) + "% down from the top of the page"

visible_text = await self._playwright_controller.get_visible_text(self._page)

# Return the complete observation
page_title = await self._page.title()
message_content = f"web browser is open to the page [{page_title}]({self._page.url}).\nThe viewport shows {percent_visible}% of the webpage, and is positioned {position_text}\n"
message_content += f"The following text is visible in the viewport:\n\n{visible_text}"
return message_content

def _target_name(self, target: str, rects: Dict[str, InteractiveRegion]) -> str | None:
try:
return rects[target]["aria_name"].strip()
Expand Down Expand Up @@ -798,38 +842,6 @@ def _format_target_list(self, ids: List[str], rects: Dict[str, InteractiveRegion

return targets

async def _get_ocr_text(
self, image: bytes | io.BufferedIOBase | PIL.Image.Image, cancellation_token: Optional[CancellationToken] = None
) -> str:
scaled_screenshot = None
if isinstance(image, PIL.Image.Image):
scaled_screenshot = image.resize((self.MLM_WIDTH, self.MLM_HEIGHT))
else:
pil_image = None
if not isinstance(image, io.BufferedIOBase):
pil_image = PIL.Image.open(io.BytesIO(image))
else:
pil_image = PIL.Image.open(cast(BinaryIO, image))
scaled_screenshot = pil_image.resize((self.MLM_WIDTH, self.MLM_HEIGHT))
pil_image.close()

# Add the multimodal message and make the request
messages: List[LLMMessage] = []
messages.append(
UserMessage(
content=[
WEB_SURFER_OCR_PROMPT,
AGImage.from_pil(scaled_screenshot),
],
source=self.name,
)
)
response = await self._model_client.create(messages, cancellation_token=cancellation_token)
self.model_usage.append(response.usage)
scaled_screenshot.close()
assert isinstance(response.content, str)
return response.content

async def _summarize_page(
self,
question: str | None = None,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,43 +1,42 @@
WEB_SURFER_TOOL_PROMPT_MM = """
Consider the following screenshot of a web browser, which is open to the page '{url}'. In this screenshot, interactive elements are outlined in bounding boxes of different colors. Each bounding box has a numeric ID label in the same color. Additional information about each visible label is listed below:
{state_description}
Consider the following screenshot of the page. In this screenshot, interactive elements are outlined in bounding boxes of different colors. Each bounding box has a numeric ID label in the same color. Additional information about each visible label is listed below:
{visible_targets}{other_targets_str}{focused_hint}
You are to respond to the most recent request by selecting an appropriate tool from the following set, or by answering the question directly if possible without tools:
You are to respond to my next request by selecting an appropriate tool from the following set, or by answering the question directly if possible:
{tool_names}
When deciding between tools, consider if the request can be best addressed by:
- the contents of the current viewport (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element might be most appropriate)
- contents found elsewhere on the full webpage (in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate)
- on some other website entirely (in which case actions like performing a new web search might be the best option)
- the contents of the CURRENT VIEWPORT (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element, might be more appropriate)
- contents found elsewhere on the CURRENT WEBPAGE [{title}]({url}), in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate
- on ANOTHER WEBSITE entirely (in which case actions like performing a new web search might be the best option)
My request follows:
"""

WEB_SURFER_TOOL_PROMPT_TEXT = """
Your web browser is open to the page '{url}'. The following text is visible in the viewport:
```
{visible_text}
```
{state_description}
You have also identified the following interactive components:
{visible_targets}{other_targets_str}{focused_hint}
You are to respond to the most recent request by selecting an appropriate tool from the following set, or by answering the question directly if possible without tools:
You are to respond to my next request by selecting an appropriate tool from the following set, or by answering the question directly if possible:
{tool_names}
When deciding between tools, consider if the request can be best addressed by:
- the contents of the current viewport (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element might be most appropriate)
- contents found elsewhere on the full webpage (in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate)
- on some other website entirely (in which case actions like performing a new web search might be the best option)
"""
- the contents of the CURRENT VIEWPORT (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element, might be more appropriate)
- contents found elsewhere on the CURRENT WEBPAGE [{title}]({url}), in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate
- on ANOTHER WEBSITE entirely (in which case actions like performing a new web search might be the best option)
WEB_SURFER_OCR_PROMPT = """
Please transcribe all visible text on this page, including both main content and the labels of UI elements.
My request follows:
"""


WEB_SURFER_QA_SYSTEM_MESSAGE = """
You are a helpful assistant that can summarize long documents to answer question.
"""
Expand Down
2 changes: 1 addition & 1 deletion python/packages/autogen-ext/tests/test_websurfer_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ async def test_run_websurfer(monkeypatch: pytest.MonkeyPatch) -> None:
result.messages[2] # type: ignore
.content[0] # type: ignore
.startswith( # type: ignore
"I am waiting a short period of time before taking further action.\n\n Here is a screenshot of the webpage:"
"I am waiting a short period of time before taking further action."
)
) # type: ignore
url_after_sleep = agent._page.url # type: ignore
Expand Down

0 comments on commit 3c30d89

Please sign in to comment.