1
1
import json
2
2
import os
3
3
from unittest .mock import MagicMock , Mock , patch
4
- from flask import Flask , request
4
+ from flask import Flask
5
5
import pytest
6
- import requests
7
6
import urllib
8
7
9
- from app import (extract_value , fetchUserGroups , format_as_ndjson ,
8
+ from app import (extract_value , fetchUserGroups ,
10
9
formatApiResponseNoStreaming , formatApiResponseStreaming ,
11
10
generateFilterString , is_chat_model , parse_multi_columns ,
12
11
prepare_body_headers_with_data , should_use_data ,
13
- stream_with_data , conversation_with_data , draft_document_generate )
12
+ stream_with_data , draft_document_generate )
14
13
15
14
AZURE_SEARCH_SERVICE = os .environ .get ("AZURE_SEARCH_SERVICE" , "" )
16
15
AZURE_OPENAI_KEY = os .environ .get ("AZURE_OPENAI_KEY" , "" )
17
16
AZURE_SEARCH_PERMITTED_GROUPS_COLUMN = os .environ .get (
18
17
"AZURE_SEARCH_PERMITTED_GROUPS_COLUMN" , ""
19
18
)
20
19
20
+
21
21
def test_parse_multi_columns ():
22
22
assert parse_multi_columns ("a|b|c" ) == ["a" , "b" , "c" ]
23
23
assert parse_multi_columns ("a,b,c" ) == ["a" , "b" , "c" ]
@@ -160,9 +160,9 @@ def test_generateFilterString(mock_fetchUserGroups):
160
160
userToken = "fake_token"
161
161
162
162
filter_string = generateFilterString (userToken )
163
- print ("filter string" ,filter_string )
164
163
assert filter_string == "None/any(g:search.in(g, '1, 2'))"
165
164
165
+
166
166
def test_prepare_body_headers_with_data ():
167
167
# Create a mock request
168
168
mock_request = MagicMock ()
@@ -208,19 +208,6 @@ def test_prepare_body_headers_with_data():
208
208
assert headers ["x-ms-useragent" ] == "GitHubSampleWebApp/PublicAPI/3.0.0"
209
209
210
210
211
- def test_invalid_datasource_type ():
212
- mock_request = MagicMock ()
213
- mock_request .json = {"messages" : ["Hello, world!" ], "index_name" : "grants" }
214
-
215
-
216
- with patch ("app.DATASOURCE_TYPE" , "InvalidType" ):
217
- with pytest .raises (Exception ) as exc_info :
218
- prepare_body_headers_with_data (mock_request )
219
- assert "DATASOURCE_TYPE is not configured or unknown: InvalidType" in str (
220
- exc_info .value
221
- )
222
-
223
-
224
211
def test_invalid_datasource_type ():
225
212
mock_request = MagicMock ()
226
213
mock_request .json = {"messages" : ["Hello, world!" ], "index_name" : "grants" }
@@ -318,16 +305,16 @@ def test_stream_with_data_azure_success():
318
305
print (results , "result test case" )
319
306
assert len (results ) == 1
320
307
308
+
321
309
# Mock constants
322
310
USE_AZURE_AI_STUDIO = "true"
323
311
AZURE_OPENAI_PREVIEW_API_VERSION = "2023-06-01-preview"
324
312
DEBUG_LOGGING = False
325
-
326
313
AZURE_SEARCH_SERVICE = os .environ .get ("AZURE_SEARCH_SERVICE" , "mysearchservice" )
327
314
328
315
329
316
def test_stream_with_data_azure_error ():
330
-
317
+
331
318
body = {
332
319
"messages" : [
333
320
{
@@ -380,10 +367,9 @@ def test_stream_with_data_azure_error():
380
367
}
381
368
],
382
369
}
383
-
370
+
384
371
if USE_AZURE_AI_STUDIO .lower () == "true" :
385
- body = body
386
-
372
+ body = body
387
373
headers = {
388
374
"Content-Type" : "application/json" ,
389
375
"api-key" : "" ,
@@ -410,6 +396,7 @@ def test_stream_with_data_azure_error():
410
396
print (results , "result test case" )
411
397
assert len (results ) == 1
412
398
399
+
413
400
def test_formatApiResponseNoStreaming ():
414
401
rawResponse = {
415
402
"id" : "1" ,
@@ -463,6 +450,8 @@ def test_extract_value():
463
450
assert extract_value ("unknown" , text ) == "N/A"
464
451
465
452
app = Flask (__name__ )
453
+
454
+
466
455
app .add_url_rule ("/draft_document/generate_section" , "draft_document_generate" , draft_document_generate , methods = ["POST" ])
467
456
468
457
@@ -541,13 +530,15 @@ def test_draft_document_generate_with_context(mock_os_environ, mock_urlopen, cli
541
530
assert "content" in response_json
542
531
assert response_json ["content" ] == "Generated content with context."
543
532
533
+
544
534
@pytest .fixture
545
535
def clients ():
546
536
app = Flask (__name__ )
547
537
app .route ('/draft_document/generate_section' , methods = ['POST' ])(draft_document_generate )
548
538
client = app .test_client ()
549
539
yield client
550
540
541
+
551
542
@patch ("urllib.request.urlopen" )
552
543
@patch ("os.environ.get" )
553
544
def test_draft_document_generate_http_error (mock_env_get , mock_urlopen , client ):
@@ -582,5 +573,3 @@ def test_draft_document_generate_http_error(mock_env_get, mock_urlopen, client):
582
573
)
583
574
584
575
assert response .status_code == 200
585
-
586
-
0 commit comments