From 5bcae19702f5a15949ca53b179714b56f2516022 Mon Sep 17 00:00:00 2001 From: Robben Wang <350053002@qq.com> Date: Thu, 7 Sep 2023 17:02:15 +0800 Subject: [PATCH] [Executor][Internal]Apply default input value in flow executor (#253) # Description - Honor flow input's default value in yaml for single node/flow/bulk run. - For bulk run, default value should be in higher priority than default mapping. Since default value is directly assigned by customer. So, don't assign default mapping when there is default value. # All Promptflow Contribution checklist: - [X] **The pull request does not introduce [breaking changes]** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [X] **I have read the [contribution guidelines](../CONTRIBUTING.md).** ## General Guidelines and Best Practices - [X] Title of the pull request is clear and informative. - [X] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [X] Pull request includes test coverage for the included changes. --------- Co-authored-by: Robben Wang --- src/promptflow/promptflow/executor/_errors.py | 14 +- .../promptflow/executor/flow_executor.py | 83 ++++++++- .../e2etests/test_executor_happypath.py | 30 ++++ .../unittests/executor/test_flow_executor.py | 165 ++++++++++++++++-- .../flows/default_input/flow.dag.yaml | 25 +++ .../default_input/test_print_aggregation.py | 7 + .../flows/default_input/test_print_input.py | 7 + 7 files changed, 305 insertions(+), 26 deletions(-) create mode 100644 src/promptflow/tests/test_configs/flows/default_input/flow.dag.yaml create mode 100644 src/promptflow/tests/test_configs/flows/default_input/test_print_aggregation.py create mode 100644 src/promptflow/tests/test_configs/flows/default_input/test_print_input.py diff --git a/src/promptflow/promptflow/executor/_errors.py b/src/promptflow/promptflow/executor/_errors.py index ac397086ab6..e70f19a6dbe 100644 --- a/src/promptflow/promptflow/executor/_errors.py +++ b/src/promptflow/promptflow/executor/_errors.py @@ -11,12 +11,6 @@ class InvalidCustomLLMTool(ValidationException): pass -class FlowExecutionError(SystemErrorException): - """Base System Exceptions for flow execution""" - - pass - - class ValueTypeUnresolved(ValidationException): pass @@ -101,11 +95,15 @@ class InputNotFound(InvalidFlowRequest): pass -class InputNotFoundFromAncestorNodeOutput(FlowExecutionError): +class InvalidAggregationInput(SystemErrorException): + pass + + +class InputNotFoundFromAncestorNodeOutput(SystemErrorException): pass -class NoNodeExecutedError(FlowExecutionError): +class NoNodeExecutedError(SystemErrorException): pass diff --git a/src/promptflow/promptflow/executor/flow_executor.py b/src/promptflow/promptflow/executor/flow_executor.py index 99ab07e0c1f..6d2d992e389 100644 --- a/src/promptflow/promptflow/executor/flow_executor.py +++ b/src/promptflow/promptflow/executor/flow_executor.py @@ -25,12 +25,13 @@ from promptflow._utils.context_utils import _change_working_dir from promptflow._utils.logger_utils import logger from promptflow._utils.utils import transpose -from promptflow.contracts.flow import Flow, InputAssignment, InputValueType, Node +from promptflow.contracts.flow import Flow, FlowInputDefinition, InputAssignment, InputValueType, Node from promptflow.contracts.run_info import FlowRunInfo, Status from promptflow.contracts.run_mode import RunMode from promptflow.exceptions import ErrorTarget, PromptflowException, SystemErrorException, ValidationException from promptflow.executor import _input_assignment_parser from promptflow.executor._errors import ( + InvalidAggregationInput, NodeConcurrencyNotFound, NodeOutputNotFound, OutputReferenceBypassed, @@ -195,6 +196,7 @@ def load_and_exec_node( if not node.source or not node.type: raise ValueError(f"Node {node_name} is not a valid node in flow {flow_file}.") + flow_inputs = FlowExecutor._apply_default_value_for_input(flow.inputs, flow_inputs) converted_flow_inputs_for_node = FlowValidator.convert_flow_inputs_for_node(flow, node, flow_inputs) package_tool_keys = [node.source.tool] if node.source and node.source.tool else [] tool_resolver = ToolResolver(working_dir, connections, package_tool_keys) @@ -400,8 +402,63 @@ def exec_aggregation( node_concurrency=DEFAULT_CONCURRENCY_FLOW, ) -> AggregationResult: self._node_concurrency = node_concurrency + aggregated_flow_inputs = dict(inputs or {}) + aggregation_inputs = dict(aggregation_inputs or {}) + self._validate_aggregation_inputs(aggregated_flow_inputs, aggregation_inputs) + aggregated_flow_inputs = self._apply_default_value_for_aggregation_input( + self._flow.inputs, aggregated_flow_inputs, aggregation_inputs + ) + with self._run_tracker.node_log_manager: - return self._exec_aggregation(inputs, aggregation_inputs, run_id) + return self._exec_aggregation(aggregated_flow_inputs, aggregation_inputs, run_id) + + @staticmethod + def _validate_aggregation_inputs(aggregated_flow_inputs: Mapping[str, Any], aggregation_inputs: Mapping[str, Any]): + """Validate the aggregation inputs according to the flow inputs.""" + for key, value in aggregated_flow_inputs.items(): + if key in aggregation_inputs: + raise InvalidAggregationInput( + message_format="Input '{input_key}' appear in both flow aggregation input and aggregation input.", + input_key=key, + ) + if not isinstance(value, list): + raise InvalidAggregationInput( + message_format="Flow aggregation input {input_key} should be one list.", input_key=key + ) + + for key, value in aggregation_inputs.items(): + if not isinstance(value, list): + raise InvalidAggregationInput( + message_format="Aggregation input {input_key} should be one list.", input_key=key + ) + + inputs_len = {key: len(value) for key, value in aggregated_flow_inputs.items()} + inputs_len.update({key: len(value) for key, value in aggregation_inputs.items()}) + if len(set(inputs_len.values())) > 1: + raise InvalidAggregationInput( + message_format="Whole aggregation inputs should have the same length. " + "Current key length mapping are: {key_len}", + key_len=inputs_len, + ) + + @staticmethod + def _apply_default_value_for_aggregation_input( + inputs: Dict[str, FlowInputDefinition], + aggregated_flow_inputs: Mapping[str, Any], + aggregation_inputs: Mapping[str, Any], + ): + aggregation_lines = 1 + if aggregated_flow_inputs.values(): + one_input_value = list(aggregated_flow_inputs.values())[0] + aggregation_lines = len(one_input_value) + # If aggregated_flow_inputs is empty, we should use aggregation_inputs to get the length. + elif aggregation_inputs.values(): + one_input_value = list(aggregation_inputs.values())[0] + aggregation_lines = len(one_input_value) + for key, value in inputs.items(): + if key not in aggregated_flow_inputs and (value and value.default): + aggregated_flow_inputs[key] = [value.default] * aggregation_lines + return aggregated_flow_inputs def _exec_aggregation( self, @@ -444,6 +501,7 @@ def _log_metric(key, value): def exec(self, inputs: dict, node_concurency=DEFAULT_CONCURRENCY_FLOW) -> dict: self._node_concurrency = node_concurency + inputs = FlowExecutor._apply_default_value_for_input(self._flow.inputs, inputs) result = self._exec(inputs) # TODO: remove this line once serving directly calling self.exec_line self._add_line_results([result]) @@ -482,6 +540,7 @@ def exec_line( allow_generator_output: bool = False, ) -> LineResult: self._node_concurrency = node_concurrency + inputs = FlowExecutor._apply_default_value_for_input(self._flow.inputs, inputs) # For flow run, validate inputs as default with self._run_tracker.node_log_manager: # exec_line interface may be called by exec_bulk, so we only set run_mode as flow run when @@ -539,6 +598,11 @@ def exec_bulk( BulkResults including flow results and metrics """ self._node_concurrency = node_concurrency + # Apply default value in early stage, so we can use it both in line execution and aggregation nodes execution. + inputs = [ + FlowExecutor._apply_default_value_for_input(self._flow.inputs, each_line_input) + for each_line_input in inputs + ] run_id = run_id or str(uuid.uuid4()) with self._run_tracker.node_log_manager: OperationContext.get_instance().run_mode = RunMode.Batch.name @@ -558,7 +622,15 @@ def exec_bulk( aggr_results=aggr_results, ) - def validate_and_apply_inputs_mapping(self, inputs, inputs_mapping): + @staticmethod + def _apply_default_value_for_input(inputs: Dict[str, FlowInputDefinition], line_inputs: Mapping) -> Dict[str, Any]: + updated_inputs = dict(line_inputs or {}) + for key, value in inputs.items(): + if key not in updated_inputs and (value and value.default): + updated_inputs[key] = value.default + return updated_inputs + + def validate_and_apply_inputs_mapping(self, inputs, inputs_mapping) -> List[Dict[str, Any]]: inputs_mapping = self._complete_inputs_mapping_by_default_value(inputs_mapping) resolved_inputs = self._apply_inputs_mapping_for_all_lines(inputs, inputs_mapping) return resolved_inputs @@ -566,6 +638,11 @@ def validate_and_apply_inputs_mapping(self, inputs, inputs_mapping): def _complete_inputs_mapping_by_default_value(self, inputs_mapping): inputs_mapping = inputs_mapping or {} result_mapping = self._default_inputs_mapping + # For input has default value, we don't try to read data from default mapping. + # Default value is in higher priority than default mapping. + for key, value in self._flow.inputs.items(): + if value and value.default: + del result_mapping[key] result_mapping.update(inputs_mapping) return result_mapping diff --git a/src/promptflow/tests/executor/e2etests/test_executor_happypath.py b/src/promptflow/tests/executor/e2etests/test_executor_happypath.py index 8a79298117b..5c2874bf73a 100644 --- a/src/promptflow/tests/executor/e2etests/test_executor_happypath.py +++ b/src/promptflow/tests/executor/e2etests/test_executor_happypath.py @@ -342,3 +342,33 @@ def test_executor_creation_with_default_variants(self, flow_folder, dev_connecti executor = FlowExecutor.create(get_yaml_file(flow_folder), dev_connections) flow_result = executor.exec_line(self.get_line_inputs()) assert flow_result.run_info.status == Status.Completed + + def test_executor_creation_with_default_input(self): + # Assert for single node run. + default_input_value = "input value from default" + yaml_file = get_yaml_file("default_input") + executor = FlowExecutor.create(yaml_file, {}) + node_result = executor.load_and_exec_node(yaml_file, "test_print_input") + assert node_result.status == Status.Completed + assert node_result.output == default_input_value + + # Assert for flow run. + flow_result = executor.exec_line({}) + assert flow_result.run_info.status == Status.Completed + assert flow_result.output["output"] == default_input_value + aggr_results = executor.exec_aggregation({}, aggregation_inputs={}) + flow_aggregate_node = aggr_results.node_run_infos["aggregate_node"] + assert flow_aggregate_node.status == Status.Completed + assert flow_aggregate_node.output == [default_input_value] + + # Assert for bulk run. + bulk_result = executor.exec_bulk([{}]) + assert bulk_result.line_results[0].run_info.status == Status.Completed + assert bulk_result.line_results[0].output["output"] == default_input_value + bulk_aggregate_node = bulk_result.aggr_results.node_run_infos["aggregate_node"] + assert bulk_aggregate_node.status == Status.Completed + assert bulk_aggregate_node.output == [default_input_value] + + # Assert for exec + exec_result = executor.exec({}) + assert exec_result["output"] == default_input_value diff --git a/src/promptflow/tests/executor/unittests/executor/test_flow_executor.py b/src/promptflow/tests/executor/unittests/executor/test_flow_executor.py index e22644707b8..90b3581c399 100644 --- a/src/promptflow/tests/executor/unittests/executor/test_flow_executor.py +++ b/src/promptflow/tests/executor/unittests/executor/test_flow_executor.py @@ -4,7 +4,10 @@ import pytest from promptflow import tool -from promptflow.contracts.flow import Flow +from promptflow.contracts.flow import Flow, FlowInputDefinition +from promptflow.contracts.tool import ValueType +from promptflow.executor._errors import InvalidAggregationInput +from promptflow.executor._line_execution_process_pool import get_available_max_worker_count from promptflow.executor.flow_executor import ( EmptyInputAfterMapping, EmptyInputListError, @@ -12,12 +15,11 @@ LineNumberNotAlign, MappingSourceNotFound, NoneInputsMappingIsNotSupported, - enable_streaming_for_llm_tool, _ensure_node_result_is_serializable, _inject_stream_options, + enable_streaming_for_llm_tool, ) from promptflow.tools.aoai import AzureOpenAI, chat, completion -from promptflow.executor._line_execution_process_pool import get_available_max_worker_count @pytest.mark.unittest @@ -148,9 +150,12 @@ def test_merge_input_dicts_by_line_error(self, inputs, error_code, error_message @pytest.mark.parametrize("inputs_mapping", [{"question": "${data.question}"}, {}]) def test_complete_inputs_mapping_by_default_value(self, inputs_mapping): - flow = Flow( - id="fakeId", name=None, nodes=[], inputs={"question": None, "groundtruth": None}, outputs=None, tools=[] - ) + inputs = { + "question": None, + "groundtruth": None, + "input_with_default_value": FlowInputDefinition(type=ValueType.INT, default="default_value"), + } + flow = Flow(id="fakeId", name=None, nodes=[], inputs=inputs, outputs=None, tools=[]) flow_executor = FlowExecutor( flow=flow, connections=None, @@ -159,6 +164,7 @@ def test_complete_inputs_mapping_by_default_value(self, inputs_mapping): loaded_tools=None, ) updated_inputs_mapping = flow_executor._complete_inputs_mapping_by_default_value(inputs_mapping) + assert "input_with_default_value" not in updated_inputs_mapping assert updated_inputs_mapping == {"question": "${data.question}", "groundtruth": "${data.groundtruth}"} @pytest.mark.parametrize( @@ -306,6 +312,140 @@ def test_inputs_mapping_for_all_lines_error(self, inputs, inputs_mapping, error_ FlowExecutor._apply_inputs_mapping_for_all_lines(inputs, inputs_mapping) assert error_message == str(e.value), "Expected: {}, Actual: {}".format(error_message, str(e.value)) + @pytest.mark.parametrize( + "flow_inputs, inputs, expected_inputs", + [ + ( + { + "input_from_default": FlowInputDefinition(type=ValueType.STRING, default="default_value"), + }, + None, # Could handle None input + {"input_from_default": "default_value"}, + ), + ( + { + "input_from_default": FlowInputDefinition(type=ValueType.STRING, default="default_value"), + }, + {}, + {"input_from_default": "default_value"}, + ), + ( + { + "input_no_default": FlowInputDefinition(type=ValueType.STRING), + }, + {}, + {}, # No default value for input. + ), + ( + { + "input_from_default": FlowInputDefinition(type=ValueType.STRING, default="default_value"), + }, + {"input_from_default": "input_value", "another_key": "input_value"}, + {"input_from_default": "input_value", "another_key": "input_value"}, + ), + ], + ) + def test_apply_default_value_for_input(self, flow_inputs, inputs, expected_inputs): + result = FlowExecutor._apply_default_value_for_input(flow_inputs, inputs) + assert result == expected_inputs + + @pytest.mark.parametrize( + "flow_inputs, aggregated_flow_inputs, aggregation_inputs, expected_inputs", + [ + ( + { + "input_from_default": FlowInputDefinition(type=ValueType.STRING, default="default_value"), + }, + {}, + {}, + {"input_from_default": ["default_value"]}, + ), + ( + { + "input_no_default": FlowInputDefinition(type=ValueType.STRING), + }, + {}, + {}, + {}, # No default value for input. + ), + ( + { + "input_from_default": FlowInputDefinition(type=ValueType.STRING, default="default_value"), + }, + {"input_from_default": "input_value", "another_key": "input_value"}, + {}, + {"input_from_default": "input_value", "another_key": "input_value"}, + ), + ( + { + "input_from_default": FlowInputDefinition(type=ValueType.STRING, default="default_value"), + }, + {"another_key": ["input_value", "input_value"]}, + {}, + { + "input_from_default": ["default_value", "default_value"], + "another_key": ["input_value", "input_value"], + }, + ), + ( + { + "input_from_default": FlowInputDefinition(type=ValueType.STRING, default="default_value"), + }, + {}, + {"another_key_in_aggregation_inputs": ["input_value", "input_value"]}, + { + "input_from_default": ["default_value", "default_value"], + }, + ), + ], + ) + def test_apply_default_value_for_aggregation_input( + self, flow_inputs, aggregated_flow_inputs, aggregation_inputs, expected_inputs + ): + result = FlowExecutor._apply_default_value_for_aggregation_input( + flow_inputs, aggregated_flow_inputs, aggregation_inputs + ) + assert result == expected_inputs + + @pytest.mark.parametrize( + "aggregated_flow_inputs, aggregation_inputs, error_message", + [ + ( + {}, + { + "input1": "value1", + }, + "Aggregation input input1 should be one list.", + ), + ( + { + "input1": "value1", + }, + {}, + "Flow aggregation input input1 should be one list.", + ), + ( + {"input1": ["value1_1", "value1_2"]}, + {"input_2": ["value2_1"]}, + "Whole aggregation inputs should have the same length. " + "Current key length mapping are: {'input1': 2, 'input_2': 1}", + ), + ( + { + "input1": "value1", + }, + { + "input1": "value1", + }, + "Input 'input1' appear in both flow aggregation input and aggregation input.", + ), + ], + ) + def test_validate_aggregation_inputs_error(self, aggregated_flow_inputs, aggregation_inputs, error_message): + with pytest.raises(InvalidAggregationInput) as e: + FlowExecutor._validate_aggregation_inputs(aggregated_flow_inputs, aggregation_inputs) + assert str(e.value) == error_message + def func_with_stream_parameter(a: int, b: int, stream=False): return a + b, stream @@ -403,18 +543,13 @@ class TestGetAvailableMaxWorkerCount: @pytest.mark.parametrize( "total_memory, available_memory, process_memory, expected_max_worker_count, actual_calculate_worker_count", [ - (1024.0, 128.0, 64.0, 1, -3), # available_memory - 0.3 * total_memory < 0 + (1024.0, 128.0, 64.0, 1, -3), # available_memory - 0.3 * total_memory < 0 (1024.0, 307.20, 64.0, 1, 0), # available_memory - 0.3 * total_memory = 0 - (1024.0, 768.0, 64.0, 7, 7), # available_memory - 0.3 * total_memory > 0 + (1024.0, 768.0, 64.0, 7, 7), # available_memory - 0.3 * total_memory > 0 ], ) def test_get_available_max_worker_count( - self, - total_memory, - available_memory, - process_memory, - expected_max_worker_count, - actual_calculate_worker_count + self, total_memory, available_memory, process_memory, expected_max_worker_count, actual_calculate_worker_count ): with patch("psutil.virtual_memory") as mock_mem: mock_mem.return_value.total = total_memory * 1024 * 1024 @@ -425,7 +560,7 @@ def test_get_available_max_worker_count( mock_logger.warning.return_value = None max_worker_count = get_available_max_worker_count() assert max_worker_count == expected_max_worker_count - if (actual_calculate_worker_count < 1): + if actual_calculate_worker_count < 1: mock_logger.warning.assert_called_with( f"Available max worker count {actual_calculate_worker_count} is less than 1, " "set it to 1." diff --git a/src/promptflow/tests/test_configs/flows/default_input/flow.dag.yaml b/src/promptflow/tests/test_configs/flows/default_input/flow.dag.yaml new file mode 100644 index 00000000000..090d64306bf --- /dev/null +++ b/src/promptflow/tests/test_configs/flows/default_input/flow.dag.yaml @@ -0,0 +1,25 @@ +inputs: + question: + type: string + default: input value from default +outputs: + output: + type: string + reference: ${test_print_input.output} +nodes: +- name: test_print_input + type: python + source: + type: code + path: test_print_input.py + inputs: + input: ${inputs.question} +- name: aggregate_node + type: python + source: + type: code + path: test_print_aggregation.py + inputs: + inputs: ${inputs.question} + aggregation: true + use_variants: false diff --git a/src/promptflow/tests/test_configs/flows/default_input/test_print_aggregation.py b/src/promptflow/tests/test_configs/flows/default_input/test_print_aggregation.py new file mode 100644 index 00000000000..1d4ea25f2af --- /dev/null +++ b/src/promptflow/tests/test_configs/flows/default_input/test_print_aggregation.py @@ -0,0 +1,7 @@ +from typing import List +from promptflow import tool + +@tool +def test_print_input(inputs: List[str]): + print(inputs) + return inputs \ No newline at end of file diff --git a/src/promptflow/tests/test_configs/flows/default_input/test_print_input.py b/src/promptflow/tests/test_configs/flows/default_input/test_print_input.py new file mode 100644 index 00000000000..51ad6e7d91c --- /dev/null +++ b/src/promptflow/tests/test_configs/flows/default_input/test_print_input.py @@ -0,0 +1,7 @@ +from promptflow import tool + + +@tool +def test_print_input(input: str): + print(input) + return input \ No newline at end of file