Skip to content

Commit 485f8c9

Browse files
author
Lukas Herman
committed
Add auto encoder model
1 parent d422163 commit 485f8c9

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

5_Auto_Encoder.py

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import tensorflow as tf
2+
from tensorflow.examples.tutorials.mnist import input_data
3+
import sys
4+
from os.path import join, splitext
5+
from layers import layers
6+
import config
7+
8+
train_dir = config.TRAIN_DIR
9+
test_dir = config.TEST_DIR
10+
n_trains = config.N_TRAINS
11+
batch_size = config.BATCH_SIZE
12+
13+
width = config.WIDTH
14+
height = config.HEIGHT
15+
channels = config.CHANNELS
16+
flat = config.FLAT
17+
n_classes = config.N_CLASSES
18+
19+
k = 1000
20+
l = 500
21+
m = 30
22+
n = l
23+
o = k
24+
25+
mnist = input_data.read_data_sets('data', one_hot=True)
26+
27+
def get_dict(train=True, batch=True):
28+
if train:
29+
if batch:
30+
batch_x, _ = mnist.train.next_batch(batch_size)
31+
return {x:batch_x}
32+
else:
33+
return {x:mnist.train.images}
34+
else:
35+
if batch:
36+
batch_x, _ = mnist.test.next_batch(batch_size)
37+
return {x:batch_x}
38+
else:
39+
return {x:mnist.test.images}
40+
41+
with tf.name_scope('InputLayer'):
42+
x = tf.placeholder(tf.float32, shape=[None, flat], name='x')
43+
44+
with tf.name_scope('NetworkModel'):
45+
with tf.name_scope('Encoder'):
46+
y1 = layers.ae_layer(x, flat, k)
47+
y2 = layers.ae_layer(y1, k, l)
48+
y3 = layers.ae_layer(y2, l, m)
49+
with tf.name_scope('Decoder'):
50+
y4 = layers.ae_layer(y3, m, n)
51+
y5 = layers.ae_layer(y4, n, o)
52+
y = layers.ae_layer(y5, o, flat)
53+
54+
with tf.name_scope('Train'):
55+
loss = tf.reduce_mean(tf.pow(y-x, 2), name='loss')
56+
train = tf.train.AdamOptimizer().minimize(loss)
57+
58+
with tf.name_scope('Accuracy'):
59+
accuracy = 1 - loss
60+
61+
# Add image summaries
62+
x_img = tf.reshape(x, [-1, height, width, channels]) # input
63+
y_img = tf.reshape(y, [-1, height, width, channels]) # reconstructed
64+
tf.summary.image('InputImage', x_img)
65+
tf.summary.image('OutputImage', y_img)
66+
67+
# Add scalar summaries
68+
tf.summary.scalar('Loss', loss)
69+
tf.summary.scalar('Accuracy', accuracy)
70+
71+
init_op = tf.global_variables_initializer()
72+
summary_op = tf.summary.merge_all()
73+
74+
with tf.Session() as sess:
75+
# Open protocol for writing files
76+
train_writer = tf.summary.FileWriter(train_dir)
77+
train_writer.add_graph(sess.graph)
78+
test_writer = tf.summary.FileWriter(test_dir)
79+
80+
sess.run(init_op)
81+
for n_train in range(1, n_trains+1):
82+
print("Training {}...".format(n_train))
83+
_ = sess.run([train], feed_dict=get_dict(train=True, batch=True))
84+
if n_train % 100 == 0:
85+
# Train
86+
s = sess.run(summary_op, feed_dict=get_dict(train=True, batch=False))
87+
train_writer.add_summary(s, n_train)
88+
# Test
89+
s = sess.run(summary_op, feed_dict=get_dict(train=False, batch=False))
90+
test_writer.add_summary(s, n_train)

__pycache__/config.cpython-36.pyc

0 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)