Skip to content

Commit

Permalink
[Executor] Refine executor logic to support async generator in flow t…
Browse files Browse the repository at this point in the history
…est (#3083)

# Description

Refine exec_line logic to support async generator in flow test.

This pull request primarily focuses on improving the handling of
generators and asynchronous generators in the `promptflow` package. The
changes ensure that the code can handle both synchronous and
asynchronous generators, improve the handling of generator outputs, and
refactor the way nodes are executed in the flow.

Here are the key changes:

Improved handling of generators and asynchronous generators:

*
[`src/promptflow-core/promptflow/_core/run_tracker.py`](diffhunk://#diff-a5027d19a24cb28a68ead16dfe6c54492c78d6e0e7640e80533928808cdb3422R6-L9):
The `inspect` module was imported to replace the use of `GeneratorType`
for checking if a value is a generator. The method
`update_and_persist_generator_node_runs` was introduced to replace
`persist_selected_node_runs`, and it now updates the output of the node
run with the output in the trace before persisting it.
[[1]](diffhunk://#diff-a5027d19a24cb28a68ead16dfe6c54492c78d6e0e7640e80533928808cdb3422R6-L9)
[[2]](diffhunk://#diff-a5027d19a24cb28a68ead16dfe6c54492c78d6e0e7640e80533928808cdb3422L299-R299)
[[3]](diffhunk://#diff-a5027d19a24cb28a68ead16dfe6c54492c78d6e0e7640e80533928808cdb3422L433-R450)
*
[`src/promptflow-core/promptflow/_utils/run_tracker_utils.py`](diffhunk://#diff-cc2845177424c6393b16b91ff5a7753eaf73aa52c4c52c53c8f83eb68746815cR4-R7):
The `inspect` module was imported, and the method
`_deep_copy_and_extract_items_from_generator_proxy` was updated to
handle `AsyncGeneratorProxy` and to convert generators to strings to
avoid deepcopy errors.
[[1]](diffhunk://#diff-cc2845177424c6393b16b91ff5a7753eaf73aa52c4c52c53c8f83eb68746815cR4-R7)
[[2]](diffhunk://#diff-cc2845177424c6393b16b91ff5a7753eaf73aa52c4c52c53c8f83eb68746815cL21-R25)

Refactoring of node execution:

*
[`src/promptflow-core/promptflow/executor/flow_executor.py`](diffhunk://#diff-bec06607cb28fd791b8ed11bb488979344ca342be5f1c67ba6dd663d5e12240fR63):
The `ThreadPoolExecutorWithContext` was imported from
`promptflow.tracing`. In the `exec_line` method, a check was added to
use `exec_line_async` when the tools are async. The `exec_line_async`
method was updated to include a `line_timeout_sec` parameter. The
`_exec_inner_with_trace_async` method was updated to use
`_stringify_generator_output_async` to handle async generator output.
The `_exec_post_process` method was updated to use
`update_and_persist_generator_node_runs` instead of
`persist_selected_node_runs`. The `_should_use_async` method was updated
to check if any tool is async. The `_traverse_nodes_async` method was
updated to use an async scheduler. The methods `_merge_async_generator`,
`_stringify_generator_output_async`, and `_merge_generator` were added
to handle generator outputs. The `_submit_to_scheduler` method was
updated to only use the thread pool mode.
[[1]](diffhunk://#diff-bec06607cb28fd791b8ed11bb488979344ca342be5f1c67ba6dd663d5e12240fR63)
[[2]](diffhunk://#diff-bec06607cb28fd791b8ed11bb488979344ca342be5f1c67ba6dd663d5e12240fR693-R699)
[[3]](diffhunk://#diff-bec06607cb28fd791b8ed11bb488979344ca342be5f1c67ba6dd663d5e12240fR731)
[[4]](diffhunk://#diff-bec06607cb28fd791b8ed11bb488979344ca342be5f1c67ba6dd663d5e12240fR751-R752)
[[5]](diffhunk://#diff-bec06607cb28fd791b8ed11bb488979344ca342be5f1c67ba6dd663d5e12240fL869-R880)
[[6]](diffhunk://#diff-bec06607cb28fd791b8ed11bb488979344ca342be5f1c67ba6dd663d5e12240fL913-R926)
[[7]](diffhunk://#diff-bec06607cb28fd791b8ed11bb488979344ca342be5f1c67ba6dd663d5e12240fR1158-R1212)
[[8]](diffhunk://#diff-bec06607cb28fd791b8ed11bb488979344ca342be5f1c67ba6dd663d5e12240fL1186-L1190)

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.

---------

Co-authored-by: Heyi <[email protected]>
  • Loading branch information
thy09 and Heyi authored Apr 30, 2024
1 parent bd82aa6 commit 35ed6c5
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 47 deletions.
18 changes: 10 additions & 8 deletions src/promptflow-core/promptflow/_core/run_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
# ---------------------------------------------------------

import asyncio
import inspect
import json
from contextvars import ContextVar
from datetime import datetime, timezone
from types import GeneratorType
from typing import Any, Dict, List, Mapping, Optional, Union

from promptflow._constants import MessageFormatType
Expand Down Expand Up @@ -296,7 +296,7 @@ def end_run(
def _ensure_serializable_value(self, val, warning_msg: Optional[str] = None):
if ConnectionType.is_connection_value(val):
return ConnectionType.serialize_conn(val)
if self.allow_generator_types and isinstance(val, GeneratorType):
if inspect.isgenerator(val) or inspect.isasyncgen(val):
return str(val)
try:
json.dumps(val, default=default_json_encoder)
Expand Down Expand Up @@ -430,22 +430,24 @@ def get_run(self, run_id):
def persist_node_run(self, run_info: RunInfo):
self._storage.persist_node_run(run_info)

def persist_selected_node_runs(self, run_info: FlowRunInfo, node_names: List[str]):
def update_and_persist_generator_node_runs(self, run_id: str, node_names: List[str]):
"""
Persists the node runs for the specified node names.
Persists the node runs for nodes producing generators.
:param run_info: The flow run information.
:type run_info: FlowRunInfo
:param run_id: The ID of the flow run.
:type run_id: str
:param node_names: The names of the nodes to persist.
:type node_names: List[str]
:returns: None
"""
run_id = run_info.run_id

selected_node_run_info = (
run_info for run_info in self.collect_child_node_runs(run_id) if run_info.node in node_names
)
for node_run_info in selected_node_run_info:
# Update the output of the node run with the output in the trace.
# This is because the output in the trace would includes the generated items.
output_in_trace = node_run_info.api_calls[0]["output"]
node_run_info.output = output_in_trace
self.persist_node_run(node_run_info)

def persist_flow_run(self, run_info: FlowRunInfo):
Expand Down
5 changes: 1 addition & 4 deletions src/promptflow-core/promptflow/_utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@
# ---------------------------------------------------------

import asyncio
import contextvars
import functools
from concurrent.futures import ThreadPoolExecutor

from promptflow._utils.utils import set_context
from promptflow.tracing import ThreadPoolExecutorWithContext


Expand Down Expand Up @@ -36,7 +33,7 @@ def async_run_allowing_running_loop(async_func, *args, **kwargs):
event loop, we run _exec_batch in a new thread; otherwise, we run it in the current thread.
"""
if _has_running_loop():
with ThreadPoolExecutor(1, initializer=set_context, initargs=(contextvars.copy_context(),)) as executor:
with ThreadPoolExecutorWithContext() as executor:
return executor.submit(lambda: asyncio.run(async_func(*args, **kwargs))).result()
else:
return asyncio.run(async_func(*args, **kwargs))
Expand Down
7 changes: 5 additions & 2 deletions src/promptflow-core/promptflow/_utils/run_tracker_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import inspect
from copy import deepcopy

from promptflow.tracing.contracts.generator_proxy import GeneratorProxy
from promptflow.tracing.contracts.generator_proxy import AsyncGeneratorProxy, GeneratorProxy


def _deep_copy_and_extract_items_from_generator_proxy(value: object) -> object:
Expand All @@ -18,6 +19,8 @@ def _deep_copy_and_extract_items_from_generator_proxy(value: object) -> object:
return [_deep_copy_and_extract_items_from_generator_proxy(v) for v in value]
elif isinstance(value, dict):
return {k: _deep_copy_and_extract_items_from_generator_proxy(v) for k, v in value.items()}
elif isinstance(value, GeneratorProxy):
elif isinstance(value, (GeneratorProxy, AsyncGeneratorProxy)):
return deepcopy(value.items)
elif inspect.isgenerator(value) or inspect.isasyncgen(value):
return str(value) # Convert generator to string to avoid deepcopy error
return deepcopy(value)
99 changes: 66 additions & 33 deletions src/promptflow-core/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from promptflow._core.run_tracker import RunTracker
from promptflow._core.tool import STREAMING_OPTION_PARAMETER_ATTR
from promptflow._core.tools_manager import ToolsManager
from promptflow._utils.async_utils import async_run_allowing_running_loop
from promptflow._utils.context_utils import _change_working_dir
from promptflow._utils.execution_utils import (
apply_default_value_for_input,
Expand Down Expand Up @@ -60,6 +61,7 @@
from promptflow.executor.flow_validator import FlowValidator
from promptflow.storage import AbstractRunStorage
from promptflow.storage._run_storage import DefaultRunStorage
from promptflow.tracing import ThreadPoolExecutorWithContext
from promptflow.tracing._integrations._openai_injector import inject_openai_api
from promptflow.tracing._operation_context import OperationContext
from promptflow.tracing._start_trace import setup_exporter_from_environ
Expand Down Expand Up @@ -693,24 +695,33 @@ def exec_line(
:return: The result of executing the line.
:rtype: ~promptflow.executor._result.LineResult
"""
if self._should_use_async():
# Use async exec_line when the tools are async
return async_run_allowing_running_loop(
self.exec_line_async,
inputs,
index,
run_id,
validate_inputs,
node_concurrency,
allow_generator_output,
line_timeout_sec,
)
# TODO: Call exec_line_async in exec_line when async is mature.
self._node_concurrency = node_concurrency
# TODO: Pass line_timeout_sec to flow node scheduler instead of updating self._line_timeout_sec
self._line_timeout_sec = line_timeout_sec or self._line_timeout_sec
inputs = apply_default_value_for_input(self._flow.inputs, inputs)
# For flow run, validate inputs as default
with self._run_tracker.node_log_manager:
# exec_line interface may be called when executing a batch run, so we only set run_mode as flow run when
# it is not set.
run_id = run_id or str(uuid.uuid4())
with self._update_operation_context(run_id, index):
line_result = self._exec(
inputs,
run_id=run_id,
line_number=index,
validate_inputs=validate_inputs,
allow_generator_output=allow_generator_output,
)
run_id = run_id or str(uuid.uuid4())
with self._run_tracker.node_log_manager, self._update_operation_context(run_id, index):
line_result = self._exec(
inputs,
run_id=run_id,
line_number=index,
validate_inputs=validate_inputs,
allow_generator_output=allow_generator_output,
)
# Return line result with index
if index is not None and isinstance(line_result.output, dict):
line_result.output[LINE_NUMBER_KEY] = index
Expand All @@ -724,6 +735,7 @@ async def exec_line_async(
validate_inputs: bool = True,
node_concurrency=DEFAULT_CONCURRENCY_FLOW,
allow_generator_output: bool = False,
line_timeout_sec: Optional[int] = None,
) -> LineResult:
"""Execute a single line of the flow.
Expand All @@ -743,13 +755,12 @@ async def exec_line_async(
:rtype: ~promptflow.executor._result.LineResult
"""
self._node_concurrency = node_concurrency
# TODO: Pass line_timeout_sec to flow node scheduler instead of updating self._line_timeout_sec
self._line_timeout_sec = line_timeout_sec or self._line_timeout_sec
inputs = apply_default_value_for_input(self._flow.inputs, inputs)
# For flow run, validate inputs as default
with self._run_tracker.node_log_manager:
# exec_line interface may be called when executing a batch run, so we only set run_mode as flow run when
# it is not set.
operation_context = OperationContext.get_instance()
operation_context.run_mode = operation_context.get("run_mode", None) or RunMode.Test.name
run_id = run_id or str(uuid.uuid4())
with self._run_tracker.node_log_manager, self._update_operation_context(run_id, index):
line_result = await self._exec_async(
inputs,
run_id=run_id,
Expand Down Expand Up @@ -869,8 +880,7 @@ async def _exec_inner_with_trace_async(
):
with self._start_flow_span(inputs) as span, self._record_cancellation_exceptions_to_span(span):
output, nodes_outputs = await self._traverse_nodes_async(inputs, context)
# TODO: Also stringify async generator output
output = self._stringify_generator_output(output) if not stream else output
output = await self._stringify_generator_output_async(output) if not stream else output
self._exec_post_process(inputs, output, nodes_outputs, run_info, run_tracker, span, stream)
return output, extract_aggregation_inputs(self._flow, nodes_outputs)

Expand Down Expand Up @@ -914,9 +924,9 @@ def _exec_post_process(
for nodename, output in nodes_outputs.items()
if isinstance(output, GeneratorType) or isinstance(output, AsyncGeneratorType)
]
run_tracker.persist_selected_node_runs(run_info, generator_output_nodes)
# When stream is True, we allow generator output in the flow output
run_tracker.allow_generator_types = stream
run_tracker.update_and_persist_generator_node_runs(run_info.run_id, generator_output_nodes)
run_tracker.end_run(run_info.run_id, result=output)
enrich_span_with_trace_type(span, inputs, output, trace_type=TraceType.FLOW)
span.set_status(StatusCode.OK)
Expand Down Expand Up @@ -1148,33 +1158,61 @@ def _extract_outputs(self, nodes_outputs, bypassed_nodes, flow_inputs):
return outputs

def _should_use_async(self):
def is_async(f):
# Here we check the original function since currently asyncgenfunction would be converted to sync func
# TODO: Improve @trace logic to make sure wrapped asyncgen is still an asyncgen
original_func = getattr(f, "__original_function", f)
return inspect.iscoroutinefunction(original_func) or inspect.isasyncgenfunction(original_func)

return (
all(inspect.iscoroutinefunction(f) for f in self._tools_manager._tools.values())
any(is_async(f) for f in self._tools_manager._tools.values())
or os.environ.get("PF_USE_ASYNC", "false").lower() == "true"
)

def _traverse_nodes(self, inputs, context: FlowExecutionContext) -> Tuple[dict, dict]:
batch_nodes = [node for node in self._flow.nodes if not node.aggregation]
outputs = {}
# TODO: Use a mixed scheduler to support both async and thread pool mode.
nodes_outputs, bypassed_nodes = self._submit_to_scheduler(context, inputs, batch_nodes)
outputs = self._extract_outputs(nodes_outputs, bypassed_nodes, inputs)
return outputs, nodes_outputs

async def _traverse_nodes_async(self, inputs, context: FlowExecutionContext) -> Tuple[dict, dict]:
batch_nodes = [node for node in self._flow.nodes if not node.aggregation]
outputs = {}
# Always use async scheduler when calling from async function.
flow_logger.info("Start executing nodes in async mode.")
scheduler = AsyncNodesScheduler(self._tools_manager, self._node_concurrency)
nodes_outputs, bypassed_nodes = await scheduler.execute(batch_nodes, inputs, context)
outputs = self._extract_outputs(nodes_outputs, bypassed_nodes, inputs)
return outputs, nodes_outputs

@staticmethod
async def _merge_async_generator(async_gen: AsyncGeneratorType, outputs: dict, key: str):
items = []
async for item in async_gen:
items.append(item)
outputs[key] = "".join(str(item) for item in items)

async def _stringify_generator_output_async(self, outputs: dict):
pool = ThreadPoolExecutorWithContext()
tasks = []
for k, v in outputs.items():
if isinstance(v, AsyncGeneratorType):
tasks.append(asyncio.create_task(self._merge_async_generator(v, outputs, k)))
elif isinstance(v, GeneratorType):
loop = asyncio.get_event_loop()
task = loop.run_in_executor(pool, self._merge_generator, v, outputs, k)
tasks.append(task)
if tasks:
await asyncio.wait(tasks)
return outputs

@staticmethod
def _merge_generator(gen: GeneratorType, outputs: dict, key: str):
outputs[key] = "".join(str(item) for item in gen)

def _stringify_generator_output(self, outputs: dict):
for k, v in outputs.items():
if isinstance(v, GeneratorType):
outputs[k] = "".join(str(chuck) for chuck in v)
self._merge_generator(v, outputs, k)

return outputs

Expand All @@ -1187,14 +1225,9 @@ def _submit_to_scheduler(self, context: FlowExecutionContext, inputs, nodes: Lis
),
current_value=self._node_concurrency,
)
if self._should_use_async():
flow_logger.info("Start executing nodes in async mode.")
scheduler = AsyncNodesScheduler(self._tools_manager, self._node_concurrency)
return asyncio.run(scheduler.execute(nodes, inputs, context))
else:
flow_logger.info("Start executing nodes in thread pool mode.")
scheduler = FlowNodesScheduler(self._tools_manager, inputs, nodes, self._node_concurrency, context)
return scheduler.execute(self._line_timeout_sec)
flow_logger.info("Start executing nodes in thread pool mode.")
scheduler = FlowNodesScheduler(self._tools_manager, inputs, nodes, self._node_concurrency, context)
return scheduler.execute(self._line_timeout_sec)

@staticmethod
def apply_inputs_mapping(
Expand Down

0 comments on commit 35ed6c5

Please sign in to comment.