@@ -13,7 +13,7 @@ class ImgConvNets(object):
13
13
are multiplications of 4.
14
14
"""
15
15
def __init__ (self , model , model_scope , img_height , img_width , class_count , keep_prob = 0.5 ,
16
- learning_rate = 1e-4 , lr_adaptive = True , batch_size = 32 , max_steps = 20000 ):
16
+ learning_rate = 1e-4 , lr_adaptive = True , batch_size = 32 , num_epoches = 100 ):
17
17
"""
18
18
Args:
19
19
model: Specify which model to use.
@@ -28,7 +28,7 @@ def __init__(self, model, model_scope, img_height, img_width, class_count, keep_
28
28
accuracy. If True, the given learning_rate will be ignored.
29
29
batch_size: optional. The number of samples to be used in one step of the
30
30
optimization process.
31
- max_steps : optional. The max number of iterative steps in the training process.
31
+ num_epoches : optional. The number of epoches for the training process.
32
32
"""
33
33
assert model == 'BASIC' or model == 'DCNN' or model == 'STCNN'
34
34
@@ -41,7 +41,7 @@ def __init__(self, model, model_scope, img_height, img_width, class_count, keep_
41
41
self .learning_rate = learning_rate
42
42
self .lr_adaptive = lr_adaptive
43
43
self .batch_size = batch_size
44
- self .max_steps = max_steps
44
+ self .num_epoches = num_epoches
45
45
46
46
def train (self , img_features , true_labels , train_dir , result_file ):
47
47
"""
@@ -103,43 +103,44 @@ def train(self, img_features, true_labels, train_dir, result_file):
103
103
104
104
save_file = os .path .join (train_dir , result_file )
105
105
106
- disp_step = self ._get_epoch_step_count (train_set .shape [0 ])
107
- for step in range (self .max_steps ):
108
- # Read a batch of images and labels
109
- batch_data = self ._get_next_batch (train_set , step * self .batch_size )
110
- images_feed , labels_feed = \
111
- batch_data [:, :cols ], batch_data [:, cols :].reshape (- 1 )
112
-
106
+ epoch_steps = math .ceil (train_set .shape [0 ] / self .batch_size )
107
+ for epoch in range (1 , self .num_epoches + 1 ):
113
108
lr_feed = self ._get_learning_rate (last_accu )
114
- # Run one step of the model. The return values are the activations
115
- # from the `train_op` (which is discarded) and the `loss` Op.
116
- _ , loss_val , accu_val = sess .run ([train_op , loss , accuracy ],
117
- feed_dict = {images_placeholder : images_feed ,
118
- labels_placeholder : labels_feed ,
119
- learning_rate_placeholder : lr_feed ,
120
- keep_prob_placeholder : self .keep_prob })
121
-
122
- # Check to make sure the loss is decreasing
123
- loss_list .append (loss_val )
124
- accu_list .append (accu_val )
125
- if (step % disp_step == 0 ) or (step == self .max_steps - 1 ):
126
- mean_accu = sum (accu_list )* 100 / len (accu_list )
127
- if mean_accu >= 99.68 and mean_accu > last_accu :
128
- saver .save (sess , save_file , global_step = step )
129
- elif step == self .max_steps - 1 :
130
- saver .save (sess , save_file )
131
-
132
- print ("Step {:6d}: learning_rate used = {:.6f}, average loss = {:8.4f}, "
133
- "and training accuracy min = {:6.2f}%, mean = {:6.2f}%, "
134
- "max = {:6.2f}%" .format (step , lr_feed ,
135
- sum (loss_list )/ len (loss_list ),
136
- min (accu_list )* 100 , mean_accu ,
137
- max (accu_list )* 100 ))
138
- if mean_accu >= 99.99 : break
139
-
140
- loss_list = []
141
- accu_list = []
142
- last_accu = mean_accu
109
+ for step in range (epoch_steps ):
110
+ # Read a batch of images and labels
111
+ batch_data = self ._get_next_batch (train_set , step * self .batch_size )
112
+ images_feed , labels_feed = \
113
+ batch_data [:, :cols ], batch_data [:, cols :].reshape (- 1 )
114
+
115
+ # Run one step of the model. The return values are the activations
116
+ # from the `train_op` (which is discarded) and the `loss` Op.
117
+ _ , loss_val , accu_val = sess .run ([train_op , loss , accuracy ],
118
+ feed_dict = {images_placeholder : images_feed ,
119
+ labels_placeholder : labels_feed ,
120
+ learning_rate_placeholder : lr_feed ,
121
+ keep_prob_placeholder : self .keep_prob })
122
+
123
+ # Check to make sure the loss is decreasing
124
+ loss_list .append (loss_val )
125
+ accu_list .append (accu_val )
126
+
127
+ mean_accu = sum (accu_list )* 100 / len (accu_list )
128
+ if mean_accu >= 99.68 and mean_accu > last_accu :
129
+ saver .save (sess , save_file , global_step = epoch )
130
+ elif epoch == self .num_epoches - 1 :
131
+ saver .save (sess , save_file )
132
+
133
+ print ("Epoch {:3d} completed: learning_rate used = {:.6f}, average loss = {:8.4f}, "
134
+ "and training accuracy min = {:6.2f}%, mean = {:6.2f}%, "
135
+ "max = {:6.2f}%" .format (epoch , lr_feed ,
136
+ sum (loss_list )/ len (loss_list ),
137
+ min (accu_list )* 100 , mean_accu ,
138
+ max (accu_list )* 100 ))
139
+ if mean_accu >= 99.99 : break
140
+
141
+ loss_list = []
142
+ accu_list = []
143
+ last_accu = mean_accu
143
144
144
145
def _build_inference_graph_stcnn (self , images , keep_prob ):
145
146
"""
@@ -393,14 +394,6 @@ def _build_training_graph(self, logits, labels, learning_rate):
393
394
394
395
return train_op , loss , accuracy
395
396
396
- def _get_epoch_step_count (self , train_set_size ):
397
- if self .max_steps > 10000 :
398
- epoch_step = math .ceil (train_set_size / (self .batch_size * 1000.0 )) * 1000
399
- else :
400
- epoch_step = math .ceil (train_set_size / (self .batch_size * 100.0 )) * 100
401
-
402
- return epoch_step
403
-
404
397
def _get_next_batch (self , data_set , start_index ):
405
398
cnt = data_set .shape [0 ]
406
399
0 commit comments