Skip to content

Commit 5b70430

Browse files
committed
Minor changes
1 parent 8316a28 commit 5b70430

File tree

4 files changed

+29
-164
lines changed

4 files changed

+29
-164
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
__pycache__
22
.ipynb_checkpoints
3+
logdir

evaluate.py

+14-15
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,24 @@
2929
sess = tf.Session(config=config)
3030
tf.keras.backend.set_session(sess)
3131

32+
_available_datasets = [
33+
"mnist",
34+
"cifar10",
35+
]
3236

33-
def main(args):
34-
35-
_available_datasets = [
36-
"mnist",
37-
"cifar10",
38-
]
37+
_available_optimizers = {
38+
"rmsprop": tf.train.RMSPropOptimizer,
39+
"adam": tf.train.AdamOptimizer,
40+
"sgd": tf.train.GradientDescentOptimizer,
41+
}
3942

40-
if args.name not in _available_datasets:
43+
def main(args):
44+
if args.dataset not in _available_datasets:
4145
raise NotImplementedError
4246

4347
dataset = build_dataset(
4448
name=args.dataset,
45-
shape=[args.height,args.width],
49+
shape=[args.height, args.width],
4650
)
4751

4852
model = build_mobilenetv3(
@@ -52,11 +56,6 @@ def main(args):
5256
width_multiplier=args.width_multiplier,
5357
)
5458

55-
_available_optimizers = {
56-
"rmsprop": tf.train.RMSPropOptimizer,
57-
"adam": tf.train.AdamOptimizer,
58-
"sgd": tf.train.GradientDescentOptimizer,
59-
}
6059

6160
if args.optimizer not in _available_optimizers:
6261
raise NotImplementedError
@@ -85,11 +84,11 @@ def main(args):
8584
# Input
8685
parser.add_argument("--height", type=int, default=128)
8786
parser.add_argument("--width", type=int, default=128)
88-
parser.add_argument("--dataset", type=str, default="mnist")
87+
parser.add_argument("--dataset", type=str, default="mnist", choices=_available_datasets)
8988

9089
# Optimizer
9190
parser.add_argument("--lr", type=float, default=0.01)
92-
parser.add_argument("--optimizer", type=str, default="rmsprop", choices=["sgd", "adam", "rmsprop"])
91+
parser.add_argument("--optimizer", type=str, default="rmsprop", choices=_available_optimizers.keys())
9392

9493
# Training & validation
9594
parser.add_argument("--valid_batch_size", type=int, default=256)

mnist_mobilenetv3_small.ipynb

-133
This file was deleted.

train.py

+14-16
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,24 @@
2929
sess = tf.Session(config=config)
3030
tf.keras.backend.set_session(sess)
3131

32+
_available_datasets = [
33+
"mnist",
34+
"cifar10",
35+
]
3236

33-
def main(args):
34-
35-
_available_datasets = [
36-
"mnist",
37-
"cifar10",
38-
]
37+
_available_optimizers = {
38+
"rmsprop": tf.train.RMSPropOptimizer,
39+
"adam": tf.train.AdamOptimizer,
40+
"sgd": tf.train.GradientDescentOptimizer,
41+
}
3942

40-
if args.name not in _available_datasets:
43+
def main(args):
44+
if args.dataset not in _available_datasets:
4145
raise NotImplementedError
4246

4347
dataset = build_dataset(
4448
name=args.dataset,
45-
shape=[args.height,args.width],
49+
shape=(args.height, args.width),
4650
train_batch_size=args.train_batch_size,
4751
valid_batch_size=args.valid_batch_size
4852
)
@@ -55,12 +59,6 @@ def main(args):
5559
l2_reg=args.l2_reg,
5660
)
5761

58-
_available_optimizers = {
59-
"rmsprop": tf.train.RMSPropOptimizer,
60-
"adam": tf.train.AdamOptimizer,
61-
"sgd": tf.train.GradientDescentOptimizer,
62-
}
63-
6462
if args.optimizer not in _available_optimizers:
6563
raise NotImplementedError
6664

@@ -96,11 +94,11 @@ def main(args):
9694
# Input
9795
parser.add_argument("--height", type=int, default=128)
9896
parser.add_argument("--width", type=int, default=128)
99-
parser.add_argument("--dataset", type=str, default="mnist")
97+
parser.add_argument("--dataset", type=str, default="mnist", choices=_available_datasets)
10098

10199
# Optimizer
102100
parser.add_argument("--lr", type=float, default=0.01)
103-
parser.add_argument("--optimizer", type=str, default="rmsprop", choices=["sgd", "adam", "rmsprop"])
101+
parser.add_argument("--optimizer", type=str, default="rmsprop", choices=_available_optimizers.keys())
104102
parser.add_argument("--l2_reg", type=float, default=1e-5)
105103

106104
# Training & validation

0 commit comments

Comments
 (0)