Skip to content

Commit

Permalink
Allow users to update the final answer prompt of MagenticOne orc. (#4476
Browse files Browse the repository at this point in the history
)

* Allow users to update the final answer prompt of MagenticOne orchestrator.
  • Loading branch information
afourney authored Dec 3, 2024
1 parent 79c5aaa commit 934ae03
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ....base import ChatAgent, TerminationCondition
from .._base_group_chat import BaseGroupChat
from ._magentic_one_orchestrator import MagenticOneOrchestrator
from ._prompts import ORCHESTRATOR_FINAL_ANSWER_PROMPT

trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
Expand All @@ -25,6 +26,7 @@ class MagenticOneGroupChat(BaseGroupChat):
Without a termination condition, the group chat will run based on the orchestrator logic or until the maximum number of turns is reached.
max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to 20.
max_stalls (int, optional): The maximum number of stalls allowed before re-planning. Defaults to 3.
final_answer_prompt (str, optional): The LLM prompt used to generate the final answer or response from the team's transcript. A default (sensible for GPT-4o class models) is provided.
Raises:
ValueError: In orchestration logic if progress ledger does not have required keys or if next speaker is not valid.
Expand Down Expand Up @@ -64,6 +66,7 @@ def __init__(
termination_condition: TerminationCondition | None = None,
max_turns: int | None = 20,
max_stalls: int = 3,
final_answer_prompt: str = ORCHESTRATOR_FINAL_ANSWER_PROMPT,
):
super().__init__(
participants,
Expand All @@ -77,6 +80,7 @@ def __init__(
raise ValueError("At least one participant is required for MagenticOneGroupChat.")
self._model_client = model_client
self._max_stalls = max_stalls
self._final_answer_prompt = final_answer_prompt

def _create_group_chat_manager_factory(
self,
Expand All @@ -95,5 +99,6 @@ def _create_group_chat_manager_factory(
max_turns,
self._model_client,
self._max_stalls,
self._final_answer_prompt,
termination_condition,
)
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
max_turns: int | None,
model_client: ChatCompletionClient,
max_stalls: int,
final_answer_prompt: str,
termination_condition: TerminationCondition | None,
):
super().__init__(
Expand All @@ -60,6 +61,7 @@ def __init__(
)
self._model_client = model_client
self._max_stalls = max_stalls
self._final_answer_prompt = final_answer_prompt
self._name = "MagenticOneOrchestrator"
self._max_json_retries = 10
self._task = ""
Expand Down Expand Up @@ -95,7 +97,10 @@ def _get_task_ledger_plan_update_prompt(self, team: str) -> str:
return ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT.format(team=team)

def _get_final_answer_prompt(self, task: str) -> str:
return ORCHESTRATOR_FINAL_ANSWER_PROMPT.format(task=task)
if self._final_answer_prompt == ORCHESTRATOR_FINAL_ANSWER_PROMPT:
return ORCHESTRATOR_FINAL_ANSWER_PROMPT.format(task=task)
else:
return self._final_answer_prompt

async def _log_message(self, log_message: str) -> None:
trace_logger.debug(log_message)
Expand Down

0 comments on commit 934ae03

Please sign in to comment.