2
2
import fire
3
3
import cv2
4
4
import matplotlib .pyplot as plt
5
- from keras .layers import Input , Dense , LeakyReLU , Dropout
6
- from keras .models import Model , load_model , save_model
5
+ from keras .layers import Dense , LeakyReLU , Reshape , Conv2DTranspose , Conv2D , Dropout , Flatten
6
+ from keras .models import load_model , save_model , Sequential
7
7
from keras .datasets import mnist
8
+ from keras .optimizers import Adam
8
9
9
10
10
11
class GAN :
11
12
def __init__ (self ):
12
13
self .discriminator = None
13
14
self .generator = None
14
15
self .gan = None
15
- self .gan_input = (100 , )
16
- self .data_shape = (28 , 28 )
17
- self .generator_input = self .gan_input
18
- self .generator_output = (np .prod (self .data_shape ), )
16
+ self .gan_input = 100
17
+ self .data_shape = (28 , 28 , 1 )
19
18
self .X_train = None
20
19
self .X_test = None
21
20
self .batch_size = 32
@@ -24,51 +23,56 @@ def __init__(self):
24
23
self .test_count = 9
25
24
26
25
def create_discriminator (self ):
27
- input = Input (self .generator_output )
28
-
29
- x = Dense (1024 )(input )
30
- x = LeakyReLU (0.2 )(x )
31
- x = Dropout (0.3 )(x )
32
- x = Dense (512 )(x )
33
- x = LeakyReLU (0.2 )(x )
34
- x = Dropout (0.3 )(x )
35
- x = Dense (256 )(x )
36
- x = LeakyReLU (0.2 )(x )
37
-
38
- output = Dense (1 , activation = 'sigmoid' )(x )
39
-
40
- self .discriminator = Model (input = input , output = output )
41
- self .discriminator .compile (loss = 'binary_crossentropy' , optimizer = 'adam' )
26
+ model = Sequential ()
27
+ model .add (Conv2D (64 , (3 , 3 ), strides = (2 , 2 ), padding = 'same' , input_shape = self .data_shape ))
28
+ model .add (LeakyReLU (alpha = 0.2 ))
29
+ model .add (Dropout (0.4 ))
30
+ model .add (Conv2D (64 , (3 , 3 ), strides = (2 , 2 ), padding = 'same' ))
31
+ model .add (LeakyReLU (alpha = 0.2 ))
32
+ model .add (Dropout (0.4 ))
33
+ model .add (Flatten ())
34
+ model .add (Dense (1 , activation = 'sigmoid' ))
35
+ # compile model
36
+ opt = Adam (lr = 0.0002 , beta_1 = 0.5 )
37
+ model .compile (loss = 'binary_crossentropy' , optimizer = opt , metrics = ['accuracy' ])
38
+
39
+ self .discriminator = model
42
40
43
41
def create_generator (self ):
44
- input = Input (self .gan_input )
45
-
46
- x = Dense (256 )(input )
47
- x = LeakyReLU (0.2 )(x )
48
- x = Dense (512 )(x )
49
- x = LeakyReLU (0.2 )(x )
50
- x = Dense (1024 )(x )
51
- x = LeakyReLU (0.2 )(x )
52
-
53
- output = Dense (self .generator_output [0 ], activation = 'tanh' )(x )
54
-
55
- self .generator = Model (input = input , output = output )
56
- self .generator .compile (loss = 'binary_crossentropy' , optimizer = 'adam' )
42
+ model = Sequential ()
43
+ # foundation for 7x7 image
44
+ n_nodes = 128 * 7 * 7
45
+ model .add (Dense (n_nodes , input_dim = self .gan_input ))
46
+ model .add (LeakyReLU (alpha = 0.2 ))
47
+ model .add (Reshape ((7 , 7 , 128 )))
48
+ # upsample to 14x14
49
+ model .add (Conv2DTranspose (128 , (4 , 4 ), strides = (2 , 2 ), padding = 'same' ))
50
+ model .add (LeakyReLU (alpha = 0.2 ))
51
+ # upsample to 28x28
52
+ model .add (Conv2DTranspose (128 , (4 , 4 ), strides = (2 , 2 ), padding = 'same' ))
53
+ model .add (LeakyReLU (alpha = 0.2 ))
54
+ model .add (Conv2D (1 , (7 , 7 ), activation = 'sigmoid' , padding = 'same' ))
55
+
56
+ self .generator = model
57
57
58
58
def create_gan (self ):
59
- input = Input (self .gan_input )
60
-
59
+ # make weights in the discriminator not trainable
61
60
self .discriminator .trainable = False
62
- x = self .generator (input )
63
- output = self .discriminator (x )
64
-
65
- self .gan = Model (input = input , output = output )
66
- self .gan .compile (loss = 'binary_crossentropy' , optimizer = 'adam' )
61
+ # connect them
62
+ model = Sequential ()
63
+ # add generator
64
+ model .add (self .generator )
65
+ # add the discriminator
66
+ model .add (self .discriminator )
67
+ # compile model
68
+ opt = Adam (lr = 0.0002 , beta_1 = 0.5 )
69
+ model .compile (loss = 'binary_crossentropy' , optimizer = opt )
70
+
71
+ self .gan = model
67
72
68
73
def load_data (self ):
69
- (self .X_train , _ ), (self .X_test , _ ) = mnist .load_data ()
70
- self .X_train = (self .X_train .astype (np .float32 ) - 127.5 ) / 127.5
71
- print (self .X_train .shape )
74
+ (self .X_train , labels ), (self .X_test , _ ) = mnist .load_data ()
75
+ self .X_train = self .X_train .astype (np .float32 ) / 255.0
72
76
73
77
def train (self ):
74
78
self .load_data ()
@@ -77,19 +81,23 @@ def train(self):
77
81
self .create_gan ()
78
82
79
83
for i in range (self .epochs ):
80
- for k in range (int (self .X_train .shape [0 ]/ self .batch_size )):
84
+ for k in range (int (self .X_train .shape [0 ] / self .batch_size )):
81
85
noise = np .random .normal (0 , 1 , (self .batch_size , 100 ))
82
- minibatch_x = self .X_train [k * self .batch_size :(k + 1 ) * self .batch_size ]
83
- minibatch_x = np .reshape (minibatch_x , ( self . batch_size , 784 ) )
86
+ minibatch_x = self .X_train [k * self .batch_size :(k + 1 ) * self .batch_size ]
87
+ minibatch_x = np .expand_dims (minibatch_x , axis = - 1 )
84
88
minibatch_y = np .ones (self .batch_size ) - 0.01
85
89
generated_x = self .generator .predict (noise )
86
90
generated_y = np .zeros (self .batch_size )
87
91
92
+ minibatch_y = np .expand_dims (minibatch_y , axis = - 1 )
93
+ generated_y = np .expand_dims (generated_y , axis = - 1 )
94
+
88
95
self .discriminator .trainable = True
89
96
self .discriminator .train_on_batch (minibatch_x , minibatch_y )
90
97
self .discriminator .train_on_batch (generated_x , generated_y )
91
98
92
- noise = np .random .normal (0 , 1 , (self .batch_size , 100 , ))
99
+ # noise = self.generate_latent_points(100, self.batch_size)
100
+ noise = np .random .normal (0 , 1 , (self .batch_size , 100 ))
93
101
gan_y = np .ones (self .batch_size )
94
102
95
103
self .gan .train_on_batch (noise , gan_y )
@@ -98,31 +106,34 @@ def train(self):
98
106
self .sample_gan (i )
99
107
save_model (self .generator , self .generator_model_path )
100
108
101
- print ("Epoch: " , i + 1 )
109
+ print ("Epoch: " , i + 1 )
102
110
103
111
def sample_gan (self , epoch ):
104
112
noise = np .random .normal (0 , 1 , (1 , 100 ))
105
113
img = self .generator .predict (noise )
106
- img = np .reshape (img , (28 , 28 ))
107
- img = img * 255
114
+ img = np .squeeze (img , axis = 0 )
115
+ img = np .squeeze (img , axis = - 1 )
116
+ img = img * 255.0
108
117
cv2 .imwrite ('gan_generated/img_{}.png' .format (epoch ), img )
109
118
110
119
def plot_results (self , generated ):
111
120
fig = plt .figure (figsize = (28 , 28 ))
112
121
columns = np .sqrt (self .test_count )
113
122
rows = np .sqrt (self .test_count )
114
- for i in range (1 , columns * rows + 1 ):
123
+ for i in range (1 , int ( columns ) * int ( rows ) ):
115
124
fig .add_subplot (rows , columns , i )
116
- plt .imshow (generated [i ])
125
+ plt .imshow (generated [i ], cmap = 'gray_r' )
117
126
plt .show ()
118
127
119
128
def test (self ):
120
129
generator = load_model (self .generator_model_path )
121
130
generated = []
122
- for _ in range (self .test_count ):
123
- noise = np .random .normal (0 , 1 , (1 , 100 ))
124
- generated .append (generator .predict (noise ))
125
-
131
+ for i in range (self .test_count ):
132
+ noise = self .generate_latent_points (100 , self .batch_size )
133
+ img = generator .predict (noise )
134
+ img = np .squeeze (img , axis = 0 )
135
+ img = np .squeeze (img , axis = - 1 )
136
+ generated .append (img * 255.0 )
126
137
self .plot_results (generated )
127
138
128
139
0 commit comments