@@ -61,6 +61,7 @@ def __init__(self,
61
61
metric_name ,
62
62
set_name ,
63
63
batch_size ,
64
+ gt_pos = [],
64
65
each_n_epochs = 1 ,
65
66
max_eval_samples = None ,
66
67
extra_vars = dict (),
@@ -96,6 +97,7 @@ def __init__(self,
96
97
:param metric_name: name of the performance metric
97
98
:param set_name: list with the names of the set splits that will be evaluated
98
99
:param batch_size: batch size used during sampling
100
+ :param gt_pos: position of the GT output to evaluate in model's outputs
99
101
:param each_n_epochs: sampling each this number of epochs or updates
100
102
:param max_eval_samples: maximum number of samples evaluated
101
103
:param extra_vars: dictionary of extra variables. See evaluation metrics in keras_wrapper/extra/evaluation.py for assigning the needed extra variables.
@@ -136,6 +138,7 @@ def __init__(self,
136
138
self .is_3DLabel = is_3DLabel
137
139
self .sampling = sampling
138
140
self .beam_search = beam_search
141
+ self .gt_pos = gt_pos
139
142
self .beam_batch_size = beam_batch_size
140
143
self .metric_name = metric_name
141
144
self .set_name = set_name
@@ -211,7 +214,7 @@ def evaluate(self, epoch, counter_name='epoch'):
211
214
}
212
215
213
216
params_prediction .update (checkDefaultParamsBeamSearch (self .extra_vars ))
214
- predictions = self .model_to_eval .predictBeamSearchNet (self .ds , params_prediction )[s ]
217
+ predictions_all = self .model_to_eval .predictBeamSearchNet (self .ds , params_prediction )[s ]
215
218
else :
216
219
orig_size = self .extra_vars .get ('eval_orig_size' , False )
217
220
params_prediction = {'batch_size' : self .batch_size ,
@@ -225,122 +228,126 @@ def evaluate(self, epoch, counter_name='epoch'):
225
228
postprocess_fun = [self .ds .convert_3DLabels_to_bboxes , self .extra_vars [s ]['references_orig_sizes' ]]
226
229
elif orig_size :
227
230
postprocess_fun = [self .ds .resize_semantic_output , self .extra_vars [s ]['eval_orig_size_id' ]]
228
- predictions = \
231
+ predictions_all = \
229
232
self .model_to_eval .predictNet (self .ds , params_prediction , postprocess_fun = postprocess_fun )[s ]
230
233
231
- if self .is_text :
232
- if params_prediction .get ('pos_unk' , False ):
233
- samples = predictions [0 ]
234
- alphas = predictions [1 ]
234
+ if not self .gt_pos :
235
+ predictions_all = [predictions_all ]
236
+ self .gt_pos = [0 ]
235
237
236
- if eval ('self.ds.loaded_raw_' + s + '[0]' ):
237
- sources = predictions [2 ]
238
+ for gt_pos in self .gt_pos :
239
+ predictions = predictions_all [gt_pos ]
240
+
241
+ if self .is_text :
242
+ if params_prediction .get ('pos_unk' , False ):
243
+ samples = predictions [0 ]
244
+ alphas = predictions [1 ]
245
+
246
+ if eval ('self.ds.loaded_raw_' + s + '[0]' ):
247
+ sources = predictions [2 ]
248
+ else :
249
+ sources = []
250
+ for preds in predictions [2 ]:
251
+ for src in preds [self .input_text_id ]:
252
+ sources .append (src )
253
+ sources = decode_predictions_beam_search (sources ,
254
+ self .index2word_x ,
255
+ pad_sequences = True ,
256
+ verbose = self .verbose )
257
+ heuristic = self .extra_vars ['heuristic' ]
238
258
else :
239
- sources = []
240
- for preds in predictions [2 ]:
241
- for src in preds [self .input_text_id ]:
242
- sources .append (src )
243
- sources = decode_predictions_beam_search (sources ,
244
- self .index2word_x ,
245
- pad_sequences = True ,
246
- verbose = self .verbose )
247
- heuristic = self .extra_vars ['heuristic' ]
248
- else :
249
- samples = predictions
250
- alphas = None
251
- heuristic = None
252
- sources = None
253
- if self .out_pred_idx is not None :
254
- samples = samples [self .out_pred_idx ]
255
- # Convert predictions into sentences
256
- if self .beam_search :
257
- predictions = decode_predictions_beam_search (samples ,
259
+ samples = predictions
260
+ alphas = None
261
+ heuristic = None
262
+ sources = None
263
+ if self .out_pred_idx is not None :
264
+ samples = samples [self .out_pred_idx ]
265
+ # Convert predictions into sentences
266
+ if self .beam_search :
267
+ predictions = decode_predictions_beam_search (samples ,
258
268
self .index2word_y ,
259
269
alphas = alphas ,
260
270
x_text = sources ,
261
271
heuristic = heuristic ,
262
272
mapping = self .extra_vars .get ('mapping' , None ),
263
273
verbose = self .verbose )
264
- else :
265
- probs = predictions
266
- predictions = decode_predictions (predictions ,
274
+ else :
275
+ probs = predictions
276
+ predictions = decode_predictions (predictions ,
267
277
1 , # always set temperature to 1
268
278
self .index2word_y ,
269
279
self .sampling_type ,
270
280
verbose = self .verbose )
271
281
272
- # Apply detokenization function if needed
273
- if self .extra_vars .get ('apply_detokenization' , False ):
274
- predictions = map (self .extra_vars ['detokenize_f' ], predictions )
275
-
276
-
277
- elif self .is_multilabel :
278
- if self .multilabel_idx is not None :
279
- predictions = predictions [self .multilabel_idx ]
280
- predictions = decode_multilabel (predictions ,
281
- self .index2word_y ,
282
- min_val = self .min_pred_multilabel ,
283
- verbose = self .verbose )
284
-
285
- # Store predictions
286
- if self .write_samples :
287
- # Store result
288
- filepath = self .save_path + '/' + s + '_' + counter_name + '_' + str (epoch ) + '.pred' # results file
289
- if self .write_type == 'list' :
290
- list2file (filepath , predictions )
291
- elif self .write_type == 'vqa' :
292
- try :
282
+ # Apply detokenization function if needed
283
+ if self .extra_vars .get ('apply_detokenization' , False ):
284
+ predictions = map (self .extra_vars ['detokenize_f' ], predictions )
285
+
286
+
287
+ elif self .is_multilabel :
288
+ if self .multilabel_idx is not None :
289
+ predictions = predictions [self .multilabel_idx ]
290
+ predictions = decode_multilabel (predictions ,
291
+ self .index2word_y ,
292
+ min_val = self .min_pred_multilabel ,
293
+ verbose = self .verbose )
294
+
295
+ # Store predictions
296
+ if self .write_samples :
297
+ # Store result
298
+ filepath = self .save_path + '/' + s + '_' + counter_name + '_' + str (epoch ) + '.pred' # results file
299
+ if self .write_type == 'list' :
300
+ list2file (filepath , predictions )
301
+ elif self .write_type == 'vqa' :
293
302
exec ('refs = self.ds.Y_' + s + '[self.gt_id]' )
294
- except :
295
- refs = ['N/A' for i in range (probs .shape [0 ])]
296
- extra_data_plot = {'reference' : refs ,
297
- 'probs' : probs ,
298
- 'vocab' : self .index2word_y }
299
- list2vqa (filepath , predictions , self .extra_vars [s ]['question_ids' ], extra = extra_data_plot )
300
- elif self .write_type == 'listoflists' :
301
- listoflists2file (filepath , predictions )
302
- elif self .write_type == 'numpy' :
303
- numpy2file (filepath , predictions )
304
- elif self .write_type == '3DLabels' :
305
- # TODO:
306
- print ("WRITE SAMPLES FUNCTION NOT IMPLEMENTED" )
307
- else :
308
- raise NotImplementedError (
309
- 'The store type "' + self .write_type + '" is not implemented.' )
303
+ extra_data_plot = {'reference' : refs ,
304
+ 'probs' : probs }
305
+ list2vqa (filepath , predictions , self .extra_vars [s ]['question_ids' ], extra = extra_data_plot )
306
+ elif self .write_type == 'listoflists' :
307
+ listoflists2file (filepath , predictions )
308
+ elif self .write_type == 'numpy' :
309
+ numpy2file (filepath , predictions )
310
+ elif self .write_type == '3DLabels' :
311
+ # TODO:
312
+ print ("WRITE SAMPLES FUNCTION NOT IMPLEMENTED" )
313
+ else :
314
+ raise NotImplementedError (
315
+ 'The store type "' + self .write_type + '" is not implemented.' )
310
316
311
- # Evaluate on each metric
312
- for metric in self .metric_name :
313
- if self .verbose > 0 :
314
- logging .info ('Evaluating on metric ' + metric )
315
- filepath = self .save_path + '/' + s + '.' + metric # results file
316
-
317
- # Evaluate on the chosen metric
318
- metrics = evaluation .select [metric ](
319
- pred_list = predictions ,
320
- verbose = self .verbose ,
321
- extra_vars = self .extra_vars ,
322
- split = s )
323
-
324
- # Print results to file and store in model log
325
- with open (filepath , 'a' ) as f :
326
- header = counter_name + ','
327
- line = str (epoch ) + ','
328
- # Store in model log
329
- self .model_to_eval .log (s , counter_name , epoch )
330
- for metric_ in sorted (metrics ):
331
- all_metrics .append (metric_ )
332
- value = metrics [metric_ ]
333
- header += metric_ + ','
334
- line += str (value ) + ','
317
+ # Evaluate on each metric
318
+ for metric in self .metric_name :
319
+ if self .verbose > 0 :
320
+ logging .info ('Evaluating on metric ' + metric )
321
+ filepath = self .save_path + '/' + s + '.' + metric # results file
322
+
323
+ # Evaluate on the chosen metric
324
+ metrics = evaluation .select [metric ](
325
+ pred_list = predictions ,
326
+ verbose = self .verbose ,
327
+ extra_vars = self .extra_vars ,
328
+ split = s )
329
+
330
+ # Print results to file and store in model log
331
+ with open (filepath , 'a' ) as f :
332
+ header = counter_name + ','
333
+ line = str (epoch ) + ','
335
334
# Store in model log
336
- self .model_to_eval .log (s , metric_ , value )
337
- if not self .written_header :
338
- f .write (header + '\n ' )
339
- self .written_header = True
340
- f .write (line + '\n ' )
335
+ self .model_to_eval .log (s , counter_name , epoch )
336
+ for metric_ in sorted (metrics ):
337
+ all_metrics .append (metric_ )
338
+ value = metrics [metric_ ]
339
+ header += metric_ + ','
340
+ line += str (value ) + ','
341
+ # Store in model log
342
+ self .model_to_eval .log (s , metric_ , value )
343
+ if not self .written_header :
344
+ f .write (header + '\n ' )
345
+ self .written_header = True
346
+ f .write (line + '\n ' )
347
+
348
+ if self .verbose > 0 :
349
+ logging .info ('Done evaluating on metric ' + metric )
341
350
342
- if self .verbose > 0 :
343
- logging .info ('Done evaluating on metric ' + metric )
344
351
345
352
# Plot results so far
346
353
self .model_to_eval .plot (counter_name , set (all_metrics ), self .set_name , upperbound = self .max_plot )
0 commit comments