Skip to content

Commit 43ca57a

Browse files
committed
cross entropy for classification
1 parent 52c6892 commit 43ca57a

10 files changed

+211
-126
lines changed

callbacks.py

+32-61
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,36 @@
1-
from keras.callbacks import Callback
1+
from keras.callbacks import TensorBoard, ModelCheckpoint
2+
import tensorflow as tf
23
import numpy as np
34

4-
class CustomModelCheckpoint(Callback):
5-
"""Save the model after every epoch.
6-
`filepath` can contain named formatting options,
7-
which will be filled the value of `epoch` and
8-
keys in `logs` (passed in `on_epoch_end`).
9-
For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
10-
then the model checkpoints will be saved with the epoch number and
11-
the validation loss in the filename.
12-
# Arguments
13-
filepath: string, path to save the model file.
14-
monitor: quantity to monitor.
15-
verbose: verbosity mode, 0 or 1.
16-
save_best_only: if `save_best_only=True`,
17-
the latest best model according to
18-
the quantity monitored will not be overwritten.
19-
mode: one of {auto, min, max}.
20-
If `save_best_only=True`, the decision
21-
to overwrite the current save file is made
22-
based on either the maximization or the
23-
minimization of the monitored quantity. For `val_acc`,
24-
this should be `max`, for `val_loss` this should
25-
be `min`, etc. In `auto` mode, the direction is
26-
automatically inferred from the name of the monitored quantity.
27-
save_weights_only: if True, then only the model's weights will be
28-
saved (`model.save_weights(filepath)`), else the full model
29-
is saved (`model.save(filepath)`).
30-
period: Interval (number of epochs) between checkpoints.
31-
"""
5+
class CustomTensorBoard(TensorBoard):
6+
""" to log the loss after each batch
7+
"""
8+
def __init__(self, log_every=1, **kwargs):
9+
super(CustomTensorBoard, self).__init__(**kwargs)
10+
self.log_every = log_every
11+
self.counter = 0
12+
13+
def on_batch_end(self, batch, logs=None):
14+
self.counter+=1
15+
if self.counter%self.log_every==0:
16+
for name, value in logs.items():
17+
if name in ['batch', 'size']:
18+
continue
19+
summary = tf.Summary()
20+
summary_value = summary.value.add()
21+
summary_value.simple_value = value.item()
22+
summary_value.tag = name
23+
self.writer.add_summary(summary, self.counter)
24+
self.writer.flush()
25+
26+
super(CustomTensorBoard, self).on_batch_end(batch, logs)
3227

33-
def __init__(self, filepath, model_to_save, monitor='val_loss', verbose=0,
34-
save_best_only=False, save_weights_only=False,
35-
mode='auto', period=1):
36-
super(CustomModelCheckpoint, self).__init__()
28+
class CustomModelCheckpoint(ModelCheckpoint):
29+
""" to save the template model, not the multi-GPU model
30+
"""
31+
def __init__(self, model_to_save, **kwargs):
32+
super(CustomModelCheckpoint, self).__init__(**kwargs)
3733
self.model_to_save = model_to_save
38-
self.monitor = monitor
39-
self.verbose = verbose
40-
self.filepath = filepath
41-
self.save_best_only = save_best_only
42-
self.save_weights_only = save_weights_only
43-
self.period = period
44-
self.epochs_since_last_save = 0
45-
46-
if mode not in ['auto', 'min', 'max']:
47-
warnings.warn('ModelCheckpoint mode %s is unknown, '
48-
'fallback to auto mode.' % (mode),
49-
RuntimeWarning)
50-
mode = 'auto'
51-
52-
if mode == 'min':
53-
self.monitor_op = np.less
54-
self.best = np.Inf
55-
elif mode == 'max':
56-
self.monitor_op = np.greater
57-
self.best = -np.Inf
58-
else:
59-
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
60-
self.monitor_op = np.greater
61-
self.best = -np.Inf
62-
else:
63-
self.monitor_op = np.less
64-
self.best = np.Inf
6534

6635
def on_epoch_end(self, epoch, logs=None):
6736
logs = logs or {}
@@ -96,4 +65,6 @@ def on_epoch_end(self, epoch, logs=None):
9665
if self.save_weights_only:
9766
self.model_to_save.save_weights(filepath, overwrite=True)
9867
else:
99-
self.model_to_save.save(filepath, overwrite=True)
68+
self.model_to_save.save(filepath, overwrite=True)
69+
70+
super(CustomModelCheckpoint, self).on_batch_end(epoch, logs)

config.json

+6-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
"warmup_epochs": 3,
1919
"ignore_thresh": 0.5,
2020
"gpus": "0,1",
21-
"scales": [1,2,4],
21+
22+
"grid_scales": [1,1,1],
23+
"obj_scale": 5,
24+
"noobj_scale": 1,
25+
"xywh_scale": 1,
26+
"class_scale": 1,
2227

2328
"tensorboard_dir": "logs",
2429
"saved_weights_name": "kangaroo.h5",

predict.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def _main_(args):
106106
else:
107107
image_paths += [input_path]
108108

109-
image_paths = [inp_file for inp_file in image_paths if (inp_file[-4:] == '.jpg' or inp_file == '.png')]
109+
image_paths = [inp_file for inp_file in image_paths if (inp_file[-4:] in ['.jpg', '.png', 'JPEG'])]
110110

111111
# the main loop
112112
for image_path in image_paths:

