From 35ed6c5072e47e05febe75dab92317f64fd1d5da Mon Sep 17 00:00:00 2001 From: Heyi Tang Date: Tue, 30 Apr 2024 17:47:08 +0800 Subject: [PATCH] [Executor] Refine executor logic to support async generator in flow test (#3083) # Description Refine exec_line logic to support async generator in flow test. This pull request primarily focuses on improving the handling of generators and asynchronous generators in the `promptflow` package. The changes ensure that the code can handle both synchronous and asynchronous generators, improve the handling of generator outputs, and refactor the way nodes are executed in the flow. Here are the key changes: Improved handling of generators and asynchronous generators: * [`src/promptflow-core/promptflow/_core/run_tracker.py`](diffhunk://#diff-a5027d19a24cb28a68ead16dfe6c54492c78d6e0e7640e80533928808cdb3422R6-L9): The `inspect` module was imported to replace the use of `GeneratorType` for checking if a value is a generator. The method `update_and_persist_generator_node_runs` was introduced to replace `persist_selected_node_runs`, and it now updates the output of the node run with the output in the trace before persisting it. [[1]](diffhunk://#diff-a5027d19a24cb28a68ead16dfe6c54492c78d6e0e7640e80533928808cdb3422R6-L9) [[2]](diffhunk://#diff-a5027d19a24cb28a68ead16dfe6c54492c78d6e0e7640e80533928808cdb3422L299-R299) [[3]](diffhunk://#diff-a5027d19a24cb28a68ead16dfe6c54492c78d6e0e7640e80533928808cdb3422L433-R450) * [`src/promptflow-core/promptflow/_utils/run_tracker_utils.py`](diffhunk://#diff-cc2845177424c6393b16b91ff5a7753eaf73aa52c4c52c53c8f83eb68746815cR4-R7): The `inspect` module was imported, and the method `_deep_copy_and_extract_items_from_generator_proxy` was updated to handle `AsyncGeneratorProxy` and to convert generators to strings to avoid deepcopy errors. [[1]](diffhunk://#diff-cc2845177424c6393b16b91ff5a7753eaf73aa52c4c52c53c8f83eb68746815cR4-R7) [[2]](diffhunk://#diff-cc2845177424c6393b16b91ff5a7753eaf73aa52c4c52c53c8f83eb68746815cL21-R25) Refactoring of node execution: * [`src/promptflow-core/promptflow/executor/flow_executor.py`](diffhunk://#diff-bec06607cb28fd791b8ed11bb488979344ca342be5f1c67ba6dd663d5e12240fR63): The `ThreadPoolExecutorWithContext` was imported from `promptflow.tracing`. In the `exec_line` method, a check was added to use `exec_line_async` when the tools are async. The `exec_line_async` method was updated to include a `line_timeout_sec` parameter. The `_exec_inner_with_trace_async` method was updated to use `_stringify_generator_output_async` to handle async generator output. The `_exec_post_process` method was updated to use `update_and_persist_generator_node_runs` instead of `persist_selected_node_runs`. The `_should_use_async` method was updated to check if any tool is async. The `_traverse_nodes_async` method was updated to use an async scheduler. The methods `_merge_async_generator`, `_stringify_generator_output_async`, and `_merge_generator` were added to handle generator outputs. The `_submit_to_scheduler` method was updated to only use the thread pool mode. [[1]](diffhunk://#diff-bec06607cb28fd791b8ed11bb488979344ca342be5f1c67ba6dd663d5e12240fR63) [[2]](diffhunk://#diff-bec06607cb28fd791b8ed11bb488979344ca342be5f1c67ba6dd663d5e12240fR693-R699) [[3]](diffhunk://#diff-bec06607cb28fd791b8ed11bb488979344ca342be5f1c67ba6dd663d5e12240fR731) [[4]](diffhunk://#diff-bec06607cb28fd791b8ed11bb488979344ca342be5f1c67ba6dd663d5e12240fR751-R752) [[5]](diffhunk://#diff-bec06607cb28fd791b8ed11bb488979344ca342be5f1c67ba6dd663d5e12240fL869-R880) [[6]](diffhunk://#diff-bec06607cb28fd791b8ed11bb488979344ca342be5f1c67ba6dd663d5e12240fL913-R926) [[7]](diffhunk://#diff-bec06607cb28fd791b8ed11bb488979344ca342be5f1c67ba6dd663d5e12240fR1158-R1212) [[8]](diffhunk://#diff-bec06607cb28fd791b8ed11bb488979344ca342be5f1c67ba6dd663d5e12240fL1186-L1190) # All Promptflow Contribution checklist: - [ ] **The pull request does not introduce [breaking changes].** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [ ] Title of the pull request is clear and informative. - [ ] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [ ] Pull request includes test coverage for the included changes. --------- Co-authored-by: Heyi --- .../promptflow/_core/run_tracker.py | 18 ++-- .../promptflow/_utils/async_utils.py | 5 +- .../promptflow/_utils/run_tracker_utils.py | 7 +- .../promptflow/executor/flow_executor.py | 99 ++++++++++++------- 4 files changed, 82 insertions(+), 47 deletions(-) diff --git a/src/promptflow-core/promptflow/_core/run_tracker.py b/src/promptflow-core/promptflow/_core/run_tracker.py index c10a20dd2cb..ac2abd26f7d 100644 --- a/src/promptflow-core/promptflow/_core/run_tracker.py +++ b/src/promptflow-core/promptflow/_core/run_tracker.py @@ -3,10 +3,10 @@ # --------------------------------------------------------- import asyncio +import inspect import json from contextvars import ContextVar from datetime import datetime, timezone -from types import GeneratorType from typing import Any, Dict, List, Mapping, Optional, Union from promptflow._constants import MessageFormatType @@ -296,7 +296,7 @@ def end_run( def _ensure_serializable_value(self, val, warning_msg: Optional[str] = None): if ConnectionType.is_connection_value(val): return ConnectionType.serialize_conn(val) - if self.allow_generator_types and isinstance(val, GeneratorType): + if inspect.isgenerator(val) or inspect.isasyncgen(val): return str(val) try: json.dumps(val, default=default_json_encoder) @@ -430,22 +430,24 @@ def get_run(self, run_id): def persist_node_run(self, run_info: RunInfo): self._storage.persist_node_run(run_info) - def persist_selected_node_runs(self, run_info: FlowRunInfo, node_names: List[str]): + def update_and_persist_generator_node_runs(self, run_id: str, node_names: List[str]): """ - Persists the node runs for the specified node names. + Persists the node runs for nodes producing generators. - :param run_info: The flow run information. - :type run_info: FlowRunInfo + :param run_id: The ID of the flow run. + :type run_id: str :param node_names: The names of the nodes to persist. :type node_names: List[str] :returns: None """ - run_id = run_info.run_id - selected_node_run_info = ( run_info for run_info in self.collect_child_node_runs(run_id) if run_info.node in node_names ) for node_run_info in selected_node_run_info: + # Update the output of the node run with the output in the trace. + # This is because the output in the trace would includes the generated items. + output_in_trace = node_run_info.api_calls[0]["output"] + node_run_info.output = output_in_trace self.persist_node_run(node_run_info) def persist_flow_run(self, run_info: FlowRunInfo): diff --git a/src/promptflow-core/promptflow/_utils/async_utils.py b/src/promptflow-core/promptflow/_utils/async_utils.py index dd019233bf8..41454e0d0f2 100644 --- a/src/promptflow-core/promptflow/_utils/async_utils.py +++ b/src/promptflow-core/promptflow/_utils/async_utils.py @@ -3,11 +3,8 @@ # --------------------------------------------------------- import asyncio -import contextvars import functools -from concurrent.futures import ThreadPoolExecutor -from promptflow._utils.utils import set_context from promptflow.tracing import ThreadPoolExecutorWithContext @@ -36,7 +33,7 @@ def async_run_allowing_running_loop(async_func, *args, **kwargs): event loop, we run _exec_batch in a new thread; otherwise, we run it in the current thread. """ if _has_running_loop(): - with ThreadPoolExecutor(1, initializer=set_context, initargs=(contextvars.copy_context(),)) as executor: + with ThreadPoolExecutorWithContext() as executor: return executor.submit(lambda: asyncio.run(async_func(*args, **kwargs))).result() else: return asyncio.run(async_func(*args, **kwargs)) diff --git a/src/promptflow-core/promptflow/_utils/run_tracker_utils.py b/src/promptflow-core/promptflow/_utils/run_tracker_utils.py index 1f44535db36..fca2be0da1d 100644 --- a/src/promptflow-core/promptflow/_utils/run_tracker_utils.py +++ b/src/promptflow-core/promptflow/_utils/run_tracker_utils.py @@ -1,9 +1,10 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +import inspect from copy import deepcopy -from promptflow.tracing.contracts.generator_proxy import GeneratorProxy +from promptflow.tracing.contracts.generator_proxy import AsyncGeneratorProxy, GeneratorProxy def _deep_copy_and_extract_items_from_generator_proxy(value: object) -> object: @@ -18,6 +19,8 @@ def _deep_copy_and_extract_items_from_generator_proxy(value: object) -> object: return [_deep_copy_and_extract_items_from_generator_proxy(v) for v in value] elif isinstance(value, dict): return {k: _deep_copy_and_extract_items_from_generator_proxy(v) for k, v in value.items()} - elif isinstance(value, GeneratorProxy): + elif isinstance(value, (GeneratorProxy, AsyncGeneratorProxy)): return deepcopy(value.items) + elif inspect.isgenerator(value) or inspect.isasyncgen(value): + return str(value) # Convert generator to string to avoid deepcopy error return deepcopy(value) diff --git a/src/promptflow-core/promptflow/executor/flow_executor.py b/src/promptflow-core/promptflow/executor/flow_executor.py index a5a1374b081..90cb95bb0e1 100644 --- a/src/promptflow-core/promptflow/executor/flow_executor.py +++ b/src/promptflow-core/promptflow/executor/flow_executor.py @@ -29,6 +29,7 @@ from promptflow._core.run_tracker import RunTracker from promptflow._core.tool import STREAMING_OPTION_PARAMETER_ATTR from promptflow._core.tools_manager import ToolsManager +from promptflow._utils.async_utils import async_run_allowing_running_loop from promptflow._utils.context_utils import _change_working_dir from promptflow._utils.execution_utils import ( apply_default_value_for_input, @@ -60,6 +61,7 @@ from promptflow.executor.flow_validator import FlowValidator from promptflow.storage import AbstractRunStorage from promptflow.storage._run_storage import DefaultRunStorage +from promptflow.tracing import ThreadPoolExecutorWithContext from promptflow.tracing._integrations._openai_injector import inject_openai_api from promptflow.tracing._operation_context import OperationContext from promptflow.tracing._start_trace import setup_exporter_from_environ @@ -693,24 +695,33 @@ def exec_line( :return: The result of executing the line. :rtype: ~promptflow.executor._result.LineResult """ + if self._should_use_async(): + # Use async exec_line when the tools are async + return async_run_allowing_running_loop( + self.exec_line_async, + inputs, + index, + run_id, + validate_inputs, + node_concurrency, + allow_generator_output, + line_timeout_sec, + ) # TODO: Call exec_line_async in exec_line when async is mature. self._node_concurrency = node_concurrency # TODO: Pass line_timeout_sec to flow node scheduler instead of updating self._line_timeout_sec self._line_timeout_sec = line_timeout_sec or self._line_timeout_sec inputs = apply_default_value_for_input(self._flow.inputs, inputs) # For flow run, validate inputs as default - with self._run_tracker.node_log_manager: - # exec_line interface may be called when executing a batch run, so we only set run_mode as flow run when - # it is not set. - run_id = run_id or str(uuid.uuid4()) - with self._update_operation_context(run_id, index): - line_result = self._exec( - inputs, - run_id=run_id, - line_number=index, - validate_inputs=validate_inputs, - allow_generator_output=allow_generator_output, - ) + run_id = run_id or str(uuid.uuid4()) + with self._run_tracker.node_log_manager, self._update_operation_context(run_id, index): + line_result = self._exec( + inputs, + run_id=run_id, + line_number=index, + validate_inputs=validate_inputs, + allow_generator_output=allow_generator_output, + ) # Return line result with index if index is not None and isinstance(line_result.output, dict): line_result.output[LINE_NUMBER_KEY] = index @@ -724,6 +735,7 @@ async def exec_line_async( validate_inputs: bool = True, node_concurrency=DEFAULT_CONCURRENCY_FLOW, allow_generator_output: bool = False, + line_timeout_sec: Optional[int] = None, ) -> LineResult: """Execute a single line of the flow. @@ -743,13 +755,12 @@ async def exec_line_async( :rtype: ~promptflow.executor._result.LineResult """ self._node_concurrency = node_concurrency + # TODO: Pass line_timeout_sec to flow node scheduler instead of updating self._line_timeout_sec + self._line_timeout_sec = line_timeout_sec or self._line_timeout_sec inputs = apply_default_value_for_input(self._flow.inputs, inputs) # For flow run, validate inputs as default - with self._run_tracker.node_log_manager: - # exec_line interface may be called when executing a batch run, so we only set run_mode as flow run when - # it is not set. - operation_context = OperationContext.get_instance() - operation_context.run_mode = operation_context.get("run_mode", None) or RunMode.Test.name + run_id = run_id or str(uuid.uuid4()) + with self._run_tracker.node_log_manager, self._update_operation_context(run_id, index): line_result = await self._exec_async( inputs, run_id=run_id, @@ -869,8 +880,7 @@ async def _exec_inner_with_trace_async( ): with self._start_flow_span(inputs) as span, self._record_cancellation_exceptions_to_span(span): output, nodes_outputs = await self._traverse_nodes_async(inputs, context) - # TODO: Also stringify async generator output - output = self._stringify_generator_output(output) if not stream else output + output = await self._stringify_generator_output_async(output) if not stream else output self._exec_post_process(inputs, output, nodes_outputs, run_info, run_tracker, span, stream) return output, extract_aggregation_inputs(self._flow, nodes_outputs) @@ -914,9 +924,9 @@ def _exec_post_process( for nodename, output in nodes_outputs.items() if isinstance(output, GeneratorType) or isinstance(output, AsyncGeneratorType) ] - run_tracker.persist_selected_node_runs(run_info, generator_output_nodes) # When stream is True, we allow generator output in the flow output run_tracker.allow_generator_types = stream + run_tracker.update_and_persist_generator_node_runs(run_info.run_id, generator_output_nodes) run_tracker.end_run(run_info.run_id, result=output) enrich_span_with_trace_type(span, inputs, output, trace_type=TraceType.FLOW) span.set_status(StatusCode.OK) @@ -1148,33 +1158,61 @@ def _extract_outputs(self, nodes_outputs, bypassed_nodes, flow_inputs): return outputs def _should_use_async(self): + def is_async(f): + # Here we check the original function since currently asyncgenfunction would be converted to sync func + # TODO: Improve @trace logic to make sure wrapped asyncgen is still an asyncgen + original_func = getattr(f, "__original_function", f) + return inspect.iscoroutinefunction(original_func) or inspect.isasyncgenfunction(original_func) + return ( - all(inspect.iscoroutinefunction(f) for f in self._tools_manager._tools.values()) + any(is_async(f) for f in self._tools_manager._tools.values()) or os.environ.get("PF_USE_ASYNC", "false").lower() == "true" ) def _traverse_nodes(self, inputs, context: FlowExecutionContext) -> Tuple[dict, dict]: batch_nodes = [node for node in self._flow.nodes if not node.aggregation] outputs = {} - # TODO: Use a mixed scheduler to support both async and thread pool mode. nodes_outputs, bypassed_nodes = self._submit_to_scheduler(context, inputs, batch_nodes) outputs = self._extract_outputs(nodes_outputs, bypassed_nodes, inputs) return outputs, nodes_outputs async def _traverse_nodes_async(self, inputs, context: FlowExecutionContext) -> Tuple[dict, dict]: batch_nodes = [node for node in self._flow.nodes if not node.aggregation] - outputs = {} - # Always use async scheduler when calling from async function. flow_logger.info("Start executing nodes in async mode.") scheduler = AsyncNodesScheduler(self._tools_manager, self._node_concurrency) nodes_outputs, bypassed_nodes = await scheduler.execute(batch_nodes, inputs, context) outputs = self._extract_outputs(nodes_outputs, bypassed_nodes, inputs) return outputs, nodes_outputs + @staticmethod + async def _merge_async_generator(async_gen: AsyncGeneratorType, outputs: dict, key: str): + items = [] + async for item in async_gen: + items.append(item) + outputs[key] = "".join(str(item) for item in items) + + async def _stringify_generator_output_async(self, outputs: dict): + pool = ThreadPoolExecutorWithContext() + tasks = [] + for k, v in outputs.items(): + if isinstance(v, AsyncGeneratorType): + tasks.append(asyncio.create_task(self._merge_async_generator(v, outputs, k))) + elif isinstance(v, GeneratorType): + loop = asyncio.get_event_loop() + task = loop.run_in_executor(pool, self._merge_generator, v, outputs, k) + tasks.append(task) + if tasks: + await asyncio.wait(tasks) + return outputs + + @staticmethod + def _merge_generator(gen: GeneratorType, outputs: dict, key: str): + outputs[key] = "".join(str(item) for item in gen) + def _stringify_generator_output(self, outputs: dict): for k, v in outputs.items(): if isinstance(v, GeneratorType): - outputs[k] = "".join(str(chuck) for chuck in v) + self._merge_generator(v, outputs, k) return outputs @@ -1187,14 +1225,9 @@ def _submit_to_scheduler(self, context: FlowExecutionContext, inputs, nodes: Lis ), current_value=self._node_concurrency, ) - if self._should_use_async(): - flow_logger.info("Start executing nodes in async mode.") - scheduler = AsyncNodesScheduler(self._tools_manager, self._node_concurrency) - return asyncio.run(scheduler.execute(nodes, inputs, context)) - else: - flow_logger.info("Start executing nodes in thread pool mode.") - scheduler = FlowNodesScheduler(self._tools_manager, inputs, nodes, self._node_concurrency, context) - return scheduler.execute(self._line_timeout_sec) + flow_logger.info("Start executing nodes in thread pool mode.") + scheduler = FlowNodesScheduler(self._tools_manager, inputs, nodes, self._node_concurrency, context) + return scheduler.execute(self._line_timeout_sec) @staticmethod def apply_inputs_mapping(