Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow users to update the final answer prompt of MagenticOne orc. #4476

Merged
merged 4 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ....base import ChatAgent, TerminationCondition
from .._base_group_chat import BaseGroupChat
from ._magentic_one_orchestrator import MagenticOneOrchestrator
from ._prompts import ORCHESTRATOR_FINAL_ANSWER_PROMPT

trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
Expand All @@ -25,6 +26,7 @@ class MagenticOneGroupChat(BaseGroupChat):
Without a termination condition, the group chat will run based on the orchestrator logic or until the maximum number of turns is reached.
max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to 20.
max_stalls (int, optional): The maximum number of stalls allowed before re-planning. Defaults to 3.
final_answer_prompt (str, optional): The LLM prompt used to generate the final answer or response from the team's transcript. A default (sensible for GPT-4o class models) is provided.

Raises:
ValueError: In orchestration logic if progress ledger does not have required keys or if next speaker is not valid.
Expand Down Expand Up @@ -64,6 +66,7 @@ def __init__(
termination_condition: TerminationCondition | None = None,
max_turns: int | None = 20,
max_stalls: int = 3,
final_answer_prompt: str = ORCHESTRATOR_FINAL_ANSWER_PROMPT,
afourney marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__(
participants,
Expand All @@ -77,6 +80,7 @@ def __init__(
raise ValueError("At least one participant is required for MagenticOneGroupChat.")
self._model_client = model_client
self._max_stalls = max_stalls
self._final_answer_prompt = final_answer_prompt
afourney marked this conversation as resolved.
Show resolved Hide resolved

def _create_group_chat_manager_factory(
self,
Expand All @@ -95,5 +99,6 @@ def _create_group_chat_manager_factory(
max_turns,
self._model_client,
self._max_stalls,
self._final_answer_prompt,
termination_condition,
)
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
max_turns: int | None,
model_client: ChatCompletionClient,
max_stalls: int,
final_answer_prompt: str,
termination_condition: TerminationCondition | None,
):
super().__init__(
Expand All @@ -60,6 +61,7 @@ def __init__(
)
self._model_client = model_client
self._max_stalls = max_stalls
self._final_answer_prompt = final_answer_prompt
self._name = "MagenticOneOrchestrator"
self._max_json_retries = 10
self._task = ""
Expand Down Expand Up @@ -95,7 +97,10 @@ def _get_task_ledger_plan_update_prompt(self, team: str) -> str:
return ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT.format(team=team)

def _get_final_answer_prompt(self, task: str) -> str:
return ORCHESTRATOR_FINAL_ANSWER_PROMPT.format(task=task)
if self._final_answer_prompt == ORCHESTRATOR_FINAL_ANSWER_PROMPT:
return ORCHESTRATOR_FINAL_ANSWER_PROMPT.format(task=task)
else:
return self._final_answer_prompt

async def _log_message(self, log_message: str) -> None:
trace_logger.debug(log_message)
Expand Down
Loading