Skip to content

Commit

Permalink
Remove "promptflow.utils" and bing.py (microsoft#54)
Browse files Browse the repository at this point in the history
Co-authored-by: yalu4 <[email protected]>
  • Loading branch information
16oeahr and yalu4 authored Aug 3, 2023
1 parent 493cea8 commit 500c4c3
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 182 deletions.
3 changes: 2 additions & 1 deletion scripts/utils/generate_tool_meta_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from promptflow.core.tool import ToolProvider
from promptflow.contracts.tool import InputDefinition, Tool, ToolType, ValueType
from promptflow.exceptions import ErrorTarget, UserErrorException
from promptflow.utils.tool_utils import function_to_interface, get_inputs_for_prompt_template

from utils.tool_utils import function_to_interface, get_inputs_for_prompt_template


def asdict_without_none(obj):
Expand Down
94 changes: 94 additions & 0 deletions scripts/utils/tool_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import inspect
from enum import Enum, EnumMeta
from typing import Callable, Union, get_args, get_origin
from jinja2 import Environment, meta
from promptflow.contracts.tool import ConnectionType, InputDefinition, ValueType


def value_to_str(val):
if val is inspect.Parameter.empty:
# For empty case, default field will be skipped when dumping to json
return None
if val is None:
# Dump default: "" in json to avoid UI validation error
return ""
if isinstance(val, Enum):
return val.value
return str(val)


def resolve_annotation(anno) -> Union[str, list]:
"""Resolve the union annotation to type list."""
origin = get_origin(anno)
if origin != Union:
return anno
# Optional[Type] is Union[Type, NoneType], filter NoneType out
args = [arg for arg in get_args(anno) if arg != type(None)] # noqa: E721
return args[0] if len(args) == 1 else args


def param_to_definition(param) -> (InputDefinition, bool):
default_value = param.default
# Get value type and enum from annotation
value_type = resolve_annotation(param.annotation)
enum = None
# Get value type and enum from default if no annotation
if default_value is not inspect.Parameter.empty and value_type == inspect.Parameter.empty:
value_type = default_value.__class__ if isinstance(default_value, Enum) else type(default_value)
# Extract enum for enum class
if isinstance(value_type, EnumMeta):
enum = [str(option.value) for option in value_type]
value_type = str
is_connection = False
if ConnectionType.is_connection_value(value_type):
typ = [value_type.__name__]
is_connection = True
elif isinstance(value_type, list):
if not all(ConnectionType.is_connection_value(t) for t in value_type):
typ = [ValueType.OBJECT]
else:
typ = [t.__name__ for t in value_type]
is_connection = True
else:
typ = [ValueType.from_type(value_type)]
return InputDefinition(type=typ, default=value_to_str(default_value), description=None, enum=enum), is_connection


def function_to_interface(f: Callable, initialize_inputs=None) -> tuple:
sign = inspect.signature(f)
all_inputs = {}
input_defs = {}
connection_types = []
# Collect all inputs from class and func
if initialize_inputs:
if any(k for k in initialize_inputs if k in sign.parameters):
raise Exception(f'Duplicate inputs found from {f.__name__!r} and "__init__()"!')
all_inputs = {**initialize_inputs}
all_inputs.update(
{
k: v
for k, v in sign.parameters.items()
if k != "self" and v.kind != v.VAR_KEYWORD and v.kind != v.VAR_POSITIONAL # TODO: Handle these cases
}
)
# Resolve inputs to definitions.
for k, v in all_inputs.items():
input_def, is_connection = param_to_definition(v)
input_defs[k] = input_def
if is_connection:
connection_types.append(input_def.type)
outputs = {}
# Note: We don't have output definition now
# outputs = {"output": OutputDefinition("output", [ValueType.from_type(type(sign.return_annotation))], "", True)}
# if is_dataclass(sign.return_annotation):
# for f in fields(sign.return_annotation):
# outputs[f.name] = OutputDefinition(f.name, [ValueType.from_type(
# type(getattr(sign.return_annotation, f.name)))], "", False)
return input_defs, outputs, connection_types


def get_inputs_for_prompt_template(template_str):
"""Get all input variable names from a jinja2 template string."""
env = Environment()
template = env.parse(template_str)
return sorted(meta.find_undeclared_variables(template), key=lambda x: template_str.find(x))
1 change: 0 additions & 1 deletion src/promptflow-tools/promptflow/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@
from .azure_language_detector import get_language # noqa: F401
from .azure_form_recognizer import AzureFormRecognizer # noqa: F401
from .azure_translator import get_translation, AzureTranslator # noqa: F401
from .bing import Bing # noqa: F401
from .openai import OpenAI # noqa: F401
from .serpapi import SerpAPI # noqa: F401
179 changes: 0 additions & 179 deletions src/promptflow-tools/promptflow/tools/bing.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/promptflow-tools/tests/test_aoai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from promptflow.connections import AzureOpenAIConnection
from promptflow.tools.aoai import chat, completion
from promptflow.utils.utils import AttrDict
from tests.utils import AttrDict


@pytest.mark.usefixtures("use_secrets_config_file")
Expand Down
10 changes: 10 additions & 0 deletions src/promptflow-tools/tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
import json


class AttrDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def __getattr__(self, item):
if item in self:
return self.__getitem__(item)
return super().__getattribute__(item)


def is_json_serializable(data, function_name):
try:
json.dumps(data)
Expand Down

0 comments on commit 500c4c3

Please sign in to comment.