|
11 | 11 |
|
12 | 12 | from cntk.utils import *
|
13 | 13 | from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error
|
14 |
| -from cntk.io import MinibatchSource, ImageDeserializer, StreamDef, StreamDefs, INFINITE_SAMPLES |
| 14 | +from cntk.io import MinibatchSource, ImageDeserializer, StreamDef, StreamDefs, INFINITE_SAMPLES, FULL_DATA_SWEEP |
15 | 15 | from cntk import Trainer, cntk_py, distributed
|
16 | 16 | from cntk.learner import momentum_sgd, learning_rate_schedule, momentum_as_time_constant_schedule, UnitType
|
17 | 17 | from _cntk_py import set_computation_network_trace_level
|
@@ -56,7 +56,7 @@ def create_reader(map_file, mean_file, train, total_data_size, distributed_after
|
56 | 56 |
|
57 | 57 |
|
58 | 58 | # Train and evaluate the network.
|
59 |
| -def train_and_evaluate(create_train_reader, create_test_reader, network_name, max_epochs, create_dist_learner, scale_up=False): |
| 59 | +def train_and_evaluate(create_train_reader, test_reader, network_name, max_epochs, create_dist_learner, scale_up=False): |
60 | 60 |
|
61 | 61 | set_computation_network_trace_level(0)
|
62 | 62 |
|
@@ -85,7 +85,7 @@ def train_and_evaluate(create_train_reader, create_test_reader, network_name, ma
|
85 | 85 | # ResNet110 samples-per-second is ~7x of single GPU, comparing to ~3x without scaling
|
86 | 86 | # up. However, bigger minimatch size on the same number of samples means less updates,
|
87 | 87 | # thus leads to higher training error. This is a trade-off of speed and accuracy
|
88 |
| - minibatch_size = 128 * (len(distributed.Communicator.num_workers()) if scale_up else 1) |
| 88 | + minibatch_size = 128 * (distributed.Communicator.num_workers() if scale_up else 1) |
89 | 89 |
|
90 | 90 | momentum_time_constant = -minibatch_size/np.log(0.9)
|
91 | 91 | l2_reg_weight = 0.0001
|
@@ -135,7 +135,6 @@ def train_and_evaluate(create_train_reader, create_test_reader, network_name, ma
|
135 | 135 | sample_count = 0
|
136 | 136 | minibatch_index = 0
|
137 | 137 |
|
138 |
| - test_reader=create_test_reader(epoch_size) |
139 | 138 | while True:
|
140 | 139 | data = test_reader.next_minibatch(minibatch_size, input_map=input_map)
|
141 | 140 | if not data: break;
|
@@ -176,7 +175,7 @@ def train_and_evaluate(create_train_reader, create_test_reader, network_name, ma
|
176 | 175 | mean=os.path.join(data_path, 'CIFAR-10_mean.xml')
|
177 | 176 |
|
178 | 177 | create_train_reader=lambda data_size: create_reader(train_data, mean, True, data_size, distributed_after_samples)
|
179 |
| - test_reader=create_reader(test, mean, False, FULL_DATA_SWEEP) |
| 178 | + test_reader=create_reader(test_data, mean, False, FULL_DATA_SWEEP) |
180 | 179 |
|
181 | 180 | train_and_evaluate(create_train_reader, test_reader, network_name, epochs, create_dist_learner, scale_up)
|
182 | 181 |
|
|
0 commit comments