-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathenvironment.py
253 lines (218 loc) · 9.63 KB
/
environment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
# Environment: user simulator
class ErrorEvaluator:
"""
This class defines an error evaluator, which will be used by user simulator (UserSim) to
simulate users' feedback.
"""
def __init__(self):
pass
def compare(self, g_sql, start_idx, tag_seq, bool_return_true_selections=False,
bool_return_true_semantic_units=False):
"""
A comparison between the current SQL with the ground truth.
:param g_sql: the ground truth SQL.
:param start_idx: the starting pointer to compare.
:param tag_seq: a sequence of semantic units.
:param bool_return_true_selections: Set to True if also returns the true selections for each semantic unit.
This will be used to simulate users' selections.
:param bool_return_true_semantic_units: Set to True if also returns the true semantic units to replace each
old one. This will be used to simulate gold users' selections.
:return: The ending index, evaluation output (a sequence of True/False, indicating whether the generated
semantic unit is correct or not), a sequence of true selections (if bool_return_truths is True).
"""
raise NotImplementedError
class UserSim:
"""
This is the class for user simulator.
"""
def __init__(self, error_evaluator):
"""
Constructor of UserSim.
:param error_evaluator: An instance of ErrorEvaluator.
"""
self.user_type = "sim"
self.patience = 3
self.error_evaluator = error_evaluator
self.ground_truth = None
self.tag_seq = None
self.dec_seq = None
self.eval_outputs = None # evaluation output (right/wrong)
self.true_selections = None # the true decisions used to simulate user selection
# interaction record
self.q_counter = 0 # number of questions
self.questioned_pointers = []
self.questioned_tags = []
self.option_selections = []
self.feedback_records = []
def update_truth(self, ground_truth):
"""
Update the user with the ground truth in this session. This will be used to simulate user feedback.
:param ground_truth: the ground truth SQL query.
:return:
"""
self.ground_truth = ground_truth
def update_pred(self, tag_seq, dec_seq):
"""
Update the user with the generated SQL query.
:param tag_seq: a sequence of semantic units (for the generated SQL query).
:param dec_seq: a sequence of decoding decisions (for the generated SQL query).
:return:
"""
self.tag_seq = tag_seq
self.dec_seq = dec_seq
_, self.eval_outputs, self.true_selections = self.error_evaluator.compare(
self.ground_truth, 0, self.tag_seq, bool_return_true_selections=True)
def record_user_feedback(self, context, user_answer, bool_qa=False):
"""
Record the interaction feedback.
:param context: the interaction context.
:param user_answer: the user feedback.
:param bool_qa: Set to True if the feedback comes from QA.
:return:
"""
self.feedback_records.append((context, user_answer))
if bool_qa:
self.questioned_tags.append((context, user_answer))
self.q_counter += 1 # number of questions + 1
def clear_counter(self):
"""
Clear session records.
:return:
"""
self.q_counter = 0
self.questioned_pointers = []
self.questioned_tags = []
self.option_selections = []
self.feedback_records = []
def get_answer(self, pointer, answer_sheet):
"""
Generate simulated user answers.
:param pointer: the pointer to the questioned semantic unit.
:param answer_sheet: a dict of {'yes'/'no': meta info}, used to simulate users' binary answers.
:return: the user answer.
"""
self.questioned_pointers.append(pointer)
pointer_eval_output = self.eval_outputs[pointer] # True=right, False=wrong
reverse_answer_sheet = {bool_right: ans for ans, (bool_right, _) in answer_sheet.items()}
answer = reverse_answer_sheet[pointer_eval_output]
# examine user patience
valid_eval_outputs = [item for item_idx, item in enumerate(self.eval_outputs[:pointer + 1])
if item is not None and item_idx in self.questioned_pointers]
if valid_eval_outputs[-3:].count(False) == self.patience:
answer = 'exit'
print("User answer: %s.\n" % answer)
return answer
def get_selection(self, pointer, answer_sheet, sel_none_of_above):
"""
Generate simulated user selections.
:param pointer: the pointer to the questioned semantic unit.
:param answer_sheet: a dict of {choice idx: corresponding decision idx}, used to simulate users' selections.
:param sel_none_of_above: the choice index of "none of the above".
:return: user selections (a list of indices).
"""
pointer_truth = self.true_selections[pointer] # ground-truth decision
selections = []
# if the prefix query is correct, possible true decisions exist
if pointer_truth is not None:
for select_id, select_val in answer_sheet.items():
if len(pointer_truth) and select_val in pointer_truth:
selections.append(select_id)
elif len(pointer_truth) == 0 and select_val is None:
selections.append(select_id)
if len(selections) == 0: # none of the above
selections.append(sel_none_of_above)
print("User answer: %s.\n" % str(selections))
return selections
class GoldUserSim(UserSim):
"""
This is the class for "gold" user simulator.
Different to UserSim, when the gold user is requested for a selection, she gives a gold label (if exists).
"""
def __init__(self, error_evaluator):
"""
Constructor of GoldUserSim.
:param error_evaluator: An instance of ErrorEvaluator.
"""
UserSim.__init__(self, error_evaluator)
self.user_type = "gold_sim"
self.true_semantic_units = None
def update_pred(self, tag_seq, dec_seq):
"""
Update the user with the generated SQL query.
:param tag_seq: a sequence of semantic units (for the generated SQL query).
:return:
"""
self.tag_seq = tag_seq
self.dec_seq = dec_seq
_, self.eval_outputs, self.true_semantic_units = self.error_evaluator.compare(
self.ground_truth, 0, self.tag_seq, bool_return_true_semantic_units=True)
def get_gold_selection(self, pointer):
"""
Generate gold user selections.
:param pointer: the pointer to the questioned semantic unit.
:return: gold_semantic_units (a list of gold semantic units),
gold_dec_items (a list of dec_items for each gold semantic unit),
sel_none_of_above (an integer, the option index referring to "none of the above"),
selections (a list of indices that gold user selects).
These lists will be used in feedback incorporation.
"""
raise NotImplementedError
class RealUser(UserSim):
"""
This is the class for real users (used in user study).
"""
def __init__(self, error_evaluator, bool_undo=True):
"""
Constructor of RealUser.
:param error_evaluator: An instance of ErrorEvaluator.
"""
UserSim.__init__(self, error_evaluator)
self.user_type = "real"
self.bool_undo = bool_undo
self.undo_semantic_units = []
def get_answer(self, pointer, *args):
"""
Request for user answers.
:param pointer: the pointer to the questioned semantic unit.
:param args: dummy inputs.
:return: the user answer.
"""
self.questioned_pointers.append(pointer)
if self.bool_undo:
answer = input("Please enter yes(y)/no(n)/undo/exit: ").lower().strip()
while answer not in {'yes', 'no', 'exit', 'y', 'n', 'undo'}:
answer = input("Please enter yes(y)/no(n)/undo/exit: ").lower().strip()
else:
answer = input("Please enter yes(y)/no(n)/exit: ").lower().strip()
while answer not in {'yes', 'no', 'exit', 'y', 'n'}:
answer = input("Please enter yes(y)/no(n)/exit: ").lower().strip()
if answer == 'y':
answer = 'yes'
elif answer == 'n':
answer = 'no'
return answer
def get_selection(self, pointer, answer_sheet, sel_none_of_above):
"""
Request for user selections.
:param pointer: the pointer to the questioned semantic unit.
:param answer_sheet: dummy inputs.
:param sel_none_of_above: dummy inputs.
:return: user selections.
"""
def answer_parsing(answer_str):
selections = answer_str.split(", ")
try:
selections = [int(sel) for sel in selections]
except:
return None
else:
assert len(selections)
if sel_none_of_above in selections:
assert len(selections) == 1 # mutual exclusive "none of the above"
return selections
answer = input("Please enter the option id(s) delimited by comma ', ': ")
selections = answer_parsing(answer)
while selections is None:
answer = input("Please enter the option id(s) delimited by comma ', ': ")
selections = answer_parsing(answer)
return selections