Skip to content

Commit

Permalink
[Executor][Internal]Apply default input value in flow executor (micro…
Browse files Browse the repository at this point in the history
…soft#253)

# Description

- Honor flow input's default value in yaml for single node/flow/bulk
run.
- For bulk run, default value should be in higher priority than default
mapping. Since default value is directly assigned by customer. So, don't
assign default mapping when there is default value.


# All Promptflow Contribution checklist:
- [X] **The pull request does not introduce [breaking changes]**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [X] **I have read the [contribution guidelines](../CONTRIBUTING.md).**

## General Guidelines and Best Practices
- [X] Title of the pull request is clear and informative.
- [X] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [X] Pull request includes test coverage for the included changes.

---------

Co-authored-by: Robben Wang <[email protected]>
  • Loading branch information
huaiyan and Robben Wang authored Sep 7, 2023
1 parent 6d31792 commit 5bcae19
Show file tree
Hide file tree
Showing 7 changed files with 305 additions and 26 deletions.
14 changes: 6 additions & 8 deletions src/promptflow/promptflow/executor/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,6 @@ class InvalidCustomLLMTool(ValidationException):
pass


class FlowExecutionError(SystemErrorException):
"""Base System Exceptions for flow execution"""

pass


class ValueTypeUnresolved(ValidationException):
pass

Expand Down Expand Up @@ -101,11 +95,15 @@ class InputNotFound(InvalidFlowRequest):
pass


class InputNotFoundFromAncestorNodeOutput(FlowExecutionError):
class InvalidAggregationInput(SystemErrorException):
pass


class InputNotFoundFromAncestorNodeOutput(SystemErrorException):
pass


class NoNodeExecutedError(FlowExecutionError):
class NoNodeExecutedError(SystemErrorException):
pass


Expand Down
83 changes: 80 additions & 3 deletions src/promptflow/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
from promptflow._utils.context_utils import _change_working_dir
from promptflow._utils.logger_utils import logger
from promptflow._utils.utils import transpose
from promptflow.contracts.flow import Flow, InputAssignment, InputValueType, Node
from promptflow.contracts.flow import Flow, FlowInputDefinition, InputAssignment, InputValueType, Node
from promptflow.contracts.run_info import FlowRunInfo, Status
from promptflow.contracts.run_mode import RunMode
from promptflow.exceptions import ErrorTarget, PromptflowException, SystemErrorException, ValidationException
from promptflow.executor import _input_assignment_parser
from promptflow.executor._errors import (
InvalidAggregationInput,
NodeConcurrencyNotFound,
NodeOutputNotFound,
OutputReferenceBypassed,
Expand Down Expand Up @@ -195,6 +196,7 @@ def load_and_exec_node(
if not node.source or not node.type:
raise ValueError(f"Node {node_name} is not a valid node in flow {flow_file}.")

flow_inputs = FlowExecutor._apply_default_value_for_input(flow.inputs, flow_inputs)
converted_flow_inputs_for_node = FlowValidator.convert_flow_inputs_for_node(flow, node, flow_inputs)
package_tool_keys = [node.source.tool] if node.source and node.source.tool else []
tool_resolver = ToolResolver(working_dir, connections, package_tool_keys)
Expand Down Expand Up @@ -400,8 +402,63 @@ def exec_aggregation(
node_concurrency=DEFAULT_CONCURRENCY_FLOW,
) -> AggregationResult:
self._node_concurrency = node_concurrency
aggregated_flow_inputs = dict(inputs or {})
aggregation_inputs = dict(aggregation_inputs or {})
self._validate_aggregation_inputs(aggregated_flow_inputs, aggregation_inputs)
aggregated_flow_inputs = self._apply_default_value_for_aggregation_input(
self._flow.inputs, aggregated_flow_inputs, aggregation_inputs
)

with self._run_tracker.node_log_manager:
return self._exec_aggregation(inputs, aggregation_inputs, run_id)
return self._exec_aggregation(aggregated_flow_inputs, aggregation_inputs, run_id)

@staticmethod
def _validate_aggregation_inputs(aggregated_flow_inputs: Mapping[str, Any], aggregation_inputs: Mapping[str, Any]):
"""Validate the aggregation inputs according to the flow inputs."""
for key, value in aggregated_flow_inputs.items():
if key in aggregation_inputs:
raise InvalidAggregationInput(
message_format="Input '{input_key}' appear in both flow aggregation input and aggregation input.",
input_key=key,
)
if not isinstance(value, list):
raise InvalidAggregationInput(
message_format="Flow aggregation input {input_key} should be one list.", input_key=key
)

for key, value in aggregation_inputs.items():
if not isinstance(value, list):
raise InvalidAggregationInput(
message_format="Aggregation input {input_key} should be one list.", input_key=key
)

inputs_len = {key: len(value) for key, value in aggregated_flow_inputs.items()}
inputs_len.update({key: len(value) for key, value in aggregation_inputs.items()})
if len(set(inputs_len.values())) > 1:
raise InvalidAggregationInput(
message_format="Whole aggregation inputs should have the same length. "
"Current key length mapping are: {key_len}",
key_len=inputs_len,
)

@staticmethod
def _apply_default_value_for_aggregation_input(
inputs: Dict[str, FlowInputDefinition],
aggregated_flow_inputs: Mapping[str, Any],
aggregation_inputs: Mapping[str, Any],
):
aggregation_lines = 1
if aggregated_flow_inputs.values():
one_input_value = list(aggregated_flow_inputs.values())[0]
aggregation_lines = len(one_input_value)
# If aggregated_flow_inputs is empty, we should use aggregation_inputs to get the length.
elif aggregation_inputs.values():
one_input_value = list(aggregation_inputs.values())[0]
aggregation_lines = len(one_input_value)
for key, value in inputs.items():
if key not in aggregated_flow_inputs and (value and value.default):
aggregated_flow_inputs[key] = [value.default] * aggregation_lines
return aggregated_flow_inputs

def _exec_aggregation(
self,
Expand Down Expand Up @@ -444,6 +501,7 @@ def _log_metric(key, value):

def exec(self, inputs: dict, node_concurency=DEFAULT_CONCURRENCY_FLOW) -> dict:
self._node_concurrency = node_concurency
inputs = FlowExecutor._apply_default_value_for_input(self._flow.inputs, inputs)
result = self._exec(inputs)
# TODO: remove this line once serving directly calling self.exec_line
self._add_line_results([result])
Expand Down Expand Up @@ -482,6 +540,7 @@ def exec_line(
allow_generator_output: bool = False,
) -> LineResult:
self._node_concurrency = node_concurrency
inputs = FlowExecutor._apply_default_value_for_input(self._flow.inputs, inputs)
# For flow run, validate inputs as default
with self._run_tracker.node_log_manager:
# exec_line interface may be called by exec_bulk, so we only set run_mode as flow run when
Expand Down Expand Up @@ -539,6 +598,11 @@ def exec_bulk(
BulkResults including flow results and metrics
"""
self._node_concurrency = node_concurrency
# Apply default value in early stage, so we can use it both in line execution and aggregation nodes execution.
inputs = [
FlowExecutor._apply_default_value_for_input(self._flow.inputs, each_line_input)
for each_line_input in inputs
]
run_id = run_id or str(uuid.uuid4())
with self._run_tracker.node_log_manager:
OperationContext.get_instance().run_mode = RunMode.Batch.name
Expand All @@ -558,14 +622,27 @@ def exec_bulk(
aggr_results=aggr_results,
)

def validate_and_apply_inputs_mapping(self, inputs, inputs_mapping):
@staticmethod
def _apply_default_value_for_input(inputs: Dict[str, FlowInputDefinition], line_inputs: Mapping) -> Dict[str, Any]:
updated_inputs = dict(line_inputs or {})
for key, value in inputs.items():
if key not in updated_inputs and (value and value.default):
updated_inputs[key] = value.default
return updated_inputs

def validate_and_apply_inputs_mapping(self, inputs, inputs_mapping) -> List[Dict[str, Any]]:
inputs_mapping = self._complete_inputs_mapping_by_default_value(inputs_mapping)
resolved_inputs = self._apply_inputs_mapping_for_all_lines(inputs, inputs_mapping)
return resolved_inputs

def _complete_inputs_mapping_by_default_value(self, inputs_mapping):
inputs_mapping = inputs_mapping or {}
result_mapping = self._default_inputs_mapping
# For input has default value, we don't try to read data from default mapping.
# Default value is in higher priority than default mapping.
for key, value in self._flow.inputs.items():
if value and value.default:
del result_mapping[key]
result_mapping.update(inputs_mapping)
return result_mapping

Expand Down
30 changes: 30 additions & 0 deletions src/promptflow/tests/executor/e2etests/test_executor_happypath.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,33 @@ def test_executor_creation_with_default_variants(self, flow_folder, dev_connecti
executor = FlowExecutor.create(get_yaml_file(flow_folder), dev_connections)
flow_result = executor.exec_line(self.get_line_inputs())
assert flow_result.run_info.status == Status.Completed

def test_executor_creation_with_default_input(self):
# Assert for single node run.
default_input_value = "input value from default"
yaml_file = get_yaml_file("default_input")
executor = FlowExecutor.create(yaml_file, {})
node_result = executor.load_and_exec_node(yaml_file, "test_print_input")
assert node_result.status == Status.Completed
assert node_result.output == default_input_value

# Assert for flow run.
flow_result = executor.exec_line({})
assert flow_result.run_info.status == Status.Completed
assert flow_result.output["output"] == default_input_value
aggr_results = executor.exec_aggregation({}, aggregation_inputs={})
flow_aggregate_node = aggr_results.node_run_infos["aggregate_node"]
assert flow_aggregate_node.status == Status.Completed
assert flow_aggregate_node.output == [default_input_value]

# Assert for bulk run.
bulk_result = executor.exec_bulk([{}])
assert bulk_result.line_results[0].run_info.status == Status.Completed
assert bulk_result.line_results[0].output["output"] == default_input_value
bulk_aggregate_node = bulk_result.aggr_results.node_run_infos["aggregate_node"]
assert bulk_aggregate_node.status == Status.Completed
assert bulk_aggregate_node.output == [default_input_value]

# Assert for exec
exec_result = executor.exec({})
assert exec_result["output"] == default_input_value
Loading

0 comments on commit 5bcae19

Please sign in to comment.