1
- from keras .callbacks import Callback
1
+ from keras .callbacks import TensorBoard , ModelCheckpoint
2
+ import tensorflow as tf
2
3
import numpy as np
3
4
4
- class CustomModelCheckpoint (Callback ):
5
- """Save the model after every epoch.
6
- `filepath` can contain named formatting options,
7
- which will be filled the value of `epoch` and
8
- keys in `logs` (passed in `on_epoch_end`).
9
- For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
10
- then the model checkpoints will be saved with the epoch number and
11
- the validation loss in the filename.
12
- # Arguments
13
- filepath: string, path to save the model file.
14
- monitor: quantity to monitor.
15
- verbose: verbosity mode, 0 or 1.
16
- save_best_only: if `save_best_only=True`,
17
- the latest best model according to
18
- the quantity monitored will not be overwritten.
19
- mode: one of {auto, min, max}.
20
- If `save_best_only=True`, the decision
21
- to overwrite the current save file is made
22
- based on either the maximization or the
23
- minimization of the monitored quantity. For `val_acc`,
24
- this should be `max`, for `val_loss` this should
25
- be `min`, etc. In `auto` mode, the direction is
26
- automatically inferred from the name of the monitored quantity.
27
- save_weights_only: if True, then only the model's weights will be
28
- saved (`model.save_weights(filepath)`), else the full model
29
- is saved (`model.save(filepath)`).
30
- period: Interval (number of epochs) between checkpoints.
31
- """
5
+ class CustomTensorBoard (TensorBoard ):
6
+ """ to log the loss after each batch
7
+ """
8
+ def __init__ (self , log_every = 1 , ** kwargs ):
9
+ super (CustomTensorBoard , self ).__init__ (** kwargs )
10
+ self .log_every = log_every
11
+ self .counter = 0
12
+
13
+ def on_batch_end (self , batch , logs = None ):
14
+ self .counter += 1
15
+ if self .counter % self .log_every == 0 :
16
+ for name , value in logs .items ():
17
+ if name in ['batch' , 'size' ]:
18
+ continue
19
+ summary = tf .Summary ()
20
+ summary_value = summary .value .add ()
21
+ summary_value .simple_value = value .item ()
22
+ summary_value .tag = name
23
+ self .writer .add_summary (summary , self .counter )
24
+ self .writer .flush ()
25
+
26
+ super (CustomTensorBoard , self ).on_batch_end (batch , logs )
32
27
33
- def __init__ (self , filepath , model_to_save , monitor = 'val_loss' , verbose = 0 ,
34
- save_best_only = False , save_weights_only = False ,
35
- mode = 'auto' , period = 1 ):
36
- super (CustomModelCheckpoint , self ).__init__ ()
28
+ class CustomModelCheckpoint (ModelCheckpoint ):
29
+ """ to save the template model, not the multi-GPU model
30
+ """
31
+ def __init__ (self , model_to_save , ** kwargs ):
32
+ super (CustomModelCheckpoint , self ).__init__ (** kwargs )
37
33
self .model_to_save = model_to_save
38
- self .monitor = monitor
39
- self .verbose = verbose
40
- self .filepath = filepath
41
- self .save_best_only = save_best_only
42
- self .save_weights_only = save_weights_only
43
- self .period = period
44
- self .epochs_since_last_save = 0
45
-
46
- if mode not in ['auto' , 'min' , 'max' ]:
47
- warnings .warn ('ModelCheckpoint mode %s is unknown, '
48
- 'fallback to auto mode.' % (mode ),
49
- RuntimeWarning )
50
- mode = 'auto'
51
-
52
- if mode == 'min' :
53
- self .monitor_op = np .less
54
- self .best = np .Inf
55
- elif mode == 'max' :
56
- self .monitor_op = np .greater
57
- self .best = - np .Inf
58
- else :
59
- if 'acc' in self .monitor or self .monitor .startswith ('fmeasure' ):
60
- self .monitor_op = np .greater
61
- self .best = - np .Inf
62
- else :
63
- self .monitor_op = np .less
64
- self .best = np .Inf
65
34
66
35
def on_epoch_end (self , epoch , logs = None ):
67
36
logs = logs or {}
@@ -96,4 +65,6 @@ def on_epoch_end(self, epoch, logs=None):
96
65
if self .save_weights_only :
97
66
self .model_to_save .save_weights (filepath , overwrite = True )
98
67
else :
99
- self .model_to_save .save (filepath , overwrite = True )
68
+ self .model_to_save .save (filepath , overwrite = True )
69
+
70
+ super (CustomModelCheckpoint , self ).on_batch_end (epoch , logs )
0 commit comments