diff --git a/src/promptflow-core/promptflow/_core/run_tracker.py b/src/promptflow-core/promptflow/_core/run_tracker.py index c10a20dd2cb..ac2abd26f7d 100644 --- a/src/promptflow-core/promptflow/_core/run_tracker.py +++ b/src/promptflow-core/promptflow/_core/run_tracker.py @@ -3,10 +3,10 @@ # --------------------------------------------------------- import asyncio +import inspect import json from contextvars import ContextVar from datetime import datetime, timezone -from types import GeneratorType from typing import Any, Dict, List, Mapping, Optional, Union from promptflow._constants import MessageFormatType @@ -296,7 +296,7 @@ def end_run( def _ensure_serializable_value(self, val, warning_msg: Optional[str] = None): if ConnectionType.is_connection_value(val): return ConnectionType.serialize_conn(val) - if self.allow_generator_types and isinstance(val, GeneratorType): + if inspect.isgenerator(val) or inspect.isasyncgen(val): return str(val) try: json.dumps(val, default=default_json_encoder) @@ -430,22 +430,24 @@ def get_run(self, run_id): def persist_node_run(self, run_info: RunInfo): self._storage.persist_node_run(run_info) - def persist_selected_node_runs(self, run_info: FlowRunInfo, node_names: List[str]): + def update_and_persist_generator_node_runs(self, run_id: str, node_names: List[str]): """ - Persists the node runs for the specified node names. + Persists the node runs for nodes producing generators. - :param run_info: The flow run information. - :type run_info: FlowRunInfo + :param run_id: The ID of the flow run. + :type run_id: str :param node_names: The names of the nodes to persist. :type node_names: List[str] :returns: None """ - run_id = run_info.run_id - selected_node_run_info = ( run_info for run_info in self.collect_child_node_runs(run_id) if run_info.node in node_names ) for node_run_info in selected_node_run_info: + # Update the output of the node run with the output in the trace. + # This is because the output in the trace would includes the generated items. + output_in_trace = node_run_info.api_calls[0]["output"] + node_run_info.output = output_in_trace self.persist_node_run(node_run_info) def persist_flow_run(self, run_info: FlowRunInfo): diff --git a/src/promptflow-core/promptflow/_utils/async_utils.py b/src/promptflow-core/promptflow/_utils/async_utils.py index dd019233bf8..41454e0d0f2 100644 --- a/src/promptflow-core/promptflow/_utils/async_utils.py +++ b/src/promptflow-core/promptflow/_utils/async_utils.py @@ -3,11 +3,8 @@ # --------------------------------------------------------- import asyncio -import contextvars import functools -from concurrent.futures import ThreadPoolExecutor -from promptflow._utils.utils import set_context from promptflow.tracing import ThreadPoolExecutorWithContext @@ -36,7 +33,7 @@ def async_run_allowing_running_loop(async_func, *args, **kwargs): event loop, we run _exec_batch in a new thread; otherwise, we run it in the current thread. """ if _has_running_loop(): - with ThreadPoolExecutor(1, initializer=set_context, initargs=(contextvars.copy_context(),)) as executor: + with ThreadPoolExecutorWithContext() as executor: return executor.submit(lambda: asyncio.run(async_func(*args, **kwargs))).result() else: return asyncio.run(async_func(*args, **kwargs)) diff --git a/src/promptflow-core/promptflow/_utils/run_tracker_utils.py b/src/promptflow-core/promptflow/_utils/run_tracker_utils.py index 1f44535db36..fca2be0da1d 100644 --- a/src/promptflow-core/promptflow/_utils/run_tracker_utils.py +++ b/src/promptflow-core/promptflow/_utils/run_tracker_utils.py @@ -1,9 +1,10 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +import inspect from copy import deepcopy -from promptflow.tracing.contracts.generator_proxy import GeneratorProxy +from promptflow.tracing.contracts.generator_proxy import AsyncGeneratorProxy, GeneratorProxy def _deep_copy_and_extract_items_from_generator_proxy(value: object) -> object: @@ -18,6 +19,8 @@ def _deep_copy_and_extract_items_from_generator_proxy(value: object) -> object: return [_deep_copy_and_extract_items_from_generator_proxy(v) for v in value] elif isinstance(value, dict): return {k: _deep_copy_and_extract_items_from_generator_proxy(v) for k, v in value.items()} - elif isinstance(value, GeneratorProxy): + elif isinstance(value, (GeneratorProxy, AsyncGeneratorProxy)): return deepcopy(value.items) + elif inspect.isgenerator(value) or inspect.isasyncgen(value): + return str(value) # Convert generator to string to avoid deepcopy error return deepcopy(value) diff --git a/src/promptflow-core/promptflow/executor/flow_executor.py b/src/promptflow-core/promptflow/executor/flow_executor.py index a5a1374b081..90cb95bb0e1 100644 --- a/src/promptflow-core/promptflow/executor/flow_executor.py +++ b/src/promptflow-core/promptflow/executor/flow_executor.py @@ -29,6 +29,7 @@ from promptflow._core.run_tracker import RunTracker from promptflow._core.tool import STREAMING_OPTION_PARAMETER_ATTR from promptflow._core.tools_manager import ToolsManager +from promptflow._utils.async_utils import async_run_allowing_running_loop from promptflow._utils.context_utils import _change_working_dir from promptflow._utils.execution_utils import ( apply_default_value_for_input, @@ -60,6 +61,7 @@ from promptflow.executor.flow_validator import FlowValidator from promptflow.storage import AbstractRunStorage from promptflow.storage._run_storage import DefaultRunStorage +from promptflow.tracing import ThreadPoolExecutorWithContext from promptflow.tracing._integrations._openai_injector import inject_openai_api from promptflow.tracing._operation_context import OperationContext from promptflow.tracing._start_trace import setup_exporter_from_environ @@ -693,24 +695,33 @@ def exec_line( :return: The result of executing the line. :rtype: ~promptflow.executor._result.LineResult """ + if self._should_use_async(): + # Use async exec_line when the tools are async + return async_run_allowing_running_loop( + self.exec_line_async, + inputs, + index, + run_id, + validate_inputs, + node_concurrency, + allow_generator_output, + line_timeout_sec, + ) # TODO: Call exec_line_async in exec_line when async is mature. self._node_concurrency = node_concurrency # TODO: Pass line_timeout_sec to flow node scheduler instead of updating self._line_timeout_sec self._line_timeout_sec = line_timeout_sec or self._line_timeout_sec inputs = apply_default_value_for_input(self._flow.inputs, inputs) # For flow run, validate inputs as default - with self._run_tracker.node_log_manager: - # exec_line interface may be called when executing a batch run, so we only set run_mode as flow run when - # it is not set. - run_id = run_id or str(uuid.uuid4()) - with self._update_operation_context(run_id, index): - line_result = self._exec( - inputs, - run_id=run_id, - line_number=index, - validate_inputs=validate_inputs, - allow_generator_output=allow_generator_output, - ) + run_id = run_id or str(uuid.uuid4()) + with self._run_tracker.node_log_manager, self._update_operation_context(run_id, index): + line_result = self._exec( + inputs, + run_id=run_id, + line_number=index, + validate_inputs=validate_inputs, + allow_generator_output=allow_generator_output, + ) # Return line result with index if index is not None and isinstance(line_result.output, dict): line_result.output[LINE_NUMBER_KEY] = index @@ -724,6 +735,7 @@ async def exec_line_async( validate_inputs: bool = True, node_concurrency=DEFAULT_CONCURRENCY_FLOW, allow_generator_output: bool = False, + line_timeout_sec: Optional[int] = None, ) -> LineResult: """Execute a single line of the flow. @@ -743,13 +755,12 @@ async def exec_line_async( :rtype: ~promptflow.executor._result.LineResult """ self._node_concurrency = node_concurrency + # TODO: Pass line_timeout_sec to flow node scheduler instead of updating self._line_timeout_sec + self._line_timeout_sec = line_timeout_sec or self._line_timeout_sec inputs = apply_default_value_for_input(self._flow.inputs, inputs) # For flow run, validate inputs as default - with self._run_tracker.node_log_manager: - # exec_line interface may be called when executing a batch run, so we only set run_mode as flow run when - # it is not set. - operation_context = OperationContext.get_instance() - operation_context.run_mode = operation_context.get("run_mode", None) or RunMode.Test.name + run_id = run_id or str(uuid.uuid4()) + with self._run_tracker.node_log_manager, self._update_operation_context(run_id, index): line_result = await self._exec_async( inputs, run_id=run_id, @@ -869,8 +880,7 @@ async def _exec_inner_with_trace_async( ): with self._start_flow_span(inputs) as span, self._record_cancellation_exceptions_to_span(span): output, nodes_outputs = await self._traverse_nodes_async(inputs, context) - # TODO: Also stringify async generator output - output = self._stringify_generator_output(output) if not stream else output + output = await self._stringify_generator_output_async(output) if not stream else output self._exec_post_process(inputs, output, nodes_outputs, run_info, run_tracker, span, stream) return output, extract_aggregation_inputs(self._flow, nodes_outputs) @@ -914,9 +924,9 @@ def _exec_post_process( for nodename, output in nodes_outputs.items() if isinstance(output, GeneratorType) or isinstance(output, AsyncGeneratorType) ] - run_tracker.persist_selected_node_runs(run_info, generator_output_nodes) # When stream is True, we allow generator output in the flow output run_tracker.allow_generator_types = stream + run_tracker.update_and_persist_generator_node_runs(run_info.run_id, generator_output_nodes) run_tracker.end_run(run_info.run_id, result=output) enrich_span_with_trace_type(span, inputs, output, trace_type=TraceType.FLOW) span.set_status(StatusCode.OK) @@ -1148,33 +1158,61 @@ def _extract_outputs(self, nodes_outputs, bypassed_nodes, flow_inputs): return outputs def _should_use_async(self): + def is_async(f): + # Here we check the original function since currently asyncgenfunction would be converted to sync func + # TODO: Improve @trace logic to make sure wrapped asyncgen is still an asyncgen + original_func = getattr(f, "__original_function", f) + return inspect.iscoroutinefunction(original_func) or inspect.isasyncgenfunction(original_func) + return ( - all(inspect.iscoroutinefunction(f) for f in self._tools_manager._tools.values()) + any(is_async(f) for f in self._tools_manager._tools.values()) or os.environ.get("PF_USE_ASYNC", "false").lower() == "true" ) def _traverse_nodes(self, inputs, context: FlowExecutionContext) -> Tuple[dict, dict]: batch_nodes = [node for node in self._flow.nodes if not node.aggregation] outputs = {} - # TODO: Use a mixed scheduler to support both async and thread pool mode. nodes_outputs, bypassed_nodes = self._submit_to_scheduler(context, inputs, batch_nodes) outputs = self._extract_outputs(nodes_outputs, bypassed_nodes, inputs) return outputs, nodes_outputs async def _traverse_nodes_async(self, inputs, context: FlowExecutionContext) -> Tuple[dict, dict]: batch_nodes = [node for node in self._flow.nodes if not node.aggregation] - outputs = {} - # Always use async scheduler when calling from async function. flow_logger.info("Start executing nodes in async mode.") scheduler = AsyncNodesScheduler(self._tools_manager, self._node_concurrency) nodes_outputs, bypassed_nodes = await scheduler.execute(batch_nodes, inputs, context) outputs = self._extract_outputs(nodes_outputs, bypassed_nodes, inputs) return outputs, nodes_outputs + @staticmethod + async def _merge_async_generator(async_gen: AsyncGeneratorType, outputs: dict, key: str): + items = [] + async for item in async_gen: + items.append(item) + outputs[key] = "".join(str(item) for item in items) + + async def _stringify_generator_output_async(self, outputs: dict): + pool = ThreadPoolExecutorWithContext() + tasks = [] + for k, v in outputs.items(): + if isinstance(v, AsyncGeneratorType): + tasks.append(asyncio.create_task(self._merge_async_generator(v, outputs, k))) + elif isinstance(v, GeneratorType): + loop = asyncio.get_event_loop() + task = loop.run_in_executor(pool, self._merge_generator, v, outputs, k) + tasks.append(task) + if tasks: + await asyncio.wait(tasks) + return outputs + + @staticmethod + def _merge_generator(gen: GeneratorType, outputs: dict, key: str): + outputs[key] = "".join(str(item) for item in gen) + def _stringify_generator_output(self, outputs: dict): for k, v in outputs.items(): if isinstance(v, GeneratorType): - outputs[k] = "".join(str(chuck) for chuck in v) + self._merge_generator(v, outputs, k) return outputs @@ -1187,14 +1225,9 @@ def _submit_to_scheduler(self, context: FlowExecutionContext, inputs, nodes: Lis ), current_value=self._node_concurrency, ) - if self._should_use_async(): - flow_logger.info("Start executing nodes in async mode.") - scheduler = AsyncNodesScheduler(self._tools_manager, self._node_concurrency) - return asyncio.run(scheduler.execute(nodes, inputs, context)) - else: - flow_logger.info("Start executing nodes in thread pool mode.") - scheduler = FlowNodesScheduler(self._tools_manager, inputs, nodes, self._node_concurrency, context) - return scheduler.execute(self._line_timeout_sec) + flow_logger.info("Start executing nodes in thread pool mode.") + scheduler = FlowNodesScheduler(self._tools_manager, inputs, nodes, self._node_concurrency, context) + return scheduler.execute(self._line_timeout_sec) @staticmethod def apply_inputs_mapping(