forked from microsoft/promptflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchat_with_pdf_test.py
96 lines (82 loc) · 3.06 KB
/
chat_with_pdf_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import unittest
import promptflow
from base_test import BaseTest
from promptflow.exceptions import ValidationException
class TestChatWithPDF(BaseTest):
def setUp(self):
super().setUp()
self.pf = promptflow.PFClient()
def tearDown(self) -> None:
return super().tearDown()
def test_run_chat_with_pdf(self):
result = self.pf.test(
flow=self.flow_path,
inputs={
"chat_history": [],
"pdf_url": "https://arxiv.org/pdf/1810.04805.pdf",
"question": "BERT stands for?",
"config": self.config_2k_context,
},
)
print(result)
self.assertTrue(
result["answer"].find(
"Bidirectional Encoder Representations from Transformers"
)
!= -1
)
def test_bulk_run_chat_with_pdf(self):
run = self.create_chat_run()
self.pf.stream(run) # wait for completion
self.assertEqual(run.status, "Completed")
details = self.pf.get_details(run)
self.assertEqual(details.shape[0], 3)
def test_eval(self):
run_2k, eval_groundedness_2k, eval_pi_2k = self.run_eval_with_config(
self.config_2k_context,
display_name="chat_with_pdf_2k_context",
)
run_3k, eval_groundedness_3k, eval_pi_3k = self.run_eval_with_config(
self.config_3k_context,
display_name="chat_with_pdf_3k_context",
)
self.check_run_basics(run_2k)
self.check_run_basics(run_3k)
self.check_run_basics(eval_groundedness_2k)
self.check_run_basics(eval_pi_2k)
self.check_run_basics(eval_groundedness_3k)
self.check_run_basics(eval_pi_3k)
def test_bulk_run_valid_mapping(self):
run = self.create_chat_run(
column_mapping={
"question": "${data.question}",
"pdf_url": "${data.pdf_url}",
"chat_history": "${data.chat_history}",
"config": self.config_2k_context,
}
)
self.pf.stream(run) # wait for completion
self.assertEqual(run.status, "Completed")
details = self.pf.get_details(run)
self.assertEqual(details.shape[0], 3)
# def test_bulk_run_mapping_missing_one_column(self):
# # in this case, run won't be created.
# with self.assertRaises(ValidationException):
# self.create_chat_run(
# column_mapping={
# "question": "${data.question}",
# "pdf_url": "${data.pdf_url}",
# }
# )
def test_bulk_run_invalid_mapping(self):
# in this case, run won't be created.
with self.assertRaises(ValidationException):
self.create_chat_run(
column_mapping={
"question": "${data.question_not_exist}",
"pdf_url": "${data.pdf_url}",
"chat_history": "${data.chat_history}",
}
)
if __name__ == "__main__":
unittest.main()