From c787b11d82e9cbb25bcfa8c8a40dd63f38fc0bae Mon Sep 17 00:00:00 2001 From: chjinche <49483542+chjinche@users.noreply.github.com> Date: Wed, 2 Aug 2023 13:58:37 +0800 Subject: [PATCH] [tool] refine tool unit tests and remove useless codes (#49) refine tool unit tests and remove useless codes. - instead of simply `skip`, which makes user need to comment the code line to run tests, change to `skip_if_no_key` by checking if api key provided. - remove duplicated tests, provide better practice by adding assertation and removing bad print. - remove useless tool codes. --- src/promptflow-tools/promptflow/tools/aoai.py | 17 +- .../promptflow/tools/openai.py | 13 +- .../promptflow/tools/serpapi.py | 75 +---- src/promptflow-tools/tests/conftest.py | 54 +++- src/promptflow-tools/tests/pytest.ini | 3 + src/promptflow-tools/tests/test_aoai.py | 61 +--- .../tests/test_azure_content_safety.py | 3 - .../tests/test_azure_form_recognizer.py | 3 - src/promptflow-tools/tests/test_embedding.py | 49 ++-- .../tests/test_handle_openai_error.py | 187 ++---------- src/promptflow-tools/tests/test_openai.py | 78 +---- src/promptflow-tools/tests/test_serpapi.py | 266 ++---------------- 12 files changed, 140 insertions(+), 669 deletions(-) create mode 100644 src/promptflow-tools/tests/pytest.ini diff --git a/src/promptflow-tools/promptflow/tools/aoai.py b/src/promptflow-tools/promptflow/tools/aoai.py index ffb223c44d2..1dc4edd2c51 100644 --- a/src/promptflow-tools/promptflow/tools/aoai.py +++ b/src/promptflow-tools/promptflow/tools/aoai.py @@ -10,7 +10,6 @@ from promptflow.core.tools_manager import register_api_method, register_apis from promptflow.tools.common import render_jinja_template, handle_openai_error, parse_chat, to_bool, \ validate_functions, process_function_call, post_process_chat_api_response -from promptflow.utils.utils import deprecated class AzureOpenAI(ToolProvider): @@ -19,11 +18,6 @@ def __init__(self, connection: AzureOpenAIConnection): self.connection = connection self._connection_dict = asdict(self.connection) - @staticmethod - @deprecated(replace="AzureOpenAI()") - def from_config(config: AzureOpenAIConnection): - return AzureOpenAI(config) - def calculate_cache_string_for_completion( self, **kwargs, @@ -62,7 +56,6 @@ def completion( # TODO: remove below type conversion after client can pass json rather than string. echo = to_bool(echo) stream = to_bool(stream) - response = openai.Completion.create( prompt=prompt, engine=deployment_name, @@ -89,7 +82,6 @@ def completion( headers={"ms-azure-ai-promptflow-called-from": "aoai-tool"}, **self._connection_dict, ) - if stream: def generator(): for chunk in response: @@ -123,7 +115,7 @@ def chat( function_call: str = None, functions: list = None, **kwargs, - ) -> str: + ) -> [str, dict]: # keep_trailing_newline=True is to keep the last \n in the prompt to avoid converting "user:\t\n" to "user:". chat_str = render_jinja_template(prompt, trim_blocks=True, keep_trailing_newline=True, **kwargs) messages = parse_chat(chat_str) @@ -152,6 +144,7 @@ def chat( completion = openai.ChatCompletion.create(**{**self._connection_dict, **params}) return post_process_chat_api_response(completion, stream, functions) + # TODO: embedding is a separate builtin tool, will remove it from llm. @tool @handle_openai_error() def embedding(self, input, deployment_name: str, user: str = ""): @@ -249,11 +242,5 @@ def chat( ) -@tool -def embedding(connection: AzureOpenAIConnection, input, deployment_name: str, user: str = ""): - return AzureOpenAI(connection).embedding(input=input, deployment_name=deployment_name, user=user) - - register_api_method(completion) register_api_method(chat) -register_api_method(embedding) diff --git a/src/promptflow-tools/promptflow/tools/openai.py b/src/promptflow-tools/promptflow/tools/openai.py index 438288089aa..8809827fe35 100644 --- a/src/promptflow-tools/promptflow/tools/openai.py +++ b/src/promptflow-tools/promptflow/tools/openai.py @@ -9,7 +9,6 @@ from promptflow.core.tools_manager import register_api_method, register_apis from promptflow.tools.common import render_jinja_template, handle_openai_error, \ parse_chat, to_bool, validate_functions, process_function_call, post_process_chat_api_response -from promptflow.utils.utils import deprecated class Engine(str, Enum): @@ -29,11 +28,6 @@ def __init__(self, connection: OpenAIConnection): self.connection = connection self._connection_dict = asdict(self.connection) - @staticmethod - @deprecated(replace="OpenAI()") - def from_config(config: OpenAIConnection): - return OpenAI(config) - @tool @handle_openai_error() def completion( @@ -143,6 +137,7 @@ def chat( completion = openai.ChatCompletion.create(**{**self._connection_dict, **params}) return post_process_chat_api_response(completion, stream, functions) + # TODO: embedding is a separate builtin tool, will remove it from llm. @tool @handle_openai_error() def embedding(self, input, model: str = "text-embedding-ada-002", user: str = ""): @@ -238,11 +233,5 @@ def chat( ) -@tool -def embedding(connection: OpenAIConnection, input, model: str = "text-embedding-ada-002", user: str = ""): - return OpenAI(connection).embedding(input=input, model=model, user=user) - - register_api_method(completion) register_api_method(chat) -register_api_method(embedding) diff --git a/src/promptflow-tools/promptflow/tools/serpapi.py b/src/promptflow-tools/promptflow/tools/serpapi.py index d1b4257add6..448b040b37a 100644 --- a/src/promptflow-tools/promptflow/tools/serpapi.py +++ b/src/promptflow-tools/promptflow/tools/serpapi.py @@ -8,9 +8,7 @@ from promptflow.connections import SerpConnection from promptflow.core.tools_manager import register_builtin_method, register_builtins from promptflow.exceptions import PromptflowException -from promptflow.tools.common import to_bool from promptflow.tools.exception import SerpAPIUserError, SerpAPISystemError -from promptflow.utils.utils import deprecated class SafeMode(str, Enum): @@ -28,11 +26,6 @@ def __init__(self, connection: SerpConnection): super().__init__() self.connection = connection - @staticmethod - @deprecated(replace="SerpAPI()") - def from_config(config: SerpConnection): - return SerpAPI(config) - def extract_error_message_from_json(self, error_data): error_message = "" # For request was rejected. For example, the api_key is not valid @@ -66,23 +59,9 @@ def search( self, query: str, # this is required location: str = None, - google_domain: str = "google.com", - gl: str = None, - hl: str = None, - lr: str = None, - tbs: str = None, safe: SafeMode = SafeMode.OFF, # Not default to be SafeMode.OFF - nfpr: bool = False, - filter: str = None, - tbm: str = None, - start: int = 0, num: int = 10, - ijn: int = 0, engine: Engine = Engine.GOOGLE, # this is required - device: str = "desktop", - no_cache: bool = False, - asynch: bool = False, - output: str = "JSON", ): from serpapi import SerpApiClient @@ -90,20 +69,8 @@ def search( params = { "q": query, "location": location, - "google_domain": google_domain, - "gl": gl, - "hl": hl, - "lr": lr, - "tbs": tbs, - "filter": filter, - "tbm": tbm, - "device": device, - "no_cache": to_bool(no_cache), - "async": to_bool(asynch), "api_key": self.connection.api_key, - "output": output, } - if isinstance(engine, Engine): params["engine"] = engine.value else: @@ -117,18 +84,12 @@ def search( else: params["safeSearch"] = "Strict" - if to_bool(nfpr): - params["nfpr"] = True - if int(start) > 0: - params["start"] = int(start) if int(num) > 0: # to combine multiple engines togather, we use "num" as the parameter for such purpose if params["engine"].lower() == "google": params["num"] = int(num) else: params["count"] = int(num) - if int(ijn) > 0: - params["ijn"] = int(ijn) search = SerpApiClient(params) @@ -136,12 +97,8 @@ def search( try: response = search.get_response() if response.status_code == requests.codes.ok: - if output.lower() == "json": - # Keep the same as SerpAPIClient.get_json() - return json.loads(response.text) - else: - # Keep the same as SerpAPIClient.get_html() - return response.text + # default output is json + return json.loads(response.text) else: # Step I: Try to get accurate error message at best error_message = self.safe_extract_error_message(response) @@ -168,44 +125,16 @@ def search( connection: SerpConnection, query: str, # this is required location: str = None, - google_domain: str = "google.com", - gl: str = None, - hl: str = None, - lr: str = None, - tbs: str = None, safe: SafeMode = SafeMode.OFF, # Not default to be SafeMode.OFF - nfpr: bool = False, - filter: str = None, - tbm: str = None, - start: int = 0, num: int = 10, - ijn: int = 0, engine: Engine = Engine.GOOGLE, # this is required - device: str = "desktop", - no_cache: bool = False, - asynch: bool = False, - output: str = "JSON", ): return SerpAPI(connection).search( query=query, location=location, - google_domain=google_domain, - gl=gl, - hl=hl, - lr=lr, - tbs=tbs, safe=safe, - nfpr=nfpr, - filter=filter, - tbm=tbm, - start=start, num=num, - ijn=ijn, engine=engine, - device=device, - no_cache=no_cache, - asynch=asynch, - output=output, ) diff --git a/src/promptflow-tools/tests/conftest.py b/src/promptflow-tools/tests/conftest.py index 8db4eb6cb69..bacd601e73c 100644 --- a/src/promptflow-tools/tests/conftest.py +++ b/src/promptflow-tools/tests/conftest.py @@ -6,21 +6,55 @@ import pytest from pytest_mock import MockerFixture # noqa: E402 -PROMOTFLOW_ROOT = Path(__file__) / "../.." +from promptflow.core.connection_manager import ConnectionManager +from promptflow.tools.aoai import AzureOpenAI + +PROMOTFLOW_ROOT = Path(__file__).absolute().parents[1] CONNECTION_FILE = (PROMOTFLOW_ROOT / "connections.json").resolve().absolute().as_posix() root_str = str(PROMOTFLOW_ROOT.resolve().absolute()) if root_str not in sys.path: sys.path.insert(0, root_str) -PROMOTFLOW_ROOT = Path(__file__).absolute().parents[1] +# connection +@pytest.fixture(autouse=True) +def use_secrets_config_file(mocker: MockerFixture): + mocker.patch.dict(os.environ, {"PROMPTFLOW_CONNECTIONS": CONNECTION_FILE}) @pytest.fixture -def use_secrets_config_file(mocker: MockerFixture): +def azure_open_ai_connection(): + return ConnectionManager().get("azure_open_ai_connection") + + +@pytest.fixture +def aoai_provider(azure_open_ai_connection) -> AzureOpenAI: + aoai_provider = AzureOpenAI(azure_open_ai_connection) + return aoai_provider + + +@pytest.fixture +def open_ai_connection(): + return ConnectionManager().get("open_ai_connection") + + +@pytest.fixture +def serp_connection(): + return ConnectionManager().get("serp_connection") + + +@pytest.fixture(autouse=True) +def skip_if_no_key(request, mocker): mocker.patch.dict(os.environ, {"PROMPTFLOW_CONNECTIONS": CONNECTION_FILE}) + if request.node.get_closest_marker('skip_if_no_key'): + conn_name = request.node.get_closest_marker('skip_if_no_key').args[0] + connection = request.getfixturevalue(conn_name) + # if dummy placeholder key, skip. + if "-api-key" in connection.api_key: + pytest.skip('skipped because no key') +# example prompts @pytest.fixture def example_prompt_template() -> str: with open(PROMOTFLOW_ROOT / "tests/test_configs/prompt_templates/marketing_writer/prompt.jinja2") as f: @@ -40,3 +74,17 @@ def example_prompt_template_with_function() -> str: with open(PROMOTFLOW_ROOT / "tests/test_configs/prompt_templates/prompt_with_function.jinja2") as f: prompt_template = f.read() return prompt_template + + +# functions +@pytest.fixture +def functions(): + return [ + { + "name": "get_current_weather", + "parameters": { + "type": "object", + "properties": {}, + }, + } + ] diff --git a/src/promptflow-tools/tests/pytest.ini b/src/promptflow-tools/tests/pytest.ini new file mode 100644 index 00000000000..7c5572c702d --- /dev/null +++ b/src/promptflow-tools/tests/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + skip_if_no_key: skip the test if actual api key is not provided. \ No newline at end of file diff --git a/src/promptflow-tools/tests/test_aoai.py b/src/promptflow-tools/tests/test_aoai.py index 77e8c163071..28ac3e7935d 100644 --- a/src/promptflow-tools/tests/test_aoai.py +++ b/src/promptflow-tools/tests/test_aoai.py @@ -3,32 +3,19 @@ import pytest from promptflow.connections import AzureOpenAIConnection -from promptflow.core.connection_manager import ConnectionManager -from promptflow.tools.aoai import AzureOpenAI, chat, completion, embedding +from promptflow.tools.aoai import chat, completion from promptflow.utils.utils import AttrDict -@pytest.fixture -def azure_open_ai_connection() -> AzureOpenAIConnection: - return ConnectionManager().get("azure_open_ai_connection") - - -@pytest.fixture -def aoai_provider(azure_open_ai_connection) -> AzureOpenAI: - aoai_provider = AzureOpenAI(azure_open_ai_connection) - return aoai_provider - - -@pytest.mark.usefixtures("use_secrets_config_file", "aoai_provider", "azure_open_ai_connection") +@pytest.mark.usefixtures("use_secrets_config_file") class TestAOAI: def test_aoai_completion(self, aoai_provider): prompt_template = "please complete this sentence: world war II " # test whether tool can handle param "stop" with value empty list # as openai raises "[] is not valid under any of the given schemas - 'stop'" - result = aoai_provider.completion( + aoai_provider.completion( prompt=prompt_template, deployment_name="text-ada-001", stop=[], logit_bias={} ) - print("aoai.completion() result=[" + result + "]") def test_aoai_chat(self, aoai_provider, example_prompt_template, chat_history): result = aoai_provider.chat( @@ -39,7 +26,6 @@ def test_aoai_chat(self, aoai_provider, example_prompt_template, chat_history): user_input="Fill in more detalis about trend 2.", chat_history=chat_history, ) - print("aoai.chat() result=[" + result + "]") assert "details about trend 2" in result.lower() def test_aoai_chat_api(self, azure_open_ai_connection, example_prompt_template, chat_history): @@ -52,19 +38,10 @@ def test_aoai_chat_api(self, azure_open_ai_connection, example_prompt_template, user_input="Write a slogan for product X", chat_history=chat_history, ) - print(" chat() api result=[" + result + "]") assert "Product X".lower() in result.lower() - functions = [ - { - "name": "get_current_weather", - "parameters": { - "type": "object", - "properties": {}, - }, - } - ] - + def test_aoai_chat_with_function( + self, azure_open_ai_connection, example_prompt_template, chat_history, functions): result = chat( connection=azure_open_ai_connection, prompt=example_prompt_template, @@ -76,8 +53,8 @@ def test_aoai_chat_api(self, azure_open_ai_connection, example_prompt_template, functions=functions, function_call="auto" ) - result = str(result.to_dict()) - print(" chat() api result=[" + result + "]") + assert "function_call" in result + assert result["function_call"]["name"] == "get_current_weather" def test_aoai_chat_message_with_no_content(self, aoai_provider): # missing colon after role name. Sometimes following prompt may result in empty content. @@ -88,34 +65,10 @@ def test_aoai_chat_message_with_no_content(self, aoai_provider): ) # assert chat tool can handle. aoai_provider.chat(prompt=prompt, deployment_name="gpt-35-turbo") - # empty content after role name:\n prompt = "user:\n" aoai_provider.chat(prompt=prompt, deployment_name="gpt-35-turbo") - def test_aoai_embedding(self, aoai_provider): - input = "The food was delicious and the waiter" - result = aoai_provider.embedding(input=input, deployment_name="text-embedding-ada-002") - embedding_vector = ", ".join(str(num) for num in result) - print("aoai.embedding() result=[" + embedding_vector + "]") - - def test_aoai_embedding_api(self, azure_open_ai_connection): - input = ["The food was delicious and the waiter"] # we could use array as well, vs str - result = embedding(azure_open_ai_connection, input=input, deployment_name="text-embedding-ada-002") - embedding_vector = ", ".join(str(num) for num in result) - print("embedding() api result=[" + embedding_vector + "]") - - def test_chat_no_exception_even_no_message_content(self, aoai_provider): - # This is to confirm no exception even if no message content; For more details, please find - # https://msdata.visualstudio.com/Vienna/_workitems/edit/2377116 - prompt = ( - "user:\n what is your name\nassistant\nAs an AI language model developed by OpenAI, " - "I do not have a name. You can call me OpenAI or AI assistant. How can I assist you today?" - ) - - result = aoai_provider.chat(prompt=prompt, deployment_name="gpt-35-turbo") - print("test_chat_no_exception_even_no_message_content result=[" + result + "]") - @pytest.mark.parametrize( "params, expected", [ diff --git a/src/promptflow-tools/tests/test_azure_content_safety.py b/src/promptflow-tools/tests/test_azure_content_safety.py index 6e0ffbd9f24..b7282a71d22 100644 --- a/src/promptflow-tools/tests/test_azure_content_safety.py +++ b/src/promptflow-tools/tests/test_azure_content_safety.py @@ -1,6 +1,5 @@ import json import unittest -from pathlib import Path import pytest @@ -10,8 +9,6 @@ import tests.utils as utils -PROMOTFLOW_ROOT = Path(__file__) / "../../../../" - @pytest.fixture def content_safety_config() -> AzureContentSafetyConnection: diff --git a/src/promptflow-tools/tests/test_azure_form_recognizer.py b/src/promptflow-tools/tests/test_azure_form_recognizer.py index ae72ec98b45..6232675b48f 100644 --- a/src/promptflow-tools/tests/test_azure_form_recognizer.py +++ b/src/promptflow-tools/tests/test_azure_form_recognizer.py @@ -1,6 +1,5 @@ import json import unittest -from pathlib import Path import pytest @@ -10,8 +9,6 @@ import tests.utils as utils -PROMOTFLOW_ROOT = Path(__file__) / "../../../../" - @pytest.fixture def form_recognizer_connection() -> CustomConnection: diff --git a/src/promptflow-tools/tests/test_embedding.py b/src/promptflow-tools/tests/test_embedding.py index 917b2600ffd..bc85bb5c2c7 100644 --- a/src/promptflow-tools/tests/test_embedding.py +++ b/src/promptflow-tools/tests/test_embedding.py @@ -1,32 +1,29 @@ import pytest -from promptflow.connections import AzureOpenAIConnection, OpenAIConnection -from promptflow.core.connection_manager import ConnectionManager +from promptflow.exceptions import ErrorResponse from promptflow.tools.embedding import embedding +from promptflow.tools.exception import InvalidConnectionType -@pytest.fixture -def azure_open_ai_connection() -> [AzureOpenAIConnection]: - return ConnectionManager().get("azure_open_ai_connection") - - -@pytest.fixture -def open_ai_connection() -> [OpenAIConnection]: - return ConnectionManager().get("open_ai_connection") - - -@pytest.mark.usefixtures("use_secrets_config_file", "azure_open_ai_connection", - "open_ai_connection") +@pytest.mark.usefixtures("use_secrets_config_file") class TestEmbedding: - def test_aoai_embedding_api(self, azure_open_ai_connection): - input = ["The food was delicious and the waiter"] # we could use array as well, vs str - result = embedding(azure_open_ai_connection, input=input, deployment_name="text-embedding-ada-002") - embedding_vector = ", ".join(str(num) for num in result) - print("embedding() api result=[" + embedding_vector + "]") - - @pytest.mark.skip(reason="openai key not set yet") - def test_openai_embedding_api(self, open_ai_connection): - input = ["The food was delicious and the waiter"] # we could use array as well, vs str - result = embedding(open_ai_connection, input=input, model="text-embedding-ada-002") - embedding_vector = ", ".join(str(num) for num in result) - print("embedding() api result=[" + embedding_vector + "]") + def test_embedding_conn_aoai(self, azure_open_ai_connection): + result = embedding( + connection=azure_open_ai_connection, + input="The food was delicious and the waiter", + deployment_name="text-embedding-ada-002") + assert len(result) == 1536 + + @pytest.mark.skip_if_no_key("open_ai_connection") + def test_embedding_conn_oai(self, open_ai_connection): + result = embedding( + connection=open_ai_connection, + input="The food was delicious and the waiter", + model="text-embedding-ada-002") + assert len(result) == 1536 + + def test_embedding_invalid_connection_type(self, serp_connection): + with pytest.raises(InvalidConnectionType) as exc_info: + embedding(connection=serp_connection, input="hello", deployment_name="text-embedding-ada-002") + assert "UserError/ToolValidationError/InvalidConnectionType" == ErrorResponse.from_exception( + exc_info.value).error_code_hierarchy diff --git a/src/promptflow-tools/tests/test_handle_openai_error.py b/src/promptflow-tools/tests/test_handle_openai_error.py index 069ed56d6b7..01277f2b3ac 100644 --- a/src/promptflow-tools/tests/test_handle_openai_error.py +++ b/src/promptflow-tools/tests/test_handle_openai_error.py @@ -12,39 +12,16 @@ ) from pytest_mock import MockerFixture -from promptflow.connections import AzureOpenAIConnection, SerpConnection, OpenAIConnection -from promptflow.core.connection_manager import ConnectionManager from promptflow.exceptions import UserErrorException, ErrorResponse -from promptflow.tools.aoai import AzureOpenAI, chat, completion, embedding +from promptflow.tools.aoai import chat, completion + from promptflow.tools.common import handle_openai_error -from promptflow.tools.embedding import embedding as Embedding from promptflow.tools.exception import ChatAPIInvalidRole, WrappedOpenAIError, openai_error_code_ref_message, \ - to_openai_error_message, JinjaTemplateError, LLMError, InvalidConnectionType, ChatAPIFunctionRoleInvalidFormat + to_openai_error_message, JinjaTemplateError, LLMError, ChatAPIFunctionRoleInvalidFormat from promptflow.tools.openai import chat as openai_chat -@pytest.fixture -def azure_open_ai_connection() -> AzureOpenAIConnection: - return ConnectionManager().get("azure_open_ai_connection") - - -@pytest.fixture -def open_ai_connection() -> OpenAIConnection: - return ConnectionManager().get("open_ai_connection") - - -@pytest.fixture -def serp_connection() -> SerpConnection: - return ConnectionManager().get("serp_connection") - - -@pytest.fixture -def aoai_provider(azure_open_ai_connection) -> AzureOpenAI: - aoai_provider = AzureOpenAI.from_config(azure_open_ai_connection) - return aoai_provider - - -@pytest.mark.usefixtures("use_secrets_config_file", "aoai_provider", "azure_open_ai_connection", "serp_connection") +@pytest.mark.usefixtures("use_secrets_config_file") class TestHandleOpenAIError: def test_aoai_chat_message_invalid_format(self, aoai_provider): # chat api prompt should follow the format of "system:\nmessage1\nuser:\nmessage2". @@ -55,17 +32,9 @@ def test_aoai_chat_message_invalid_format(self, aoai_provider): assert "UserError/ToolValidationError/ChatAPIInvalidRole" == ErrorResponse.from_exception( exc_info.value).error_code_hierarchy - def test_aoai_authencation_error_with_api_key(self, azure_open_ai_connection): - """ - bad api key - - API Error Message - completion,embedding,chat AuthenticationError The same - """ - + def test_aoai_authencation_error_with_bad_api_key(self, azure_open_ai_connection): azure_open_ai_connection.api_key = "hello" prompt_template = "please complete this sentence: world war II " - raw_message = ( "Access denied due to invalid subscription key or wrong API endpoint. " "Make sure to provide a valid key for an active subscription and use a " @@ -73,83 +42,35 @@ def test_aoai_authencation_error_with_api_key(self, azure_open_ai_connection): ) error_msg = to_openai_error_message(AuthenticationError(message=raw_message)) error_code = "UserError/OpenAIError/AuthenticationError" - - with pytest.raises(WrappedOpenAIError) as exc_info: - completion(azure_open_ai_connection, prompt=prompt_template, deployment_name="text-ada-001") - assert error_msg == exc_info.value.message - assert error_code == ErrorResponse.from_exception( - exc_info.value).error_code_hierarchy - with pytest.raises(WrappedOpenAIError) as exc_info: chat(azure_open_ai_connection, prompt=f"user:\n{prompt_template}", deployment_name="gpt-35-turbo") assert error_msg == exc_info.value.message assert error_code == ErrorResponse.from_exception( exc_info.value).error_code_hierarchy - with pytest.raises(WrappedOpenAIError) as exc_info: - embedding(azure_open_ai_connection, input=prompt_template, deployment_name="text-embedding-ada-002") - assert error_msg == exc_info.value.message - assert error_code == ErrorResponse.from_exception( - exc_info.value).error_code_hierarchy - - def test_aoai_authencation_error_with_api_base(self, azure_open_ai_connection): + def test_aoai_connection_error_with_bad_api_base(self, azure_open_ai_connection): """ - bad api base (endpoint) - - API Error Message - completion,embedding, chat APIConnectionError The same - APIConnectionError: Error communicating with OpenAI: HTTPSConnectionPool(host='gpt-test-eus11.openai.azure.com' , port=443): Max retries exceeded with url: //openai/deployments/text-ada-001/completions? api-version=2022-12-01 (Caused by NewConnectionError(': Failed to establish a new connection: [Errno 11001] getaddrinfo failed')) """ - azure_open_ai_connection.api_base = "https://gpt-test-eus11.openai.azure.com/" prompt_template = "please complete this sentence: world war II " error_code = "UserError/OpenAIError/APIConnectionError" - - with pytest.raises(WrappedOpenAIError) as exc_info: - completion(azure_open_ai_connection, prompt=prompt_template, deployment_name="text-ada-001") - assert openai_error_code_ref_message in exc_info.value.message - assert error_code == ErrorResponse.from_exception( - exc_info.value).error_code_hierarchy - with pytest.raises(WrappedOpenAIError) as exc_info: chat(azure_open_ai_connection, prompt=f"user:\n{prompt_template}", deployment_name="gpt-35-turbo") assert openai_error_code_ref_message in exc_info.value.message assert error_code == ErrorResponse.from_exception( exc_info.value).error_code_hierarchy - with pytest.raises(WrappedOpenAIError) as exc_info: - embedding(azure_open_ai_connection, input=prompt_template, deployment_name="text-embedding-ada-002") - assert openai_error_code_ref_message in exc_info.value.message - assert error_code == ErrorResponse.from_exception( - exc_info.value).error_code_hierarchy - - def test_aoai_invalid_request_error_with_api_version(self, azure_open_ai_connection): - """ - bad api version - - API Error Message - completion,embedding, chat InvalidRequestError the same - - InvalidRequestError: Resource not found - """ - + def test_aoai_invalid_request_error_with_bad_api_version(self, azure_open_ai_connection): + """InvalidRequestError: Resource not found""" azure_open_ai_connection.api_version = "2022-12-23" prompt_template = "please complete this sentence: world war II " - raw_message = "Resource not found" error_msg = to_openai_error_message(InvalidRequestError(message=raw_message, param=None)) error_code = "UserError/OpenAIError/InvalidRequestError" - - with pytest.raises(WrappedOpenAIError) as exc_info: - completion(azure_open_ai_connection, prompt=prompt_template, deployment_name="text-ada-001") - assert error_msg == exc_info.value.message - assert error_code == ErrorResponse.from_exception( - exc_info.value).error_code_hierarchy - # Chat will throw: Exception occurs: InvalidRequestError: Resource not found with pytest.raises(WrappedOpenAIError) as exc_info: chat(azure_open_ai_connection, prompt=f"user:\n{prompt_template}", deployment_name="gpt-35-turbo") @@ -157,64 +78,30 @@ def test_aoai_invalid_request_error_with_api_version(self, azure_open_ai_connect assert error_code == ErrorResponse.from_exception( exc_info.value).error_code_hierarchy - # Embedding throw: Exception occurs: InvalidRequestError: The API deployment for this resource does not exist. - # If you created the deployment within the last 5 minutes, please wait a moment and try again. - with pytest.raises(WrappedOpenAIError) as exc_info: - embedding(azure_open_ai_connection, input=prompt_template, deployment_name="text-embedding-ada-002") - assert error_msg == exc_info.value.message - assert error_code == ErrorResponse.from_exception( - exc_info.value).error_code_hierarchy - - def test_aoai_invalid_request_error_api_type(self, azure_open_ai_connection): + def test_aoai_invalid_request_error_with_bad_api_type(self, azure_open_ai_connection): """ - bad api type - - API Error Message - completion,embedding, chat InvalidAPIType the same - InvalidAPIType: The API type provided in invalid. Please select one of the supported API types: 'azure', 'azure_ad', 'open_ai' """ - azure_open_ai_connection.api_type = "aml" prompt_template = "please complete this sentence: world war II " - raw_message = ( "The API type provided in invalid. Please select one of the supported API types: " "'azure', 'azure_ad', 'open_ai'" ) error_msg = to_openai_error_message(InvalidAPIType(message=raw_message)) error_code = "UserError/OpenAIError/InvalidAPIType" - - with pytest.raises(WrappedOpenAIError) as exc_info: - completion(azure_open_ai_connection, prompt=prompt_template, deployment_name="text-ada-001") - assert error_msg == exc_info.value.message - assert error_code == ErrorResponse.from_exception( - exc_info.value).error_code_hierarchy - with pytest.raises(WrappedOpenAIError) as exc_info: chat(azure_open_ai_connection, prompt=f"user:\n{prompt_template}", deployment_name="gpt-35-turbo") assert error_msg == exc_info.value.message assert error_code == ErrorResponse.from_exception( exc_info.value).error_code_hierarchy - with pytest.raises(WrappedOpenAIError) as exc_info: - embedding(azure_open_ai_connection, input=prompt_template, deployment_name="text-embedding-ada-002") - assert error_msg == exc_info.value.message - assert error_code == ErrorResponse.from_exception( - exc_info.value).error_code_hierarchy - - def test_aoai_invalid_request_error_with_deployment(self, aoai_provider, azure_open_ai_connection): + def test_aoai_invalid_request_error_with_bad_deployment(self, aoai_provider): """ - bad model/deployment - - API Error Message - completion,embedding,chat InvalidRequestError The same - InvalidRequestError: The API deployment for this resource does not exist. If you created the deployment within the last 5 minutes, please wait a moment and try again. """ - # This will throw InvalidRequestError prompt_template = "please complete this sentence: world war II " deployment = "hello" @@ -224,25 +111,12 @@ def test_aoai_invalid_request_error_with_deployment(self, aoai_provider, azure_o ) error_msg = to_openai_error_message(InvalidRequestError(message=raw_message, param=None)) error_code = "UserError/OpenAIError/InvalidRequestError" - - with pytest.raises(WrappedOpenAIError) as exc_info: - aoai_provider.completion(prompt=prompt_template, deployment_name=deployment) - assert error_msg == exc_info.value.message - assert error_code == ErrorResponse.from_exception( - exc_info.value).error_code_hierarchy - with pytest.raises(WrappedOpenAIError) as exc_info: aoai_provider.chat(prompt=f"user:\n{prompt_template}", deployment_name=deployment) assert error_msg == exc_info.value.message assert error_code == ErrorResponse.from_exception( exc_info.value).error_code_hierarchy - with pytest.raises(WrappedOpenAIError) as exc_info: - embedding(azure_open_ai_connection, input=prompt_template, deployment_name=deployment) - assert error_msg == exc_info.value.message - assert error_code == ErrorResponse.from_exception( - exc_info.value).error_code_hierarchy - def test_rate_limit_error_insufficient_quota(self, azure_open_ai_connection, mocker: MockerFixture): dummyEx = RateLimitError("Something went wrong", json_body={"error": {"type": "insufficient_quota"}}) mock_method = mocker.patch("promptflow.tools.aoai.openai.Completion.create", side_effect=dummyEx) @@ -279,7 +153,7 @@ def test_non_retriable_connection_error(self, azure_open_ai_connection, mocker: ), ], ) - def test_retriable_openai_error_handle(self, azure_open_ai_connection, mocker: MockerFixture, dummyExceptionList): + def test_retriable_openai_error_handle(self, mocker: MockerFixture, dummyExceptionList): for dummyEx in dummyExceptionList: # Patch the test_method to throw the desired exception patched_test_method = mocker.patch("promptflow.tools.aoai.completion", side_effect=dummyEx) @@ -319,7 +193,7 @@ def test_retriable_openai_error_handle(self, azure_open_ai_connection, mocker: M ], ) def test_retriable_openai_error_handle_with_header( - self, azure_open_ai_connection, mocker: MockerFixture, dummyExceptionList + self, mocker: MockerFixture, dummyExceptionList ): for dummyEx in dummyExceptionList: # Patch the test_method to throw the desired exception @@ -372,9 +246,9 @@ def test_non_retriable_openai_error_handle( def test_unexpected_error_handle(self, azure_open_ai_connection, mocker: MockerFixture): dummyEx = Exception("Something went wrong") - mock_method = mocker.patch("promptflow.tools.aoai.openai.Completion.create", side_effect=dummyEx) + mock_method = mocker.patch("promptflow.tools.aoai.openai.ChatCompletion.create", side_effect=dummyEx) with pytest.raises(LLMError) as exc_info: - completion(connection=azure_open_ai_connection, prompt="hello", deployment_name="text-ada-001") + chat(connection=azure_open_ai_connection, prompt="user:\nhello", deployment_name="gpt-35-turbo") assert to_openai_error_message(dummyEx) != exc_info.value.args[0] assert "OpenAI API hits exception: Exception: Something went wrong" == exc_info.value.message assert mock_method.call_count == 1 @@ -385,7 +259,7 @@ def test_template_syntax_error_handle(self, azure_open_ai_connection, mocker: Mo dummyEx = TemplateSyntaxError(message="Something went wrong", lineno=1) mock_method = mocker.patch("jinja2.Template.__new__", side_effect=dummyEx) with pytest.raises(JinjaTemplateError) as exc_info: - completion(connection=azure_open_ai_connection, prompt="hello", deployment_name="text-ada-001") + chat(connection=azure_open_ai_connection, prompt="user:\nhello", deployment_name="gpt-35-turbo") error_message = "Failed to render jinja template: TemplateSyntaxError: Something went wrong\n line 1. " \ + "Please modify your prompt to fix the issue." assert error_message == exc_info.value.message @@ -393,38 +267,23 @@ def test_template_syntax_error_handle(self, azure_open_ai_connection, mocker: Mo assert "UserError/ToolValidationError/JinjaTemplateError" == ErrorResponse.from_exception( exc_info.value).error_code_hierarchy - def test_invalid_connection_type(self, serp_connection): - with pytest.raises(InvalidConnectionType) as exc_info: - Embedding(connection=serp_connection, input="hello", deployment_name="text-embedding-ada-002") - assert "UserError/ToolValidationError/InvalidConnectionType" == ErrorResponse.from_exception( - exc_info.value).error_code_hierarchy - - @pytest.mark.skip(reason="openai key not set yet") - def test_model_not_accept_functions_as_param(self, open_ai_connection, example_prompt_template): + @pytest.mark.skip_if_no_key("open_ai_connection") + def test_model_not_accept_functions_as_param( + self, open_ai_connection, example_prompt_template, functions): with pytest.raises(WrappedOpenAIError) as exc_info: openai_chat( connection=open_ai_connection, prompt=example_prompt_template, model="gpt-3.5-turbo-0301", - temperature=0, - functions=[ - { - "name": "get_current_weather", - "parameters": { - "type": "object", - "properties": {}, - }, - } - ] + functions=functions ) assert "Current model does not support the `functions` parameter" in exc_info.value.message - @pytest.mark.skip(reason="openai key not set yet") - def test_input_invalid_function_role_prompt(self, open_ai_connection): + def test_input_invalid_function_role_prompt(self, azure_open_ai_connection): with pytest.raises(ChatAPIFunctionRoleInvalidFormat) as exc_info: - openai_chat( - connection=open_ai_connection, + chat( + connection=azure_open_ai_connection, prompt="function:\n This is function role prompt", - model="gpt-3.5-turbo" + deployment_name="gpt-35-turbo" ) assert "'name' is required if role is function," in exc_info.value.message diff --git a/src/promptflow-tools/tests/test_openai.py b/src/promptflow-tools/tests/test_openai.py index c921943ff14..5cff3b914fd 100644 --- a/src/promptflow-tools/tests/test_openai.py +++ b/src/promptflow-tools/tests/test_openai.py @@ -1,41 +1,24 @@ import pytest import json -from promptflow.connections import OpenAIConnection -from promptflow.core.connection_manager import ConnectionManager -from promptflow.tools.openai import embedding, chat, completion, OpenAI - - -@pytest.fixture -def open_ai_connection() -> [OpenAIConnection]: - return ConnectionManager().get("open_ai_connection") +from promptflow.tools.openai import chat, completion, OpenAI @pytest.fixture def openai_provider(open_ai_connection) -> OpenAI: - aoai_provider = OpenAI.from_config(open_ai_connection) - return aoai_provider + return OpenAI(open_ai_connection) -@pytest.mark.usefixtures("use_secrets_config_file", - "open_ai_connection") -@pytest.mark.skip(reason="openai key not set yet") +@pytest.mark.usefixtures("use_secrets_config_file") +@pytest.mark.skip_if_no_key("open_ai_connection") class TestOpenAI: - def test_openai_embedding_api(self, open_ai_connection): - input = ["The food was delicious and the waiter"] # we could use array as well, vs str - result = embedding(open_ai_connection, input=input, model="text-embedding-ada-002") - embedding_vector = ", ".join(str(num) for num in result) - print("embedding() api result=[" + embedding_vector + "]") - def test_openai_completion(self, openai_provider): prompt_template = "please complete this sentence: world war II " - result = openai_provider.completion(prompt=prompt_template) - print("openai.completion() result=[" + result + "]") + openai_provider.completion(prompt=prompt_template) def test_openai_completion_api(self, open_ai_connection): prompt_template = "please complete this sentence: world war II " - result = completion(open_ai_connection, prompt=prompt_template) - print("completion() api result=[" + result + "]") + completion(open_ai_connection, prompt=prompt_template) def test_openai_chat(self, openai_provider, example_prompt_template, chat_history): result = openai_provider.chat( @@ -46,7 +29,7 @@ def test_openai_chat(self, openai_provider, example_prompt_template, chat_histor user_input="Fill in more detalis about trend 2.", chat_history=chat_history, ) - print("openai.chat() result=[" + result + "]") + assert "details about trend 2" in result.lower() def test_openai_chat_api(self, open_ai_connection, example_prompt_template, chat_history): result = chat( @@ -58,57 +41,20 @@ def test_openai_chat_api(self, open_ai_connection, example_prompt_template, chat user_input="Write a slogan for product X", chat_history=chat_history, ) - print("chat() api result=[" + result + "]") - - functions = [ - { - "name": "get_current_weather", - "parameters": { - "type": "object", - "properties": {}, - }, - } - ] - - result = chat( - connection=open_ai_connection, - prompt=example_prompt_template, - model="gpt-3.5-turbo", - max_tokens="inF", - temperature=0, - user_input="What is the weather in Boston?", - chat_history=chat_history, - function_call="auto", - functions=functions - ) - result = str(result.to_dict()) - print("chat() api result=[" + result + "]") - - def test_openai_embedding(self, openai_provider): - input = "The food was delicious and the waiter" - result = openai_provider.embedding(input=input) - embedding_vector = ", ".join(str(num) for num in result) - print("openai.embedding() result=[" + embedding_vector + "]") + assert "Product X".lower() in result.lower() - def test_openai_prompt_with_function(self, open_ai_connection, example_prompt_template_with_function): - functions = [ - { - "name": "get_current_weather", - "parameters": { - "type": "object", - "properties": {}, - }, - } - ] + def test_openai_prompt_with_function( + self, open_ai_connection, example_prompt_template_with_function, functions): result = chat( connection=open_ai_connection, prompt=example_prompt_template_with_function, model="gpt-3.5-turbo", temperature=0, + # test input functions. functions=functions, + # test input prompt containing function role. name="get_location", result=json.dumps({"location": "Austin"}), - # assignments=assignments, question="What is the weather in Boston?", prev_question="Where is Boston?" ) diff --git a/src/promptflow-tools/tests/test_serpapi.py b/src/promptflow-tools/tests/test_serpapi.py index 9ea8638d161..0174e6d2d44 100644 --- a/src/promptflow-tools/tests/test_serpapi.py +++ b/src/promptflow-tools/tests/test_serpapi.py @@ -1,279 +1,45 @@ -import json -from pathlib import Path - import pytest -from promptflow.connections import SerpConnection -from promptflow.core.connection_manager import ConnectionManager from promptflow.exceptions import UserErrorException -from promptflow.tools.serpapi import Engine, SafeMode, SerpAPI, search +from promptflow.tools.serpapi import Engine, SafeMode, search import tests.utils as utils -PROMOTFLOW_ROOT = Path(__file__) / "../../../../" - - -@pytest.fixture -def serpapi_config() -> SerpConnection: - return ConnectionManager().get("serp_connection") - -@pytest.fixture -def serpapi_provider(serpapi_config) -> SerpAPI: - serpAPIProvider = SerpAPI.from_config(serpapi_config) - return serpAPIProvider - - -@pytest.mark.usefixtures("use_secrets_config_file", "serpapi_provider", "serpapi_config") -@pytest.mark.skip(reason="serpapi key not set yet") +@pytest.mark.usefixtures("use_secrets_config_file") +@pytest.mark.skip_if_no_key("serp_connection") class TestSerpAPI: - def test_start_num_ijn(self, serpapi_provider): - query = "Cute cat" - num = 3 - result_dict = serpapi_provider.search(query=query, num=num, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_start_num_ijn result1:\n" + json.dumps(result_dict)) - assert len(result_dict["organic_results"]) <= num - - start = 10 - result_dict = serpapi_provider.search(query=query, num=num, start=start, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_start_num_ijn result2:\n" + json.dumps(result_dict)) - assert int(result_dict["search_parameters"]["start"]) == start - assert result_dict["search_parameters"].get("ijn") is None - assert "ijn=" not in result_dict["search_metadata"]["google_url"] - - ijn = 5 - result_dict = serpapi_provider.search(query=query, num=num, ijn=ijn, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_start_num_ijn result3:\n" + json.dumps(result_dict)) - assert int(result_dict["search_parameters"]["num"]) == num - # ijn value will be ignore if start not specified - assert result_dict["search_parameters"].get("start") is None - assert result_dict["search_parameters"].get("ijn") is None - - ijn = 5 - tbm = "isch" - result_dict = serpapi_provider.search(query=query, num=num, start=start, ijn=ijn, tbm=tbm, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_start_num_ijn result4:\n" + json.dumps(result_dict)) - assert int(result_dict["search_parameters"]["start"]) == start - # ijn value will be ignore if start not specified - assert int(result_dict["search_parameters"]["ijn"]) == ijn - assert int(result_dict["search_parameters"]["num"]) == num - assert f"ijn={ijn}" in result_dict["search_metadata"]["google_url"] - assert f"num={num}" in result_dict["search_metadata"]["google_url"] - - def test_tbm(self, serpapi_provider): - query = "cute cats" - num = 5 - - tbm = "isch" - result_dict = serpapi_provider.search(query=query, tbm=tbm, num=num, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_tbm result1:\n" + json.dumps(result_dict)) - assert result_dict["images_results"] is not None - # we could not validate the num here; for "isch" the return num not match somewhow - - tbm = "vid" - result_dict = serpapi_provider.search(query=query, tbm=tbm, num=num, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_tbm result2:\n" + json.dumps(result_dict)) - assert len(result_dict["video_results"]) <= num - - tbm = None - result_dict = serpapi_provider.search(query=query, tbm=tbm, num=num, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_tbm result3:\n" + json.dumps(result_dict)) - assert len(result_dict["organic_results"]) <= num - - def test_output(self, serpapi_provider): - query = "cute cats" - - result_dict = serpapi_provider.search(query=query, num=2, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_output result1:\n" + json.dumps(result_dict)) - assert "" not in result_dict - - output = "HTML" - result_dict = serpapi_provider.search(query=query, num=2, output=output, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_output result2:\n" + json.dumps(result_dict)) - assert "" in result_dict - - def test_location_gl(self, serpapi_provider): - query = "cute cats" - location = ("Texas,Austin",) - - result_dict = serpapi_provider.search(query=query, num=2, location=location, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_location_gl result1:\n" + json.dumps(result_dict)) - assert result_dict["search_parameters"]["location_requested"] in location - - gl = "uk" - result_dict = serpapi_provider.search(query=query, num=2, gl=gl, location=location, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_location_gl result2:\n" + json.dumps(result_dict)) - assert result_dict["search_parameters"]["location_requested"] in location - assert result_dict["search_parameters"]["gl"] == gl - - def test_domain(self, serpapi_provider): - query = "Cute cat" - - google_domain = None - result_dict = serpapi_provider.search(query=query, google_domain=google_domain, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_domain result1:\n" + json.dumps(result_dict)) - assert result_dict["search_parameters"]["google_domain"] == "google.com" - - google_domain = "google.al" - result_dict = serpapi_provider.search(query=query, google_domain=google_domain, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_domain result2:\n" + json.dumps(result_dict)) - assert result_dict["search_parameters"]["google_domain"] == google_domain - - # this spec would be ignored if engine is not google - google_domain = "google.al" - result_dict = serpapi_provider.search(query=query, google_domain=google_domain, engine="bing") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_domain result3:\n" + json.dumps(result_dict)) - assert result_dict["search_parameters"].get("google_domain") is None - - def test_tbs(self, serpapi_provider): - query = "cute cats" - - tbs = "dur:l" - result_dict = serpapi_provider.search(query=query, num=2, tbs=tbs, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_location_gl result1:\n" + json.dumps(result_dict)) - assert tbs == result_dict["search_parameters"]["tbs"] - - def test_safe(self, serpapi_provider): - query = "I am looking for tools to hurt animals" - - safe = SafeMode.ACTIVE - result_dict = serpapi_provider.search(query=query, num=2, safe=safe, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_safe result1:\n" + json.dumps(result_dict)) - assert safe == result_dict["search_parameters"]["safe"] - - safe = None - result_dict = serpapi_provider.search(query=query, num=2, safe=safe, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_safe result12\n" + json.dumps(result_dict)) - assert result_dict["search_parameters"].get("safe") is None - - # google is tolerant to bad safe input - safe = "None" - result_dict = serpapi_provider.search(query=query, num=2, safe=safe, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_safe result3\n" + json.dumps(result_dict)) - assert result_dict["search_parameters"].get("safe") is None - - def test_nfpr_filter(self, serpapi_provider): - query = "cute cats" - - nfpr = True - filter = "videos" - result_dict = serpapi_provider.search(query=query, num=2, nfpr=nfpr, filter=filter, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_nfpr_filter result1:\n" + json.dumps(result_dict)) - assert str(nfpr) == result_dict["search_parameters"]["nfpr"] - assert filter == result_dict["search_parameters"]["filter"] - - nfpr = False - filter = "videos" - result_dict = serpapi_provider.search(query=query, num=2, nfpr=nfpr, filter=filter, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_nfpr_filter result2:\n" + json.dumps(result_dict)) - assert result_dict["search_parameters"].get("nfpr") is None - assert filter == result_dict["search_parameters"]["filter"] - - def test_device(self, serpapi_provider): - query = "cute cats" - - result_dict = serpapi_provider.search(query=query, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_device result1:\n" + json.dumps(result_dict)) - assert result_dict["search_parameters"]["device"] == "desktop" - - device = "mobile" - result_dict = serpapi_provider.search(query=query, device=device, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_device result1:\n" + json.dumps(result_dict)) - assert result_dict["search_parameters"]["device"] == device - - def test_cache_asynch(self, serpapi_provider): - query = "cute cats" - - no_cache = True - asynch = False - result_dict = serpapi_provider.search(query=query, no_cache=no_cache, asynch=asynch, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_cache_asynch result1:\n" + json.dumps(result_dict)) - assert result_dict["search_information"] is not None - - no_cache = True - asynch = True - result_dict = serpapi_provider.search(query=query, no_cache=no_cache, asynch=asynch, engine="google") - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_cache_asynch result2:\n" + json.dumps(result_dict)) - # no_cache and async could not be True at the same time - assert result_dict.get("search_information") is None - - def test_engine(self, serpapi_provider): + def test_engine(self, serp_connection): query = "cute cats" num = 2 - engine = Engine.GOOGLE.value - result_dict = serpapi_provider.search(query=query, num=num, safe=SafeMode.ACTIVE, engine=engine) - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_engine result1:\n" + json.dumps(result_dict)) + result_dict = search( + connection=serp_connection, query=query, num=num, safe=SafeMode.ACTIVE, engine=Engine.GOOGLE.value) + utils.is_json_serializable(result_dict, "serp api search()") assert result_dict["search_metadata"]["google_url"] is not None assert int(result_dict["search_parameters"]["num"]) == num assert result_dict["search_parameters"]["safe"].lower() == "active" - engine = Engine.BING.value - result_dict = serpapi_provider.search(query=query, num=num, safe=SafeMode.ACTIVE, engine=engine) - utils.is_json_serializable(result_dict, "Serp_API.search()") - print("test_engine result2:\n" + json.dumps(result_dict)) + result_dict = search( + connection=serp_connection, query=query, num=num, safe=SafeMode.ACTIVE, engine=Engine.BING.value) + utils.is_json_serializable(result_dict, "serp api search()") assert int(result_dict["search_parameters"]["count"]) == num assert result_dict["search_parameters"]["safe_search"].lower() == "strict" - def test_invalid_api_key(self, serpapi_config): - serpapi_config.api_key = "hello" + def test_invalid_api_key(self, serpapi_connection): + serpapi_connection.api_key = "hello" query = "cute cats" num = 2 engine = Engine.GOOGLE.value error_msg = "Invalid API key. Your API key should be here: https://serpapi.com/manage-api-key" with pytest.raises(UserErrorException) as exc_info: - search(connection=serpapi_config, query=query, num=num, engine=engine) - assert error_msg == exc_info.value.args[0] - - def test_invalid_query_for_google(self, serpapi_config): - query = "" - num = 2 - engine = Engine.GOOGLE.value - error_msg = "Missing query `q` parameter." - with pytest.raises(UserErrorException) as exc_info: - search(connection=serpapi_config, query=query, num=num, engine=engine) + search(connection=serpapi_connection, query=query, num=num, engine=engine) assert error_msg == exc_info.value.args[0] - def test_invalid_query_for_bing(self, serpapi_config): + @pytest.mark.parametrize("engine", [Engine.GOOGLE.value, Engine.BING.value]) + def test_invalid_query(self, serpapi_connection, engine): query = "" num = 2 - engine = Engine.BING.value error_msg = "Missing query `q` parameter." with pytest.raises(UserErrorException) as exc_info: - search(connection=serpapi_config, query=query, num=num, engine=engine) - assert error_msg == exc_info.value.args[0] - - def test_invalid_api_key_for_html_output(self, serpapi_config): - serpapi_config.api_key = "hello" - query = "cute cats" - num = 2 - engine = Engine.GOOGLE.value - output = "html" - error_msg = "Invalid API key. Your API key should be here: https://serpapi.com/manage-api-key" - with pytest.raises(UserErrorException) as exc_info: - search(connection=serpapi_config, query=query, num=num, engine=engine, output=output) + search(connection=serpapi_connection, query=query, num=num, engine=engine) assert error_msg == exc_info.value.args[0]