From a3d2bb7bc57ecea1893e1960ae66b6d1cfcdc8ad Mon Sep 17 00:00:00 2001 From: Min Shi <39176492+Jasmin3q@users.noreply.github.com> Date: Fri, 19 Jan 2024 12:15:52 +0800 Subject: [PATCH] [Executor] Support node_concurrency when running an async flow (#1698) # Description Use asyncio.semaphore to support concurrency control. When node_concurrency is set to 2: sync_passthrough1 and async_passthrough1 can run concurrently: ![1705287434459](https://github.com/microsoft/promptflow/assets/39176492/487f1d03-34fa-4016-bf01-ff31d64d2c3f) When node_concurrency is set to 1: Only one of sync_passthrough1 and async_passthrough1 can at the same time: ![1705288060087](https://github.com/microsoft/promptflow/assets/39176492/9ce31cd9-4af4-4451-a898-06cccf5c2321) --------- Co-authored-by: Min Shi --- .../executor/_async_nodes_scheduler.py | 15 ++++++- .../tests/executor/e2etests/test_async.py | 42 +++++++++++++++++++ 2 files changed, 55 insertions(+), 2 deletions(-) create mode 100644 src/promptflow/tests/executor/e2etests/test_async.py diff --git a/src/promptflow/promptflow/executor/_async_nodes_scheduler.py b/src/promptflow/promptflow/executor/_async_nodes_scheduler.py index 8c6c781011c..c308937545b 100644 --- a/src/promptflow/promptflow/executor/_async_nodes_scheduler.py +++ b/src/promptflow/promptflow/executor/_async_nodes_scheduler.py @@ -34,7 +34,6 @@ def __init__( node_concurrency: int, ) -> None: self._tools_manager = tools_manager - # TODO: Add concurrency control in execution self._node_concurrency = node_concurrency self._task_start_time = {} self._task_last_log_time = {} @@ -55,7 +54,9 @@ async def execute( "Current thread is not main thread, skip signal handler registration in AsyncNodesScheduler." ) + # Semaphore should be created in the loop, otherwise it will not work. loop = asyncio.get_running_loop() + self._semaphore = asyncio.Semaphore(self._node_concurrency, loop=loop) monitor = threading.Thread( target=monitor_long_running_coroutine, args=(loop, self._task_start_time, self._task_last_log_time, self._dag_manager_completed_event), @@ -129,6 +130,10 @@ def _execute_nodes( self._create_node_task(node, dag_manager, context, executor): node for node in dag_manager.pop_ready_nodes() } + async def run_task_with_semaphore(self, coroutine): + async with self._semaphore: + return await coroutine + def _create_node_task( self, node: Node, @@ -139,12 +144,17 @@ def _create_node_task( f = self._tools_manager.get_tool(node.name) kwargs = dag_manager.get_node_valid_inputs(node, f) if inspect.iscoroutinefunction(f): + # For async task, it will not be executed before calling create_task. task = context.invoke_tool_async(node, f, kwargs) else: + # For sync task, convert it to async task and run it in executor thread. + # Even though the task is put to the thread pool, thread.start will only be triggered after create_task. task = self._sync_function_to_async_task(executor, context, node, f, kwargs) # Set the name of the task to the node name for debugging purpose # It does not need to be unique by design. - return asyncio.create_task(task, name=node.name) + # Wrap the coroutine in a task with asyncio.create_task to schedule it for event loop execution + # The task is created and added to the event loop, but the exact execution depends on loop's scheduling + return asyncio.create_task(self.run_task_with_semaphore(task), name=node.name) @staticmethod async def _sync_function_to_async_task( @@ -154,6 +164,7 @@ async def _sync_function_to_async_task( f, kwargs, ): + # The task will not be executed before calling create_task. return await asyncio.get_running_loop().run_in_executor(executor, context.invoke_tool, node, f, kwargs) diff --git a/src/promptflow/tests/executor/e2etests/test_async.py b/src/promptflow/tests/executor/e2etests/test_async.py new file mode 100644 index 00000000000..1a2db96e8ac --- /dev/null +++ b/src/promptflow/tests/executor/e2etests/test_async.py @@ -0,0 +1,42 @@ +import os +import pytest +from promptflow.executor import FlowExecutor +from ..utils import get_flow_folder, get_yaml_file + + +@pytest.mark.e2etest +class TestAsync: + @pytest.mark.parametrize( + "folder_name, concurrency_levels, expected_concurrency", + [ + ("async_tools", [1, 2, 3], [1, 2, 2]), + ("async_tools_with_sync_tools", [1, 2, 3], [1, 2, 2]), + ], + ) + def test_executor_node_concurrency(self, folder_name, concurrency_levels, expected_concurrency): + os.chdir(get_flow_folder(folder_name)) + executor = FlowExecutor.create(get_yaml_file(folder_name), {}) + + def calculate_max_concurrency(flow_result): + timeline = [] + api_calls = flow_result.run_info.api_calls[0]["children"] + for api_call in api_calls: + timeline.append(("start", api_call["start_time"])) + timeline.append(("end", api_call["end_time"])) + timeline.sort(key=lambda x: x[1]) + current_concurrency = 0 + max_concurrency = 0 + for event, _ in timeline: + if event == "start": + current_concurrency += 1 + max_concurrency = max(max_concurrency, current_concurrency) + elif event == "end": + current_concurrency -= 1 + return max_concurrency + + for i in range(len(concurrency_levels)): + concurrency = concurrency_levels[i] + flow_result = executor.exec_line({"input_str": "Hello"}, node_concurrency=concurrency) + max_concurrency = calculate_max_concurrency(flow_result) + assert max_concurrency == expected_concurrency[i] + assert max_concurrency <= concurrency