Skip to content

Commit 62e6dd8

Browse files
author
Laura Lopez
committed
Multiple outputs evaluation in EvalPerformance callback
1 parent 79e6a8f commit 62e6dd8

File tree

3 files changed

+107
-98
lines changed

3 files changed

+107
-98
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
keras_wrapper/Datasets
33
keras_wrapper/Models
44
.idea
5+
Multimodal_Keras_Wrapper.egg-info

keras_wrapper/dataset.py

+1
Original file line numberDiff line numberDiff line change
@@ -1287,6 +1287,7 @@ def build_vocabulary(self, captions, id, do_split=True, min_occ=0, n_words=0, sp
12871287
for w, k in self.extra_words.iteritems():
12881288
dictionary[w] = k
12891289

1290+
12901291
# Store dictionary and append to previously existent if needed.
12911292
if id not in self.vocabulary:
12921293
self.vocabulary[id] = dict()

keras_wrapper/extra/callbacks.py

+105-98
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(self,
6161
metric_name,
6262
set_name,
6363
batch_size,
64+
gt_pos=[],
6465
each_n_epochs=1,
6566
max_eval_samples=None,
6667
extra_vars=dict(),
@@ -96,6 +97,7 @@ def __init__(self,
9697
:param metric_name: name of the performance metric
9798
:param set_name: list with the names of the set splits that will be evaluated
9899
:param batch_size: batch size used during sampling
100+
:param gt_pos: position of the GT output to evaluate in model's outputs
99101
:param each_n_epochs: sampling each this number of epochs or updates
100102
:param max_eval_samples: maximum number of samples evaluated
101103
:param extra_vars: dictionary of extra variables. See evaluation metrics in keras_wrapper/extra/evaluation.py for assigning the needed extra variables.
@@ -136,6 +138,7 @@ def __init__(self,
136138
self.is_3DLabel = is_3DLabel
137139
self.sampling = sampling
138140
self.beam_search = beam_search
141+
self.gt_pos = gt_pos
139142
self.beam_batch_size = beam_batch_size
140143
self.metric_name = metric_name
141144
self.set_name = set_name
@@ -211,7 +214,7 @@ def evaluate(self, epoch, counter_name='epoch'):
211214
}
212215

213216
params_prediction.update(checkDefaultParamsBeamSearch(self.extra_vars))
214-
predictions = self.model_to_eval.predictBeamSearchNet(self.ds, params_prediction)[s]
217+
predictions_all = self.model_to_eval.predictBeamSearchNet(self.ds, params_prediction)[s]
215218
else:
216219
orig_size = self.extra_vars.get('eval_orig_size', False)
217220
params_prediction = {'batch_size': self.batch_size,
@@ -225,122 +228,126 @@ def evaluate(self, epoch, counter_name='epoch'):
225228
postprocess_fun = [self.ds.convert_3DLabels_to_bboxes, self.extra_vars[s]['references_orig_sizes']]
226229
elif orig_size:
227230
postprocess_fun = [self.ds.resize_semantic_output, self.extra_vars[s]['eval_orig_size_id']]
228-
predictions = \
231+
predictions_all = \
229232
self.model_to_eval.predictNet(self.ds, params_prediction, postprocess_fun=postprocess_fun)[s]
230233

231-
if self.is_text:
232-
if params_prediction.get('pos_unk', False):
233-
samples = predictions[0]
234-
alphas = predictions[1]
234+
if not self.gt_pos:
235+
predictions_all = [predictions_all]
236+
self.gt_pos = [0]
235237

236-
if eval('self.ds.loaded_raw_' + s + '[0]'):
237-
sources = predictions[2]
238+
for gt_pos in self.gt_pos:
239+
predictions = predictions_all[gt_pos]
240+
241+
if self.is_text:
242+
if params_prediction.get('pos_unk', False):
243+
samples = predictions[0]
244+
alphas = predictions[1]
245+
246+
if eval('self.ds.loaded_raw_' + s + '[0]'):
247+
sources = predictions[2]
248+
else:
249+
sources = []
250+
for preds in predictions[2]:
251+
for src in preds[self.input_text_id]:
252+
sources.append(src)
253+
sources = decode_predictions_beam_search(sources,
254+
self.index2word_x,
255+
pad_sequences=True,
256+
verbose=self.verbose)
257+
heuristic = self.extra_vars['heuristic']
238258
else:
239-
sources = []
240-
for preds in predictions[2]:
241-
for src in preds[self.input_text_id]:
242-
sources.append(src)
243-
sources = decode_predictions_beam_search(sources,
244-
self.index2word_x,
245-
pad_sequences=True,
246-
verbose=self.verbose)
247-
heuristic = self.extra_vars['heuristic']
248-
else:
249-
samples = predictions
250-
alphas = None
251-
heuristic = None
252-
sources = None
253-
if self.out_pred_idx is not None:
254-
samples = samples[self.out_pred_idx]
255-
# Convert predictions into sentences
256-
if self.beam_search:
257-
predictions = decode_predictions_beam_search(samples,
259+
samples = predictions
260+
alphas = None
261+
heuristic = None
262+
sources = None
263+
if self.out_pred_idx is not None:
264+
samples = samples[self.out_pred_idx]
265+
# Convert predictions into sentences
266+
if self.beam_search:
267+
predictions = decode_predictions_beam_search(samples,
258268
self.index2word_y,
259269
alphas=alphas,
260270
x_text=sources,
261271
heuristic=heuristic,
262272
mapping=self.extra_vars.get('mapping', None),
263273
verbose=self.verbose)
264-
else:
265-
probs = predictions
266-
predictions = decode_predictions(predictions,
274+
else:
275+
probs = predictions
276+
predictions = decode_predictions(predictions,
267277
1, # always set temperature to 1
268278
self.index2word_y,
269279
self.sampling_type,
270280
verbose=self.verbose)
271281

272-
# Apply detokenization function if needed
273-
if self.extra_vars.get('apply_detokenization', False):
274-
predictions = map(self.extra_vars['detokenize_f'], predictions)
275-
276-
277-
elif self.is_multilabel:
278-
if self.multilabel_idx is not None:
279-
predictions = predictions[self.multilabel_idx]
280-
predictions = decode_multilabel(predictions,
281-
self.index2word_y,
282-
min_val=self.min_pred_multilabel,
283-
verbose=self.verbose)
284-
285-
# Store predictions
286-
if self.write_samples:
287-
# Store result
288-
filepath = self.save_path + '/' + s + '_' + counter_name + '_' + str(epoch) + '.pred' # results file
289-
if self.write_type == 'list':
290-
list2file(filepath, predictions)
291-
elif self.write_type == 'vqa':
292-
try:
282+
# Apply detokenization function if needed
283+
if self.extra_vars.get('apply_detokenization', False):
284+
predictions = map(self.extra_vars['detokenize_f'], predictions)
285+
286+
287+
elif self.is_multilabel:
288+
if self.multilabel_idx is not None:
289+
predictions = predictions[self.multilabel_idx]
290+
predictions = decode_multilabel(predictions,
291+
self.index2word_y,
292+
min_val=self.min_pred_multilabel,
293+
verbose=self.verbose)
294+
295+
# Store predictions
296+
if self.write_samples:
297+
# Store result
298+
filepath = self.save_path + '/' + s + '_' + counter_name + '_' + str(epoch) + '.pred' # results file
299+
if self.write_type == 'list':
300+
list2file(filepath, predictions)
301+
elif self.write_type == 'vqa':
293302
exec('refs = self.ds.Y_'+s+'[self.gt_id]')
294-
except:
295-
refs = ['N/A' for i in range(probs.shape[0])]
296-
extra_data_plot = {'reference': refs,
297-
'probs': probs,
298-
'vocab': self.index2word_y}
299-
list2vqa(filepath, predictions, self.extra_vars[s]['question_ids'], extra=extra_data_plot)
300-
elif self.write_type == 'listoflists':
301-
listoflists2file(filepath, predictions)
302-
elif self.write_type == 'numpy':
303-
numpy2file(filepath, predictions)
304-
elif self.write_type == '3DLabels':
305-
# TODO:
306-
print("WRITE SAMPLES FUNCTION NOT IMPLEMENTED")
307-
else:
308-
raise NotImplementedError(
309-
'The store type "' + self.write_type + '" is not implemented.')
303+
extra_data_plot = {'reference': refs,
304+
'probs': probs}
305+
list2vqa(filepath, predictions, self.extra_vars[s]['question_ids'], extra=extra_data_plot)
306+
elif self.write_type == 'listoflists':
307+
listoflists2file(filepath, predictions)
308+
elif self.write_type == 'numpy':
309+
numpy2file(filepath, predictions)
310+
elif self.write_type == '3DLabels':
311+
# TODO:
312+
print("WRITE SAMPLES FUNCTION NOT IMPLEMENTED")
313+
else:
314+
raise NotImplementedError(
315+
'The store type "' + self.write_type + '" is not implemented.')
310316

311-
# Evaluate on each metric
312-
for metric in self.metric_name:
313-
if self.verbose > 0:
314-
logging.info('Evaluating on metric ' + metric)
315-
filepath = self.save_path + '/' + s + '.' + metric # results file
316-
317-
# Evaluate on the chosen metric
318-
metrics = evaluation.select[metric](
319-
pred_list=predictions,
320-
verbose=self.verbose,
321-
extra_vars=self.extra_vars,
322-
split=s)
323-
324-
# Print results to file and store in model log
325-
with open(filepath, 'a') as f:
326-
header = counter_name + ','
327-
line = str(epoch) + ','
328-
# Store in model log
329-
self.model_to_eval.log(s, counter_name, epoch)
330-
for metric_ in sorted(metrics):
331-
all_metrics.append(metric_)
332-
value = metrics[metric_]
333-
header += metric_ + ','
334-
line += str(value) + ','
317+
# Evaluate on each metric
318+
for metric in self.metric_name:
319+
if self.verbose > 0:
320+
logging.info('Evaluating on metric ' + metric)
321+
filepath = self.save_path + '/' + s + '.' + metric # results file
322+
323+
# Evaluate on the chosen metric
324+
metrics = evaluation.select[metric](
325+
pred_list=predictions,
326+
verbose=self.verbose,
327+
extra_vars=self.extra_vars,
328+
split=s)
329+
330+
# Print results to file and store in model log
331+
with open(filepath, 'a') as f:
332+
header = counter_name + ','
333+
line = str(epoch) + ','
335334
# Store in model log
336-
self.model_to_eval.log(s, metric_, value)
337-
if not self.written_header:
338-
f.write(header + '\n')
339-
self.written_header = True
340-
f.write(line + '\n')
335+
self.model_to_eval.log(s, counter_name, epoch)
336+
for metric_ in sorted(metrics):
337+
all_metrics.append(metric_)
338+
value = metrics[metric_]
339+
header += metric_ + ','
340+
line += str(value) + ','
341+
# Store in model log
342+
self.model_to_eval.log(s, metric_, value)
343+
if not self.written_header:
344+
f.write(header + '\n')
345+
self.written_header = True
346+
f.write(line + '\n')
347+
348+
if self.verbose > 0:
349+
logging.info('Done evaluating on metric ' + metric)
341350

342-
if self.verbose > 0:
343-
logging.info('Done evaluating on metric ' + metric)
344351

345352
# Plot results so far
346353
self.model_to_eval.plot(counter_name, set(all_metrics), self.set_name, upperbound=self.max_plot)

0 commit comments

Comments
 (0)