Skip to content

Commit a3d2bb7

Browse files
Jasmin3qMin Shi
and
Min Shi
authored
[Executor] Support node_concurrency when running an async flow (#1698)
# Description Use asyncio.semaphore to support concurrency control. When node_concurrency is set to 2: sync_passthrough1 and async_passthrough1 can run concurrently: ![1705287434459](https://github.com/microsoft/promptflow/assets/39176492/487f1d03-34fa-4016-bf01-ff31d64d2c3f) When node_concurrency is set to 1: Only one of sync_passthrough1 and async_passthrough1 can at the same time: ![1705288060087](https://github.com/microsoft/promptflow/assets/39176492/9ce31cd9-4af4-4451-a898-06cccf5c2321) --------- Co-authored-by: Min Shi <[email protected]>
1 parent 5053232 commit a3d2bb7

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

src/promptflow/promptflow/executor/_async_nodes_scheduler.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def __init__(
3434
node_concurrency: int,
3535
) -> None:
3636
self._tools_manager = tools_manager
37-
# TODO: Add concurrency control in execution
3837
self._node_concurrency = node_concurrency
3938
self._task_start_time = {}
4039
self._task_last_log_time = {}
@@ -55,7 +54,9 @@ async def execute(
5554
"Current thread is not main thread, skip signal handler registration in AsyncNodesScheduler."
5655
)
5756

57+
# Semaphore should be created in the loop, otherwise it will not work.
5858
loop = asyncio.get_running_loop()
59+
self._semaphore = asyncio.Semaphore(self._node_concurrency, loop=loop)
5960
monitor = threading.Thread(
6061
target=monitor_long_running_coroutine,
6162
args=(loop, self._task_start_time, self._task_last_log_time, self._dag_manager_completed_event),
@@ -129,6 +130,10 @@ def _execute_nodes(
129130
self._create_node_task(node, dag_manager, context, executor): node for node in dag_manager.pop_ready_nodes()
130131
}
131132

133+
async def run_task_with_semaphore(self, coroutine):
134+
async with self._semaphore:
135+
return await coroutine
136+
132137
def _create_node_task(
133138
self,
134139
node: Node,
@@ -139,12 +144,17 @@ def _create_node_task(
139144
f = self._tools_manager.get_tool(node.name)
140145
kwargs = dag_manager.get_node_valid_inputs(node, f)
141146
if inspect.iscoroutinefunction(f):
147+
# For async task, it will not be executed before calling create_task.
142148
task = context.invoke_tool_async(node, f, kwargs)
143149
else:
150+
# For sync task, convert it to async task and run it in executor thread.
151+
# Even though the task is put to the thread pool, thread.start will only be triggered after create_task.
144152
task = self._sync_function_to_async_task(executor, context, node, f, kwargs)
145153
# Set the name of the task to the node name for debugging purpose
146154
# It does not need to be unique by design.
147-
return asyncio.create_task(task, name=node.name)
155+
# Wrap the coroutine in a task with asyncio.create_task to schedule it for event loop execution
156+
# The task is created and added to the event loop, but the exact execution depends on loop's scheduling
157+
return asyncio.create_task(self.run_task_with_semaphore(task), name=node.name)
148158

149159
@staticmethod
150160
async def _sync_function_to_async_task(
@@ -154,6 +164,7 @@ async def _sync_function_to_async_task(
154164
f,
155165
kwargs,
156166
):
167+
# The task will not be executed before calling create_task.
157168
return await asyncio.get_running_loop().run_in_executor(executor, context.invoke_tool, node, f, kwargs)
158169

159170

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import os
2+
import pytest
3+
from promptflow.executor import FlowExecutor
4+
from ..utils import get_flow_folder, get_yaml_file
5+
6+
7+
@pytest.mark.e2etest
8+
class TestAsync:
9+
@pytest.mark.parametrize(
10+
"folder_name, concurrency_levels, expected_concurrency",
11+
[
12+
("async_tools", [1, 2, 3], [1, 2, 2]),
13+
("async_tools_with_sync_tools", [1, 2, 3], [1, 2, 2]),
14+
],
15+
)
16+
def test_executor_node_concurrency(self, folder_name, concurrency_levels, expected_concurrency):
17+
os.chdir(get_flow_folder(folder_name))
18+
executor = FlowExecutor.create(get_yaml_file(folder_name), {})
19+
20+
def calculate_max_concurrency(flow_result):
21+
timeline = []
22+
api_calls = flow_result.run_info.api_calls[0]["children"]
23+
for api_call in api_calls:
24+
timeline.append(("start", api_call["start_time"]))
25+
timeline.append(("end", api_call["end_time"]))
26+
timeline.sort(key=lambda x: x[1])
27+
current_concurrency = 0
28+
max_concurrency = 0
29+
for event, _ in timeline:
30+
if event == "start":
31+
current_concurrency += 1
32+
max_concurrency = max(max_concurrency, current_concurrency)
33+
elif event == "end":
34+
current_concurrency -= 1
35+
return max_concurrency
36+
37+
for i in range(len(concurrency_levels)):
38+
concurrency = concurrency_levels[i]
39+
flow_result = executor.exec_line({"input_str": "Hello"}, node_concurrency=concurrency)
40+
max_concurrency = calculate_max_concurrency(flow_result)
41+
assert max_concurrency == expected_concurrency[i]
42+
assert max_concurrency <= concurrency

0 commit comments

Comments
 (0)