diff --git a/src/promptflow/promptflow/executor/_errors.py b/src/promptflow/promptflow/executor/_errors.py index ac397086ab6..e70f19a6dbe 100644 --- a/src/promptflow/promptflow/executor/_errors.py +++ b/src/promptflow/promptflow/executor/_errors.py @@ -11,12 +11,6 @@ class InvalidCustomLLMTool(ValidationException): pass -class FlowExecutionError(SystemErrorException): - """Base System Exceptions for flow execution""" - - pass - - class ValueTypeUnresolved(ValidationException): pass @@ -101,11 +95,15 @@ class InputNotFound(InvalidFlowRequest): pass -class InputNotFoundFromAncestorNodeOutput(FlowExecutionError): +class InvalidAggregationInput(SystemErrorException): + pass + + +class InputNotFoundFromAncestorNodeOutput(SystemErrorException): pass -class NoNodeExecutedError(FlowExecutionError): +class NoNodeExecutedError(SystemErrorException): pass diff --git a/src/promptflow/promptflow/executor/flow_executor.py b/src/promptflow/promptflow/executor/flow_executor.py index 99ab07e0c1f..6d2d992e389 100644 --- a/src/promptflow/promptflow/executor/flow_executor.py +++ b/src/promptflow/promptflow/executor/flow_executor.py @@ -25,12 +25,13 @@ from promptflow._utils.context_utils import _change_working_dir from promptflow._utils.logger_utils import logger from promptflow._utils.utils import transpose -from promptflow.contracts.flow import Flow, InputAssignment, InputValueType, Node +from promptflow.contracts.flow import Flow, FlowInputDefinition, InputAssignment, InputValueType, Node from promptflow.contracts.run_info import FlowRunInfo, Status from promptflow.contracts.run_mode import RunMode from promptflow.exceptions import ErrorTarget, PromptflowException, SystemErrorException, ValidationException from promptflow.executor import _input_assignment_parser from promptflow.executor._errors import ( + InvalidAggregationInput, NodeConcurrencyNotFound, NodeOutputNotFound, OutputReferenceBypassed, @@ -195,6 +196,7 @@ def load_and_exec_node( if not node.source or not node.type: raise ValueError(f"Node {node_name} is not a valid node in flow {flow_file}.") + flow_inputs = FlowExecutor._apply_default_value_for_input(flow.inputs, flow_inputs) converted_flow_inputs_for_node = FlowValidator.convert_flow_inputs_for_node(flow, node, flow_inputs) package_tool_keys = [node.source.tool] if node.source and node.source.tool else [] tool_resolver = ToolResolver(working_dir, connections, package_tool_keys) @@ -400,8 +402,63 @@ def exec_aggregation( node_concurrency=DEFAULT_CONCURRENCY_FLOW, ) -> AggregationResult: self._node_concurrency = node_concurrency + aggregated_flow_inputs = dict(inputs or {}) + aggregation_inputs = dict(aggregation_inputs or {}) + self._validate_aggregation_inputs(aggregated_flow_inputs, aggregation_inputs) + aggregated_flow_inputs = self._apply_default_value_for_aggregation_input( + self._flow.inputs, aggregated_flow_inputs, aggregation_inputs + ) + with self._run_tracker.node_log_manager: - return self._exec_aggregation(inputs, aggregation_inputs, run_id) + return self._exec_aggregation(aggregated_flow_inputs, aggregation_inputs, run_id) + + @staticmethod + def _validate_aggregation_inputs(aggregated_flow_inputs: Mapping[str, Any], aggregation_inputs: Mapping[str, Any]): + """Validate the aggregation inputs according to the flow inputs.""" + for key, value in aggregated_flow_inputs.items(): + if key in aggregation_inputs: + raise InvalidAggregationInput( + message_format="Input '{input_key}' appear in both flow aggregation input and aggregation input.", + input_key=key, + ) + if not isinstance(value, list): + raise InvalidAggregationInput( + message_format="Flow aggregation input {input_key} should be one list.", input_key=key + ) + + for key, value in aggregation_inputs.items(): + if not isinstance(value, list): + raise InvalidAggregationInput( + message_format="Aggregation input {input_key} should be one list.", input_key=key + ) + + inputs_len = {key: len(value) for key, value in aggregated_flow_inputs.items()} + inputs_len.update({key: len(value) for key, value in aggregation_inputs.items()}) + if len(set(inputs_len.values())) > 1: + raise InvalidAggregationInput( + message_format="Whole aggregation inputs should have the same length. " + "Current key length mapping are: {key_len}", + key_len=inputs_len, + ) + + @staticmethod + def _apply_default_value_for_aggregation_input( + inputs: Dict[str, FlowInputDefinition], + aggregated_flow_inputs: Mapping[str, Any], + aggregation_inputs: Mapping[str, Any], + ): + aggregation_lines = 1 + if aggregated_flow_inputs.values(): + one_input_value = list(aggregated_flow_inputs.values())[0] + aggregation_lines = len(one_input_value) + # If aggregated_flow_inputs is empty, we should use aggregation_inputs to get the length. + elif aggregation_inputs.values(): + one_input_value = list(aggregation_inputs.values())[0] + aggregation_lines = len(one_input_value) + for key, value in inputs.items(): + if key not in aggregated_flow_inputs and (value and value.default): + aggregated_flow_inputs[key] = [value.default] * aggregation_lines + return aggregated_flow_inputs def _exec_aggregation( self, @@ -444,6 +501,7 @@ def _log_metric(key, value): def exec(self, inputs: dict, node_concurency=DEFAULT_CONCURRENCY_FLOW) -> dict: self._node_concurrency = node_concurency + inputs = FlowExecutor._apply_default_value_for_input(self._flow.inputs, inputs) result = self._exec(inputs) # TODO: remove this line once serving directly calling self.exec_line self._add_line_results([result]) @@ -482,6 +540,7 @@ def exec_line( allow_generator_output: bool = False, ) -> LineResult: self._node_concurrency = node_concurrency + inputs = FlowExecutor._apply_default_value_for_input(self._flow.inputs, inputs) # For flow run, validate inputs as default with self._run_tracker.node_log_manager: # exec_line interface may be called by exec_bulk, so we only set run_mode as flow run when @@ -539,6 +598,11 @@ def exec_bulk( BulkResults including flow results and metrics """ self._node_concurrency = node_concurrency + # Apply default value in early stage, so we can use it both in line execution and aggregation nodes execution. + inputs = [ + FlowExecutor._apply_default_value_for_input(self._flow.inputs, each_line_input) + for each_line_input in inputs + ] run_id = run_id or str(uuid.uuid4()) with self._run_tracker.node_log_manager: OperationContext.get_instance().run_mode = RunMode.Batch.name @@ -558,7 +622,15 @@ def exec_bulk( aggr_results=aggr_results, ) - def validate_and_apply_inputs_mapping(self, inputs, inputs_mapping): + @staticmethod + def _apply_default_value_for_input(inputs: Dict[str, FlowInputDefinition], line_inputs: Mapping) -> Dict[str, Any]: + updated_inputs = dict(line_inputs or {}) + for key, value in inputs.items(): + if key not in updated_inputs and (value and value.default): + updated_inputs[key] = value.default + return updated_inputs + + def validate_and_apply_inputs_mapping(self, inputs, inputs_mapping) -> List[Dict[str, Any]]: inputs_mapping = self._complete_inputs_mapping_by_default_value(inputs_mapping) resolved_inputs = self._apply_inputs_mapping_for_all_lines(inputs, inputs_mapping) return resolved_inputs @@ -566,6 +638,11 @@ def validate_and_apply_inputs_mapping(self, inputs, inputs_mapping): def _complete_inputs_mapping_by_default_value(self, inputs_mapping): inputs_mapping = inputs_mapping or {} result_mapping = self._default_inputs_mapping + # For input has default value, we don't try to read data from default mapping. + # Default value is in higher priority than default mapping. + for key, value in self._flow.inputs.items(): + if value and value.default: + del result_mapping[key] result_mapping.update(inputs_mapping) return result_mapping diff --git a/src/promptflow/tests/executor/e2etests/test_executor_happypath.py b/src/promptflow/tests/executor/e2etests/test_executor_happypath.py index 8a79298117b..5c2874bf73a 100644 --- a/src/promptflow/tests/executor/e2etests/test_executor_happypath.py +++ b/src/promptflow/tests/executor/e2etests/test_executor_happypath.py @@ -342,3 +342,33 @@ def test_executor_creation_with_default_variants(self, flow_folder, dev_connecti executor = FlowExecutor.create(get_yaml_file(flow_folder), dev_connections) flow_result = executor.exec_line(self.get_line_inputs()) assert flow_result.run_info.status == Status.Completed + + def test_executor_creation_with_default_input(self): + # Assert for single node run. + default_input_value = "input value from default" + yaml_file = get_yaml_file("default_input") + executor = FlowExecutor.create(yaml_file, {}) + node_result = executor.load_and_exec_node(yaml_file, "test_print_input") + assert node_result.status == Status.Completed + assert node_result.output == default_input_value + + # Assert for flow run. + flow_result = executor.exec_line({}) + assert flow_result.run_info.status == Status.Completed + assert flow_result.output["output"] == default_input_value + aggr_results = executor.exec_aggregation({}, aggregation_inputs={}) + flow_aggregate_node = aggr_results.node_run_infos["aggregate_node"] + assert flow_aggregate_node.status == Status.Completed + assert flow_aggregate_node.output == [default_input_value] + + # Assert for bulk run. + bulk_result = executor.exec_bulk([{}]) + assert bulk_result.line_results[0].run_info.status == Status.Completed + assert bulk_result.line_results[0].output["output"] == default_input_value + bulk_aggregate_node = bulk_result.aggr_results.node_run_infos["aggregate_node"] + assert bulk_aggregate_node.status == Status.Completed + assert bulk_aggregate_node.output == [default_input_value] + + # Assert for exec + exec_result = executor.exec({}) + assert exec_result["output"] == default_input_value diff --git a/src/promptflow/tests/executor/unittests/executor/test_flow_executor.py b/src/promptflow/tests/executor/unittests/executor/test_flow_executor.py index e22644707b8..90b3581c399 100644 --- a/src/promptflow/tests/executor/unittests/executor/test_flow_executor.py +++ b/src/promptflow/tests/executor/unittests/executor/test_flow_executor.py @@ -4,7 +4,10 @@ import pytest from promptflow import tool -from promptflow.contracts.flow import Flow +from promptflow.contracts.flow import Flow, FlowInputDefinition +from promptflow.contracts.tool import ValueType +from promptflow.executor._errors import InvalidAggregationInput +from promptflow.executor._line_execution_process_pool import get_available_max_worker_count from promptflow.executor.flow_executor import ( EmptyInputAfterMapping, EmptyInputListError, @@ -12,12 +15,11 @@ LineNumberNotAlign, MappingSourceNotFound, NoneInputsMappingIsNotSupported, - enable_streaming_for_llm_tool, _ensure_node_result_is_serializable, _inject_stream_options, + enable_streaming_for_llm_tool, ) from promptflow.tools.aoai import AzureOpenAI, chat, completion -from promptflow.executor._line_execution_process_pool import get_available_max_worker_count @pytest.mark.unittest @@ -148,9 +150,12 @@ def test_merge_input_dicts_by_line_error(self, inputs, error_code, error_message @pytest.mark.parametrize("inputs_mapping", [{"question": "${data.question}"}, {}]) def test_complete_inputs_mapping_by_default_value(self, inputs_mapping): - flow = Flow( - id="fakeId", name=None, nodes=[], inputs={"question": None, "groundtruth": None}, outputs=None, tools=[] - ) + inputs = { + "question": None, + "groundtruth": None, + "input_with_default_value": FlowInputDefinition(type=ValueType.INT, default="default_value"), + } + flow = Flow(id="fakeId", name=None, nodes=[], inputs=inputs, outputs=None, tools=[]) flow_executor = FlowExecutor( flow=flow, connections=None, @@ -159,6 +164,7 @@ def test_complete_inputs_mapping_by_default_value(self, inputs_mapping): loaded_tools=None, ) updated_inputs_mapping = flow_executor._complete_inputs_mapping_by_default_value(inputs_mapping) + assert "input_with_default_value" not in updated_inputs_mapping assert updated_inputs_mapping == {"question": "${data.question}", "groundtruth": "${data.groundtruth}"} @pytest.mark.parametrize( @@ -306,6 +312,140 @@ def test_inputs_mapping_for_all_lines_error(self, inputs, inputs_mapping, error_ FlowExecutor._apply_inputs_mapping_for_all_lines(inputs, inputs_mapping) assert error_message == str(e.value), "Expected: {}, Actual: {}".format(error_message, str(e.value)) + @pytest.mark.parametrize( + "flow_inputs, inputs, expected_inputs", + [ + ( + { + "input_from_default": FlowInputDefinition(type=ValueType.STRING, default="default_value"), + }, + None, # Could handle None input + {"input_from_default": "default_value"}, + ), + ( + { + "input_from_default": FlowInputDefinition(type=ValueType.STRING, default="default_value"), + }, + {}, + {"input_from_default": "default_value"}, + ), + ( + { + "input_no_default": FlowInputDefinition(type=ValueType.STRING), + }, + {}, + {}, # No default value for input. + ), + ( + { + "input_from_default": FlowInputDefinition(type=ValueType.STRING, default="default_value"), + }, + {"input_from_default": "input_value", "another_key": "input_value"}, + {"input_from_default": "input_value", "another_key": "input_value"}, + ), + ], + ) + def test_apply_default_value_for_input(self, flow_inputs, inputs, expected_inputs): + result = FlowExecutor._apply_default_value_for_input(flow_inputs, inputs) + assert result == expected_inputs + + @pytest.mark.parametrize( + "flow_inputs, aggregated_flow_inputs, aggregation_inputs, expected_inputs", + [ + ( + { + "input_from_default": FlowInputDefinition(type=ValueType.STRING, default="default_value"), + }, + {}, + {}, + {"input_from_default": ["default_value"]}, + ), + ( + { + "input_no_default": FlowInputDefinition(type=ValueType.STRING), + }, + {}, + {}, + {}, # No default value for input. + ), + ( + { + "input_from_default": FlowInputDefinition(type=ValueType.STRING, default="default_value"), + }, + {"input_from_default": "input_value", "another_key": "input_value"}, + {}, + {"input_from_default": "input_value", "another_key": "input_value"}, + ), + ( + { + "input_from_default": FlowInputDefinition(type=ValueType.STRING, default="default_value"), + }, + {"another_key": ["input_value", "input_value"]}, + {}, + { + "input_from_default": ["default_value", "default_value"], + "another_key": ["input_value", "input_value"], + }, + ), + ( + { + "input_from_default": FlowInputDefinition(type=ValueType.STRING, default="default_value"), + }, + {}, + {"another_key_in_aggregation_inputs": ["input_value", "input_value"]}, + { + "input_from_default": ["default_value", "default_value"], + }, + ), + ], + ) + def test_apply_default_value_for_aggregation_input( + self, flow_inputs, aggregated_flow_inputs, aggregation_inputs, expected_inputs + ): + result = FlowExecutor._apply_default_value_for_aggregation_input( + flow_inputs, aggregated_flow_inputs, aggregation_inputs + ) + assert result == expected_inputs + + @pytest.mark.parametrize( + "aggregated_flow_inputs, aggregation_inputs, error_message", + [ + ( + {}, + { + "input1": "value1", + }, + "Aggregation input input1 should be one list.", + ), + ( + { + "input1": "value1", + }, + {}, + "Flow aggregation input input1 should be one list.", + ), + ( + {"input1": ["value1_1", "value1_2"]}, + {"input_2": ["value2_1"]}, + "Whole aggregation inputs should have the same length. " + "Current key length mapping are: {'input1': 2, 'input_2': 1}", + ), + ( + { + "input1": "value1", + }, + { + "input1": "value1", + }, + "Input 'input1' appear in both flow aggregation input and aggregation input.", + ), + ], + ) + def test_validate_aggregation_inputs_error(self, aggregated_flow_inputs, aggregation_inputs, error_message): + with pytest.raises(InvalidAggregationInput) as e: + FlowExecutor._validate_aggregation_inputs(aggregated_flow_inputs, aggregation_inputs) + assert str(e.value) == error_message + def func_with_stream_parameter(a: int, b: int, stream=False): return a + b, stream @@ -403,18 +543,13 @@ class TestGetAvailableMaxWorkerCount: @pytest.mark.parametrize( "total_memory, available_memory, process_memory, expected_max_worker_count, actual_calculate_worker_count", [ - (1024.0, 128.0, 64.0, 1, -3), # available_memory - 0.3 * total_memory < 0 + (1024.0, 128.0, 64.0, 1, -3), # available_memory - 0.3 * total_memory < 0 (1024.0, 307.20, 64.0, 1, 0), # available_memory - 0.3 * total_memory = 0 - (1024.0, 768.0, 64.0, 7, 7), # available_memory - 0.3 * total_memory > 0 + (1024.0, 768.0, 64.0, 7, 7), # available_memory - 0.3 * total_memory > 0 ], ) def test_get_available_max_worker_count( - self, - total_memory, - available_memory, - process_memory, - expected_max_worker_count, - actual_calculate_worker_count + self, total_memory, available_memory, process_memory, expected_max_worker_count, actual_calculate_worker_count ): with patch("psutil.virtual_memory") as mock_mem: mock_mem.return_value.total = total_memory * 1024 * 1024 @@ -425,7 +560,7 @@ def test_get_available_max_worker_count( mock_logger.warning.return_value = None max_worker_count = get_available_max_worker_count() assert max_worker_count == expected_max_worker_count - if (actual_calculate_worker_count < 1): + if actual_calculate_worker_count < 1: mock_logger.warning.assert_called_with( f"Available max worker count {actual_calculate_worker_count} is less than 1, " "set it to 1." diff --git a/src/promptflow/tests/test_configs/flows/default_input/flow.dag.yaml b/src/promptflow/tests/test_configs/flows/default_input/flow.dag.yaml new file mode 100644 index 00000000000..090d64306bf --- /dev/null +++ b/src/promptflow/tests/test_configs/flows/default_input/flow.dag.yaml @@ -0,0 +1,25 @@ +inputs: + question: + type: string + default: input value from default +outputs: + output: + type: string + reference: ${test_print_input.output} +nodes: +- name: test_print_input + type: python + source: + type: code + path: test_print_input.py + inputs: + input: ${inputs.question} +- name: aggregate_node + type: python + source: + type: code + path: test_print_aggregation.py + inputs: + inputs: ${inputs.question} + aggregation: true + use_variants: false diff --git a/src/promptflow/tests/test_configs/flows/default_input/test_print_aggregation.py b/src/promptflow/tests/test_configs/flows/default_input/test_print_aggregation.py new file mode 100644 index 00000000000..1d4ea25f2af --- /dev/null +++ b/src/promptflow/tests/test_configs/flows/default_input/test_print_aggregation.py @@ -0,0 +1,7 @@ +from typing import List +from promptflow import tool + +@tool +def test_print_input(inputs: List[str]): + print(inputs) + return inputs \ No newline at end of file diff --git a/src/promptflow/tests/test_configs/flows/default_input/test_print_input.py b/src/promptflow/tests/test_configs/flows/default_input/test_print_input.py new file mode 100644 index 00000000000..51ad6e7d91c --- /dev/null +++ b/src/promptflow/tests/test_configs/flows/default_input/test_print_input.py @@ -0,0 +1,7 @@ +from promptflow import tool + + +@tool +def test_print_input(input: str): + print(input) + return input \ No newline at end of file