Skip to content

Commit e6a8577

Browse files
eldakmsmahilleb-msft
authored andcommitted
Fix for Examples/Image/Classification/ResNet/Python/TrainResNet_CIFAR10_Distributed.py
1 parent 7892679 commit e6a8577

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

Examples/Image/Classification/ResNet/Python/TrainResNet_CIFAR10_Distributed.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from cntk.utils import *
1313
from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error
14-
from cntk.io import MinibatchSource, ImageDeserializer, StreamDef, StreamDefs, INFINITE_SAMPLES
14+
from cntk.io import MinibatchSource, ImageDeserializer, StreamDef, StreamDefs, INFINITE_SAMPLES, FULL_DATA_SWEEP
1515
from cntk import Trainer, cntk_py, distributed
1616
from cntk.learner import momentum_sgd, learning_rate_schedule, momentum_as_time_constant_schedule, UnitType
1717
from _cntk_py import set_computation_network_trace_level
@@ -56,7 +56,7 @@ def create_reader(map_file, mean_file, train, total_data_size, distributed_after
5656

5757

5858
# Train and evaluate the network.
59-
def train_and_evaluate(create_train_reader, create_test_reader, network_name, max_epochs, create_dist_learner, scale_up=False):
59+
def train_and_evaluate(create_train_reader, test_reader, network_name, max_epochs, create_dist_learner, scale_up=False):
6060

6161
set_computation_network_trace_level(0)
6262

@@ -85,7 +85,7 @@ def train_and_evaluate(create_train_reader, create_test_reader, network_name, ma
8585
# ResNet110 samples-per-second is ~7x of single GPU, comparing to ~3x without scaling
8686
# up. However, bigger minimatch size on the same number of samples means less updates,
8787
# thus leads to higher training error. This is a trade-off of speed and accuracy
88-
minibatch_size = 128 * (len(distributed.Communicator.num_workers()) if scale_up else 1)
88+
minibatch_size = 128 * (distributed.Communicator.num_workers() if scale_up else 1)
8989

9090
momentum_time_constant = -minibatch_size/np.log(0.9)
9191
l2_reg_weight = 0.0001
@@ -135,7 +135,6 @@ def train_and_evaluate(create_train_reader, create_test_reader, network_name, ma
135135
sample_count = 0
136136
minibatch_index = 0
137137

138-
test_reader=create_test_reader(epoch_size)
139138
while True:
140139
data = test_reader.next_minibatch(minibatch_size, input_map=input_map)
141140
if not data: break;
@@ -176,7 +175,7 @@ def train_and_evaluate(create_train_reader, create_test_reader, network_name, ma
176175
mean=os.path.join(data_path, 'CIFAR-10_mean.xml')
177176

178177
create_train_reader=lambda data_size: create_reader(train_data, mean, True, data_size, distributed_after_samples)
179-
test_reader=create_reader(test, mean, False, FULL_DATA_SWEEP)
178+
test_reader=create_reader(test_data, mean, False, FULL_DATA_SWEEP)
180179

181180
train_and_evaluate(create_train_reader, test_reader, network_name, epochs, create_dist_learner, scale_up)
182181

0 commit comments

Comments
 (0)