diff --git a/src/promptflow/promptflow/_cli/_params.py b/src/promptflow/promptflow/_cli/_params.py index 7e1677b6787..128c5da8384 100644 --- a/src/promptflow/promptflow/_cli/_params.py +++ b/src/promptflow/promptflow/_cli/_params.py @@ -11,11 +11,13 @@ def __call__(self, parser, namespace, values, option_string=None): super(AppendToDictAction, self).__call__(parser, namespace, action, option_string) def get_action(self, values, option_string): # pylint: disable=no-self-use + from promptflow._sdk._utils import strip_quotation + kwargs = {} for item in values: try: - key, value = item.split("=", 1) - kwargs[key] = value + key, value = strip_quotation(item).split("=", 1) + kwargs[key] = strip_quotation(value) except ValueError: raise Exception("Usage error: {} KEY=VALUE [KEY=VALUE ...]".format(option_string)) return kwargs diff --git a/src/promptflow/promptflow/_sdk/_utils.py b/src/promptflow/promptflow/_sdk/_utils.py index c72c0ca7046..c35d2fa75a8 100644 --- a/src/promptflow/promptflow/_sdk/_utils.py +++ b/src/promptflow/promptflow/_sdk/_utils.py @@ -239,9 +239,22 @@ def load_from_dict(schema: Any, data: Dict, context: Dict, additional_message: s raise ValidationError(decorate_validation_error(schema, pretty_error, additional_message)) +def strip_quotation(value): + """ + To avoid escaping chars in command args, args will be surrounded in quotas. + Need to remove the pair of quotation first. + """ + if value.startswith('"') and value.endswith('"'): + return value[1:-1] + elif value.startswith("'") and value.endswith("'"): + return value[1:-1] + else: + return value + + def parse_variant(variant: str) -> Tuple[str, str]: variant_regex = r"\${([^.]+).([^}]+)}" - match = re.match(variant_regex, variant) + match = re.match(variant_regex, strip_quotation(variant)) if match: return match.group(1), match.group(2) else: diff --git a/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py b/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py index 6851fcddded..75100ab58d0 100644 --- a/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py +++ b/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py @@ -86,6 +86,8 @@ def test_basic_flow_run_batch_and_eval(self) -> None: ) assert "Completed" in f.getvalue() + # Check the CLI works correctly when the parameter is surrounded by quotation, as below shown: + # --param "key=value" key="value" f = io.StringIO() with contextlib.redirect_stdout(f): run_pf_command( @@ -94,8 +96,8 @@ def test_basic_flow_run_batch_and_eval(self) -> None: "--flow", f"{FLOWS_DIR}/classification_accuracy_evaluation", "--column-mapping", - "groundtruth=${data.answer}", - "prediction=${run.outputs.category}", + "'groundtruth=${data.answer}'", + "prediction='${run.outputs.category}'", "variant_id=${data.variant_id}", "--data", f"{DATAS_DIR}/webClassification3.jsonl", @@ -309,7 +311,7 @@ def test_pf_flow_with_variant(self, capsys): "answer=Channel", "evidence=Url", "--variant", - "${summarize_text_content.variant_1}", + "'${summarize_text_content.variant_1}'", ) output_path = Path(temp_dir) / ".promptflow" / "flow-summarize_text_content-variant_1.output.json" assert output_path.exists() diff --git a/src/promptflow/tests/sdk_cli_test/unittests/test_utils.py b/src/promptflow/tests/sdk_cli_test/unittests/test_utils.py index 352c80faa10..47654bf50e4 100644 --- a/src/promptflow/tests/sdk_cli_test/unittests/test_utils.py +++ b/src/promptflow/tests/sdk_cli_test/unittests/test_utils.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +import argparse import os import shutil import tempfile @@ -16,6 +17,7 @@ _calculate_column_widths, list_of_dict_to_nested_dict, ) +from promptflow._cli._params import AppendToDictAction from promptflow._sdk._errors import GenerateFlowToolsJsonError from promptflow._sdk._utils import ( decrypt_secret_value, @@ -123,6 +125,20 @@ def test_list_of_dict_to_nested_dict(self): result = list_of_dict_to_nested_dict(test_list) assert result == {"node1": {"connection": "a", "deploy_name": "b"}} + def test_append_to_dict_action(self): + parser = argparse.ArgumentParser(prog="test_dict_action") + parser.add_argument("--dict", action=AppendToDictAction, nargs="+") + args = ["--dict", "key1=val1", "\'key2=val2\'", "\"key3=val3\"", "key4=\'val4\'", "key5=\"val5'"] + args = parser.parse_args(args) + expect_dict = { + "key1": "val1", + "key2": "val2", + "key3": "val3", + "key4": "val4", + "key5": "\"val5'", + } + assert args.dict[0] == expect_dict + def test_build_sorted_column_widths_tuple_list(self) -> None: columns = ["col1", "col2", "col3"] values1 = {"col1": 1, "col2": 4, "col3": 3}