train.py

+28-12
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from yolo import create_yolov3_model, dummy_loss
99
from generator import BatchGenerator
1010
from utils.utils import normalize, evaluate, makedirs
11-
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
11+
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
1212
from keras.optimizers import Adam
13-
from callbacks import CustomModelCheckpoint
13+
from callbacks import CustomModelCheckpoint, CustomTensorBoard
1414
from utils.multi_gpu_model import multi_gpu_model
1515
import tensorflow as tf
1616
import keras
@@ -46,8 +46,8 @@ def create_training_instances(
4646
if len(labels) > 0:
4747
overlap_labels = set(labels).intersection(set(train_labels.keys()))
4848

49-
print('Seen labels: \t\t' + str(train_labels) + '\n')
50-
print('Given labels: \t\t' + str(labels))
49+
print('Seen labels: \t' + str(train_labels) + '\n')
50+
print('Given labels: \t' + str(labels))
5151

5252
# return None, None, None if some given label is not in the dataset
5353
if len(overlap_labels) < len(labels):
@@ -73,8 +73,8 @@ def create_callbacks(saved_weights_name, tensorboard_logs, model_to_save):
7373
verbose = 1
7474
)
7575
checkpoint = CustomModelCheckpoint(
76-
saved_weights_name,# + '{epoch:02d}.h5',
7776
model_to_save = model_to_save,
77+
filepath = saved_weights_name,# + '{epoch:02d}.h5',
7878
monitor = 'loss',
7979
verbose = 1,
8080
save_best_only = True,
@@ -91,7 +91,7 @@ def create_callbacks(saved_weights_name, tensorboard_logs, model_to_save):
9191
cooldown = 0,
9292
min_lr = 0
9393
)
94-
tensorboard = TensorBoard(
94+
tensorboard = CustomTensorBoard(
9595
log_dir = tensorboard_logs,
9696
write_graph = True,
9797
write_images = True,
@@ -108,7 +108,11 @@ def create_model(
108108
multi_gpu,
109109
saved_weights_name,
110110
lr,
111-
scales
111+
grid_scales,
112+
obj_scale,
113+
noobj_scale,
114+
xywh_scale,
115+
class_scale
112116
):
113117
if multi_gpu > 1:
114118
with tf.device('/cpu:0'):
@@ -120,7 +124,11 @@ def create_model(
120124
batch_size = batch_size//multi_gpu,
121125
warmup_batches = warmup_batches,
122126
ignore_thresh = ignore_thresh,
123-
scales = scales
127+
grid_scales = grid_scales,
128+
obj_scale = obj_scale,
129+
noobj_scale = noobj_scale,
130+
xywh_scale = xywh_scale,
131+
class_scale = class_scale
124132
)
125133
else:
126134
template_model, infer_model = create_yolov3_model(
@@ -131,8 +139,12 @@ def create_model(
131139
batch_size = batch_size,
132140
warmup_batches = warmup_batches,
133141
ignore_thresh = ignore_thresh,
134-
scales = scales
135-
)
142+
grid_scales = grid_scales,
143+
obj_scale = obj_scale,
144+
noobj_scale = noobj_scale,
145+
xywh_scale = xywh_scale,
146+
class_scale = class_scale
147+
)
136148

137149
# load the pretrained weight if exists, otherwise load the backend weight only
138150
if os.path.exists(saved_weights_name):
@@ -169,7 +181,7 @@ def _main_(args):
169181
config['valid']['cache_name'],
170182
config['model']['labels']
171183
)
172-
print('\nTraining on the following labels: ' + str(labels))
184+
print('\nTraining on: \t' + str(labels) + '\n')
173185

174186
###############################
175187
# Create the generators
@@ -223,7 +235,11 @@ def _main_(args):
223235
multi_gpu = multi_gpu,
224236
saved_weights_name = config['train']['saved_weights_name'],
225237
lr = config['train']['learning_rate'],
226-
scales = config['train']['scales'],
238+
grid_scales = config['train']['grid_scales'],
239+
obj_scale = config['train']['obj_scale'],
240+
noobj_scale = config['train']['noobj_scale'],
241+
xywh_scale = config['train']['xywh_scale'],
242+
class_scale = config['train']['class_scale'],
227243
)
228244

229245
###############################

utils/utils.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ def decode_netout(netout, anchors, obj_thresh, net_h, net_w):
175175
boxes = []
176176

177177
netout[..., :2] = _sigmoid(netout[..., :2])
178-
netout[..., 4:] = _sigmoid(netout[..., 4:])
179-
netout[..., 5:] = netout[..., 4][..., np.newaxis] * netout[..., 5:]
178+
netout[..., 4] = _sigmoid(netout[..., 4])
179+
netout[..., 5:] = netout[..., 4][..., np.newaxis] * _softmax(netout[..., 5:])
180180
netout[..., 5:] *= netout[..., 5:] > obj_thresh
181181

182182
for i in range(grid_h*grid_w):
@@ -314,4 +314,10 @@ def compute_ap(recall, precision):
314314

315315
# and sum (\Delta recall) * prec
316316
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
317-
return ap
317+
return ap
318+
319+
def _softmax(x, axis=-1):
320+
x = x - np.amax(x, axis, keepdims=True)
321+
e_x = np.exp(x)
322+
323+
return e_x / e_x.sum(axis, keepdims=True)

0 commit comments

Comments
 (0)