diff --git a/scripts/utils/generate_tool_meta_utils.py b/scripts/utils/generate_tool_meta_utils.py index e519c4e8b3c..dec516f4889 100644 --- a/scripts/utils/generate_tool_meta_utils.py +++ b/scripts/utils/generate_tool_meta_utils.py @@ -12,7 +12,8 @@ from promptflow.core.tool import ToolProvider from promptflow.contracts.tool import InputDefinition, Tool, ToolType, ValueType from promptflow.exceptions import ErrorTarget, UserErrorException -from promptflow.utils.tool_utils import function_to_interface, get_inputs_for_prompt_template + +from utils.tool_utils import function_to_interface, get_inputs_for_prompt_template def asdict_without_none(obj): diff --git a/scripts/utils/tool_utils.py b/scripts/utils/tool_utils.py new file mode 100644 index 00000000000..5c19e1643ee --- /dev/null +++ b/scripts/utils/tool_utils.py @@ -0,0 +1,94 @@ +import inspect +from enum import Enum, EnumMeta +from typing import Callable, Union, get_args, get_origin +from jinja2 import Environment, meta +from promptflow.contracts.tool import ConnectionType, InputDefinition, ValueType + + +def value_to_str(val): + if val is inspect.Parameter.empty: + # For empty case, default field will be skipped when dumping to json + return None + if val is None: + # Dump default: "" in json to avoid UI validation error + return "" + if isinstance(val, Enum): + return val.value + return str(val) + + +def resolve_annotation(anno) -> Union[str, list]: + """Resolve the union annotation to type list.""" + origin = get_origin(anno) + if origin != Union: + return anno + # Optional[Type] is Union[Type, NoneType], filter NoneType out + args = [arg for arg in get_args(anno) if arg != type(None)] # noqa: E721 + return args[0] if len(args) == 1 else args + + +def param_to_definition(param) -> (InputDefinition, bool): + default_value = param.default + # Get value type and enum from annotation + value_type = resolve_annotation(param.annotation) + enum = None + # Get value type and enum from default if no annotation + if default_value is not inspect.Parameter.empty and value_type == inspect.Parameter.empty: + value_type = default_value.__class__ if isinstance(default_value, Enum) else type(default_value) + # Extract enum for enum class + if isinstance(value_type, EnumMeta): + enum = [str(option.value) for option in value_type] + value_type = str + is_connection = False + if ConnectionType.is_connection_value(value_type): + typ = [value_type.__name__] + is_connection = True + elif isinstance(value_type, list): + if not all(ConnectionType.is_connection_value(t) for t in value_type): + typ = [ValueType.OBJECT] + else: + typ = [t.__name__ for t in value_type] + is_connection = True + else: + typ = [ValueType.from_type(value_type)] + return InputDefinition(type=typ, default=value_to_str(default_value), description=None, enum=enum), is_connection + + +def function_to_interface(f: Callable, initialize_inputs=None) -> tuple: + sign = inspect.signature(f) + all_inputs = {} + input_defs = {} + connection_types = [] + # Collect all inputs from class and func + if initialize_inputs: + if any(k for k in initialize_inputs if k in sign.parameters): + raise Exception(f'Duplicate inputs found from {f.__name__!r} and "__init__()"!') + all_inputs = {**initialize_inputs} + all_inputs.update( + { + k: v + for k, v in sign.parameters.items() + if k != "self" and v.kind != v.VAR_KEYWORD and v.kind != v.VAR_POSITIONAL # TODO: Handle these cases + } + ) + # Resolve inputs to definitions. + for k, v in all_inputs.items(): + input_def, is_connection = param_to_definition(v) + input_defs[k] = input_def + if is_connection: + connection_types.append(input_def.type) + outputs = {} + # Note: We don't have output definition now + # outputs = {"output": OutputDefinition("output", [ValueType.from_type(type(sign.return_annotation))], "", True)} + # if is_dataclass(sign.return_annotation): + # for f in fields(sign.return_annotation): + # outputs[f.name] = OutputDefinition(f.name, [ValueType.from_type( + # type(getattr(sign.return_annotation, f.name)))], "", False) + return input_defs, outputs, connection_types + + +def get_inputs_for_prompt_template(template_str): + """Get all input variable names from a jinja2 template string.""" + env = Environment() + template = env.parse(template_str) + return sorted(meta.find_undeclared_variables(template), key=lambda x: template_str.find(x)) diff --git a/src/promptflow-tools/promptflow/tools/__init__.py b/src/promptflow-tools/promptflow/tools/__init__.py index fe831864618..9c0ba52066b 100644 --- a/src/promptflow-tools/promptflow/tools/__init__.py +++ b/src/promptflow-tools/promptflow/tools/__init__.py @@ -4,6 +4,5 @@ from .azure_language_detector import get_language # noqa: F401 from .azure_form_recognizer import AzureFormRecognizer # noqa: F401 from .azure_translator import get_translation, AzureTranslator # noqa: F401 -from .bing import Bing # noqa: F401 from .openai import OpenAI # noqa: F401 from .serpapi import SerpAPI # noqa: F401 diff --git a/src/promptflow-tools/promptflow/tools/bing.py b/src/promptflow-tools/promptflow/tools/bing.py deleted file mode 100644 index d3ba1ee22a3..00000000000 --- a/src/promptflow-tools/promptflow/tools/bing.py +++ /dev/null @@ -1,179 +0,0 @@ -import json -import sys - -from promptflow.core.tool import ToolProvider, tool -from promptflow.connections import BingConnection -from promptflow.core.cache_manager import enable_cache -from promptflow.core.tools_manager import register_builtin_method, register_builtins -from promptflow.exceptions import ErrorTarget, PromptflowException, SystemErrorException, UserErrorException -from promptflow.tools.common import to_bool -from promptflow.utils.utils import deprecated - - -class Bing(ToolProvider): - """ - API Reference: - https://learn.microsoft.com/en-us/rest/api/cognitiveservices-bingsearch/bing-web-api-v7-reference - Parameter: - https://learn.microsoft.com/en-us/rest/api/cognitiveservices-bingsearch/bing-web-api-v7-reference#query-parameters - """ - - def __init__(self, connection: BingConnection): - super().__init__() - self.connection = connection - - @staticmethod - @deprecated(replace="Bing()") - def from_config(config: BingConnection): - return Bing(config) - - def calculate_cache_string_for_search(self, **kwargs): - return json.dumps(kwargs) - - def extract_error_message_and_code(self, error_json): - error_message = "" - if not error_json: - return error_message - error_code_reference = ( - "For more info, please refer to " - "https://learn.microsoft.com/en-us/bing/search-apis/" - "bing-web-search/reference/error-codes" - ) - if "message" in error_json: - # populate error_message with the top error items - error_message += (" " if len(error_message) > 0 else "") + error_json["message"] - if "code" in error_json: - # Append error code if existing - code_message = f"code: {error_json['code']}. {error_code_reference}" - error_message += (" " if len(error_message) > 0 else "") + code_message - return error_message - - def extract_error_message(self, error_data): - error_message = "" - # For request was rejected. For example, the api_key is not valid - if "error" in error_data: - error_message = self.extract_error_message_and_code(error_data["error"]) - - # For request accepted but bing responded with non-success code. - # For example, the parameters value is not valid - if "errors" in error_data and error_data["errors"] and len(error_data["errors"]) > 0: - # populate error_message with the top error items - error_message = self.extract_error_message_and_code(error_data["errors"][0]) - return error_message - - def safe_extract_error_message(self, response): - default_error_message = "Bing search request failed. Please check logs for details." - try: - error_data = response.json() - print(f"Response json: {json.dumps(error_data)}", file=sys.stderr) - error_message = self.extract_error_message(error_data) - error_message = error_message if len(error_message) > 0 else default_error_message - print(f"Extracted error message: {error_message}", file=sys.stderr) - return error_message - except Exception as e: - # Swallow any exception when extract detailed error message - print( - f"Unexpected exception occurs while extract error message " - f"from response: {type(e).__name__}: {str(e)}", - file=sys.stderr, - ) - return default_error_message - - @tool - @enable_cache(calculate_cache_string_for_search) - def search( - self, - query: str, - answerCount: int = None, - cc: str = None, # country code - count: int = 10, - freshness: str = None, - mkt: str = None, # market defined by - - offset: int = 0, - promote: list = [], - responseFilter: list = [], - safesearch: str = "Moderate", - setLang: str = "en", - textDecorations: bool = False, - textFormat: str = "Raw", - ): - import requests - - # there are others header parameters as well - headers = {"Ocp-Apim-Subscription-Key": str(self.connection.api_key)} - params = { - "q": query, - "answerCount": int(answerCount) if answerCount else None, - "cc": cc, - "count": int(count), - "freshness": freshness, - "mkt": mkt, - "offset": int(offset), - "promote": list(json.loads(promote)) if promote else [], - "responseFilter": list(json.loads(responseFilter)) if responseFilter else [], - "safesearch": safesearch, - "setLang": setLang, - "textDecorations": to_bool(textDecorations), - "textFormat": textFormat, - } - - try: - response = requests.get(self.connection.url, headers=headers, params=params) - if response.status_code == requests.codes.ok: - return response.json() - else: - # Handle status_code is not ok - # Step I: Try to get accurate error message at best - error_message = self.safe_extract_error_message(response) - - # Step II: Construct PromptflowException - if response.status_code >= 500: - raise SystemErrorException(message=error_message, target=ErrorTarget.TOOL) - else: - raise UserErrorException(message=error_message, target=ErrorTarget.TOOL) - except Exception as e: - if not isinstance(e, PromptflowException): - error_message = "Unexpected exception occurs. Please check logs for details." - print(f"Unexpected exception occurs: {type(e).__name__}: {str(e)}", file=sys.stderr) - raise SystemErrorException(message=error_message, target=ErrorTarget.TOOL) - raise - - -register_builtins(Bing) - - -@tool -def search( - connection: BingConnection, - query: str, - answerCount: int = None, - cc: str = None, # country code - count: int = 10, - freshness: str = None, - mkt: str = None, # market defined by - - offset: int = 0, - promote: list = [], - responseFilter: list = [], - safesearch: str = "Moderate", - setLang: str = "en", - textDecorations: bool = False, - textFormat: str = "Raw", -): - return Bing(connection).search( - query=query, - answerCount=answerCount, - cc=cc, - count=count, - freshness=freshness, - mkt=mkt, - offset=offset, - promote=promote, - responseFilter=responseFilter, - safesearch=safesearch, - setLang=setLang, - textDecorations=textDecorations, - textFormat=textFormat, - ) - - -register_builtin_method(search) diff --git a/src/promptflow-tools/tests/test_aoai.py b/src/promptflow-tools/tests/test_aoai.py index 28ac3e7935d..7bc82e21399 100644 --- a/src/promptflow-tools/tests/test_aoai.py +++ b/src/promptflow-tools/tests/test_aoai.py @@ -4,7 +4,7 @@ from promptflow.connections import AzureOpenAIConnection from promptflow.tools.aoai import chat, completion -from promptflow.utils.utils import AttrDict +from tests.utils import AttrDict @pytest.mark.usefixtures("use_secrets_config_file") diff --git a/src/promptflow-tools/tests/utils.py b/src/promptflow-tools/tests/utils.py index 9c66707cf8d..411cc4b8739 100644 --- a/src/promptflow-tools/tests/utils.py +++ b/src/promptflow-tools/tests/utils.py @@ -1,6 +1,16 @@ import json +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __getattr__(self, item): + if item in self: + return self.__getitem__(item) + return super().__getattribute__(item) + + def is_json_serializable(data, function_name): try: json.dumps(data)