13
13
from sklearn .metrics import accuracy_score , recall_score
14
14
from sklearn .model_selection import train_test_split
15
15
16
- from finetune import Classifier , SequenceLabeler
16
+ from finetune import Classifier , SequenceLabeler , Comparison , ComparisonRegressor , MultipleChoice
17
17
from finetune .base_models import TextCNN , BERTModelCased , GPT2Model , GPTModel , RoBERTa , GPT
18
18
from finetune .config import get_config
19
19
from finetune .util .metrics import (
20
20
sequence_labeling_token_precision ,
21
21
sequence_labeling_token_recall ,
22
22
)
23
23
from finetune .datasets .reuters import Reuters
24
- from finetune .encoding .input_encoder import get_default_context , tokenize_context , ArrayEncodedOutput
24
+ from finetune .encoding .input_encoder import tokenize_context , ArrayEncodedOutput
25
25
26
26
27
27
# prevent excessive warning logs
28
28
warnings .filterwarnings ("ignore" )
29
29
os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "3"
30
30
31
31
class TestAuxiliaryTokenization (unittest .TestCase ):
32
- def test_get_default_context (self ):
33
- context = [
34
- (2 , ["single" , True , 23.3 , 4 ]),
35
- (4 , ["double" , True , 24.3 , 2 ]),
36
- (8 , ["single" , False , 25.3 , 3 ]),
37
- ]
38
-
39
- expected = ["single" , True , 24.3 , 3 ]
40
- self .assertEqual (get_default_context (context ), expected )
41
-
42
32
def test_tokenize_context (self ):
43
33
encoded_output = ArrayEncodedOutput (
44
34
token_ids = [
@@ -60,15 +50,16 @@ def test_tokenize_context(self):
60
50
{'token' : "only" , 'start' : 13 , 'end' : 17 , 'left' : 20 , 'bold' : False },
61
51
{'token' : "$80" , 'start' : 18 , 'end' : 21 , 'left' : 30 , 'bold' : True },
62
52
]
63
- expanded_context = tokenize_context (context , encoded_output )
53
+ config = get_config (** {'default_context' : {'left' : 0 , 'bold' : False }})
54
+ expanded_context = tokenize_context (context , encoded_output , config )
64
55
expected = [
65
- [False , 20 ],
56
+ [False , 0 ],
66
57
[False , 10 ],
67
58
[False , 10 ],
68
59
[False , 20 ],
69
60
[True , 30 ],
70
61
[True , 30 ],
71
- [False , 20 ]
62
+ [False , 0 ]
72
63
]
73
64
print (expanded_context )
74
65
np .testing .assert_array_equal (expected , expanded_context )
@@ -119,12 +110,13 @@ def default_config(self, **kwargs):
119
110
defaults = {
120
111
"batch_size" : 2 ,
121
112
"max_length" : 256 ,
122
- "n_epochs" : 1000 ,
113
+ "n_epochs" : 1 , # we mostly are making sure nothing errors out
123
114
"base_model" : self .base_model ,
124
115
"val_size" : 0 ,
125
116
"use_auxiliary_info" : True ,
126
117
"context_dim" : 1 ,
127
- "val_set" : (self .trainX , self .trainY , self .train_context )
118
+ "val_set" : (self .trainX , self .trainY , self .train_context ),
119
+ "default_context" : {'bold' : False }
128
120
}
129
121
defaults .update (kwargs )
130
122
return dict (get_config (** defaults ))
@@ -178,7 +170,6 @@ def test_sequence_labeler_no_auxiliary(self):
178
170
Ensure model training does not error out
179
171
Ensure model returns reasonable predictions
180
172
"""
181
-
182
173
model = SequenceLabeler (** self .default_config (use_auxiliary_info = False , val_set = (self .trainX , self .trainY )))
183
174
model .fit (self .trainX , self .trainY_seq )
184
175
preds = model .predict (self .trainX )
@@ -190,12 +181,64 @@ def test_sequence_labeler_auxiliary(self):
190
181
Ensure model training does not error out
191
182
Ensure model returns reasonable predictions
192
183
"""
193
-
194
- model = SequenceLabeler (** self .default_config ())
184
+ # here we want to make sure we're actually using context
185
+ model = SequenceLabeler (** self .default_config (n_epochs = 1500 ))
195
186
model .fit (self .trainX , self .trainY_seq , context = self .train_context )
196
187
preds = model .predict (self .trainX , context = self .train_context )
197
188
self ._evaluate_sequence_preds (preds , True )
198
-
189
+
190
+ def test_comparison_auxiliary (self ):
191
+ """
192
+ Ensure model training does not error out
193
+ Ensure model returns reasonable predictions
194
+ """
195
+ model = Comparison (** self .default_config (chunk_long_sequences = False , max_length = 50 , batch_size = 4 ))
196
+ trainX = [['i like apples' , 'i like apples' ]] * 4
197
+ trainY = ['A' , 'B' , 'C' , 'D' ]
198
+ train_context = [
199
+ [self .train_context [i ], self .train_context [j ]] for i in [0 , 1 ] for j in [0 , 1 ]
200
+ ]
201
+ print (train_context )
202
+ model .fit (trainX , trainY , context = train_context )
203
+ preds = model .predict (trainX , context = train_context )
204
+
205
+ def test_comparison_regressor_auxiliary (self ):
206
+ """
207
+ Ensure model training does not error out
208
+ Ensure model returns reasonable predictions
209
+ """
210
+ model = ComparisonRegressor (** self .default_config (chunk_long_sequences = False , max_length = 50 , batch_size = 4 ))
211
+ trainX = [['i like apples' , 'i like apples' ]] * 4
212
+ trainY = [0 , .5 , .5 , 1 ]
213
+ train_context = [
214
+ [self .train_context [i ], self .train_context [j ]] for i in [0 , 1 ] for j in [0 , 1 ]
215
+ ]
216
+ print (train_context )
217
+ model .fit (trainX , trainY , context = train_context )
218
+ preds = model .predict (trainX , context = train_context )
219
+
220
+ def test_multiple_choice_auxiliary (self ):
221
+ """
222
+ Ensure model training does not error out
223
+ Ensure model returns reasonable predictions
224
+ """
225
+ model = MultipleChoice (** self .default_config (chunk_long_sequences = False , max_length = 50 , batch_size = 4 ))
226
+ questions = ['i like apples' ] * 2
227
+ answers = [['happy' , 'sad' , 'neutral' , 'not satisfied' ], ['happy' , 'sad' , 'neutral' , 'not satisfied' ]]
228
+ correct_answers = ['happy' , 'sad' ]
229
+ answer_context = [
230
+ [{'start' : 0 , 'end' : 5 , 'token' : 'happy' , 'bold' : False }],
231
+ [{'start' : 0 , 'end' : 3 , 'token' : 'sad' , 'bold' : False }],
232
+ [{'start' : 0 , 'end' : 7 , 'token' : 'neutral' , 'bold' : False }],
233
+ [{'start' : 0 , 'end' : 3 , 'token' : 'not' , 'bold' : False }, {'start' : 4 , 'end' : 13 , 'token' : 'satisfied' , 'bold' : False }],
234
+ ]
235
+ # context looks like [[{}, {}, {}], [{}], [{}], [{}], [{}, {}]] where the first list is for the question
236
+ # and the subsequent ones are for each answer
237
+ train_context = [[self .train_context [0 ]] + answer_context ] + [[self .train_context [1 ]] + answer_context ]
238
+ print (train_context )
239
+ model .fit (questions , answers , correct_answers , context = train_context )
240
+ preds = model .predict (questions , answers , context = train_context )
241
+
199
242
def test_save_load (self ):
200
243
"""
201
244
Ensure saving + loading does not cause errors
0 commit comments