Skip to content

Commit

Permalink
[tool] refine tool unit tests and remove useless codes (microsoft#49)
Browse files Browse the repository at this point in the history
refine tool unit tests and remove useless codes.
- instead of simply `skip`, which makes user need to comment the code
line to run tests, change to `skip_if_no_key` by checking if api key
provided.
- remove duplicated tests, provide better practice by adding assertation
and removing bad print.
- remove useless tool codes.
  • Loading branch information
chjinche authored Aug 2, 2023
1 parent dfd5402 commit c787b11
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 669 deletions.
17 changes: 2 additions & 15 deletions src/promptflow-tools/promptflow/tools/aoai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from promptflow.core.tools_manager import register_api_method, register_apis
from promptflow.tools.common import render_jinja_template, handle_openai_error, parse_chat, to_bool, \
validate_functions, process_function_call, post_process_chat_api_response
from promptflow.utils.utils import deprecated


class AzureOpenAI(ToolProvider):
Expand All @@ -19,11 +18,6 @@ def __init__(self, connection: AzureOpenAIConnection):
self.connection = connection
self._connection_dict = asdict(self.connection)

@staticmethod
@deprecated(replace="AzureOpenAI()")
def from_config(config: AzureOpenAIConnection):
return AzureOpenAI(config)

def calculate_cache_string_for_completion(
self,
**kwargs,
Expand Down Expand Up @@ -62,7 +56,6 @@ def completion(
# TODO: remove below type conversion after client can pass json rather than string.
echo = to_bool(echo)
stream = to_bool(stream)

response = openai.Completion.create(
prompt=prompt,
engine=deployment_name,
Expand All @@ -89,7 +82,6 @@ def completion(
headers={"ms-azure-ai-promptflow-called-from": "aoai-tool"},
**self._connection_dict,
)

if stream:
def generator():
for chunk in response:
Expand Down Expand Up @@ -123,7 +115,7 @@ def chat(
function_call: str = None,
functions: list = None,
**kwargs,
) -> str:
) -> [str, dict]:
# keep_trailing_newline=True is to keep the last \n in the prompt to avoid converting "user:\t\n" to "user:".
chat_str = render_jinja_template(prompt, trim_blocks=True, keep_trailing_newline=True, **kwargs)
messages = parse_chat(chat_str)
Expand Down Expand Up @@ -152,6 +144,7 @@ def chat(
completion = openai.ChatCompletion.create(**{**self._connection_dict, **params})
return post_process_chat_api_response(completion, stream, functions)

# TODO: embedding is a separate builtin tool, will remove it from llm.
@tool
@handle_openai_error()
def embedding(self, input, deployment_name: str, user: str = ""):
Expand Down Expand Up @@ -249,11 +242,5 @@ def chat(
)


@tool
def embedding(connection: AzureOpenAIConnection, input, deployment_name: str, user: str = ""):
return AzureOpenAI(connection).embedding(input=input, deployment_name=deployment_name, user=user)


register_api_method(completion)
register_api_method(chat)
register_api_method(embedding)
13 changes: 1 addition & 12 deletions src/promptflow-tools/promptflow/tools/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from promptflow.core.tools_manager import register_api_method, register_apis
from promptflow.tools.common import render_jinja_template, handle_openai_error, \
parse_chat, to_bool, validate_functions, process_function_call, post_process_chat_api_response
from promptflow.utils.utils import deprecated


class Engine(str, Enum):
Expand All @@ -29,11 +28,6 @@ def __init__(self, connection: OpenAIConnection):
self.connection = connection
self._connection_dict = asdict(self.connection)

@staticmethod
@deprecated(replace="OpenAI()")
def from_config(config: OpenAIConnection):
return OpenAI(config)

@tool
@handle_openai_error()
def completion(
Expand Down Expand Up @@ -143,6 +137,7 @@ def chat(
completion = openai.ChatCompletion.create(**{**self._connection_dict, **params})
return post_process_chat_api_response(completion, stream, functions)

# TODO: embedding is a separate builtin tool, will remove it from llm.
@tool
@handle_openai_error()
def embedding(self, input, model: str = "text-embedding-ada-002", user: str = ""):
Expand Down Expand Up @@ -238,11 +233,5 @@ def chat(
)


@tool
def embedding(connection: OpenAIConnection, input, model: str = "text-embedding-ada-002", user: str = ""):
return OpenAI(connection).embedding(input=input, model=model, user=user)


register_api_method(completion)
register_api_method(chat)
register_api_method(embedding)
75 changes: 2 additions & 73 deletions src/promptflow-tools/promptflow/tools/serpapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from promptflow.connections import SerpConnection
from promptflow.core.tools_manager import register_builtin_method, register_builtins
from promptflow.exceptions import PromptflowException
from promptflow.tools.common import to_bool
from promptflow.tools.exception import SerpAPIUserError, SerpAPISystemError
from promptflow.utils.utils import deprecated


class SafeMode(str, Enum):
Expand All @@ -28,11 +26,6 @@ def __init__(self, connection: SerpConnection):
super().__init__()
self.connection = connection

@staticmethod
@deprecated(replace="SerpAPI()")
def from_config(config: SerpConnection):
return SerpAPI(config)

def extract_error_message_from_json(self, error_data):
error_message = ""
# For request was rejected. For example, the api_key is not valid
Expand Down Expand Up @@ -66,44 +59,18 @@ def search(
self,
query: str, # this is required
location: str = None,
google_domain: str = "google.com",
gl: str = None,
hl: str = None,
lr: str = None,
tbs: str = None,
safe: SafeMode = SafeMode.OFF, # Not default to be SafeMode.OFF
nfpr: bool = False,
filter: str = None,
tbm: str = None,
start: int = 0,
num: int = 10,
ijn: int = 0,
engine: Engine = Engine.GOOGLE, # this is required
device: str = "desktop",
no_cache: bool = False,
asynch: bool = False,
output: str = "JSON",
):
from serpapi import SerpApiClient

# required parameters. https://serpapi.com/search-api.
params = {
"q": query,
"location": location,
"google_domain": google_domain,
"gl": gl,
"hl": hl,
"lr": lr,
"tbs": tbs,
"filter": filter,
"tbm": tbm,
"device": device,
"no_cache": to_bool(no_cache),
"async": to_bool(asynch),
"api_key": self.connection.api_key,
"output": output,
}

if isinstance(engine, Engine):
params["engine"] = engine.value
else:
Expand All @@ -117,31 +84,21 @@ def search(
else:
params["safeSearch"] = "Strict"

if to_bool(nfpr):
params["nfpr"] = True
if int(start) > 0:
params["start"] = int(start)
if int(num) > 0:
# to combine multiple engines togather, we use "num" as the parameter for such purpose
if params["engine"].lower() == "google":
params["num"] = int(num)
else:
params["count"] = int(num)
if int(ijn) > 0:
params["ijn"] = int(ijn)

search = SerpApiClient(params)

# get response
try:
response = search.get_response()
if response.status_code == requests.codes.ok:
if output.lower() == "json":
# Keep the same as SerpAPIClient.get_json()
return json.loads(response.text)
else:
# Keep the same as SerpAPIClient.get_html()
return response.text
# default output is json
return json.loads(response.text)
else:
# Step I: Try to get accurate error message at best
error_message = self.safe_extract_error_message(response)
Expand All @@ -168,44 +125,16 @@ def search(
connection: SerpConnection,
query: str, # this is required
location: str = None,
google_domain: str = "google.com",
gl: str = None,
hl: str = None,
lr: str = None,
tbs: str = None,
safe: SafeMode = SafeMode.OFF, # Not default to be SafeMode.OFF
nfpr: bool = False,
filter: str = None,
tbm: str = None,
start: int = 0,
num: int = 10,
ijn: int = 0,
engine: Engine = Engine.GOOGLE, # this is required
device: str = "desktop",
no_cache: bool = False,
asynch: bool = False,
output: str = "JSON",
):
return SerpAPI(connection).search(
query=query,
location=location,
google_domain=google_domain,
gl=gl,
hl=hl,
lr=lr,
tbs=tbs,
safe=safe,
nfpr=nfpr,
filter=filter,
tbm=tbm,
start=start,
num=num,
ijn=ijn,
engine=engine,
device=device,
no_cache=no_cache,
asynch=asynch,
output=output,
)


Expand Down
54 changes: 51 additions & 3 deletions src/promptflow-tools/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,55 @@
import pytest
from pytest_mock import MockerFixture # noqa: E402

PROMOTFLOW_ROOT = Path(__file__) / "../.."
from promptflow.core.connection_manager import ConnectionManager
from promptflow.tools.aoai import AzureOpenAI

PROMOTFLOW_ROOT = Path(__file__).absolute().parents[1]
CONNECTION_FILE = (PROMOTFLOW_ROOT / "connections.json").resolve().absolute().as_posix()
root_str = str(PROMOTFLOW_ROOT.resolve().absolute())
if root_str not in sys.path:
sys.path.insert(0, root_str)


PROMOTFLOW_ROOT = Path(__file__).absolute().parents[1]
# connection
@pytest.fixture(autouse=True)
def use_secrets_config_file(mocker: MockerFixture):
mocker.patch.dict(os.environ, {"PROMPTFLOW_CONNECTIONS": CONNECTION_FILE})


@pytest.fixture
def use_secrets_config_file(mocker: MockerFixture):
def azure_open_ai_connection():
return ConnectionManager().get("azure_open_ai_connection")


@pytest.fixture
def aoai_provider(azure_open_ai_connection) -> AzureOpenAI:
aoai_provider = AzureOpenAI(azure_open_ai_connection)
return aoai_provider


@pytest.fixture
def open_ai_connection():
return ConnectionManager().get("open_ai_connection")


@pytest.fixture
def serp_connection():
return ConnectionManager().get("serp_connection")


@pytest.fixture(autouse=True)
def skip_if_no_key(request, mocker):
mocker.patch.dict(os.environ, {"PROMPTFLOW_CONNECTIONS": CONNECTION_FILE})
if request.node.get_closest_marker('skip_if_no_key'):
conn_name = request.node.get_closest_marker('skip_if_no_key').args[0]
connection = request.getfixturevalue(conn_name)
# if dummy placeholder key, skip.
if "-api-key" in connection.api_key:
pytest.skip('skipped because no key')


# example prompts
@pytest.fixture
def example_prompt_template() -> str:
with open(PROMOTFLOW_ROOT / "tests/test_configs/prompt_templates/marketing_writer/prompt.jinja2") as f:
Expand All @@ -40,3 +74,17 @@ def example_prompt_template_with_function() -> str:
with open(PROMOTFLOW_ROOT / "tests/test_configs/prompt_templates/prompt_with_function.jinja2") as f:
prompt_template = f.read()
return prompt_template


# functions
@pytest.fixture
def functions():
return [
{
"name": "get_current_weather",
"parameters": {
"type": "object",
"properties": {},
},
}
]
3 changes: 3 additions & 0 deletions src/promptflow-tools/tests/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
markers =
skip_if_no_key: skip the test if actual api key is not provided.
Loading

0 comments on commit c787b11

Please sign in to comment.