Skip to content

Commit e13ac33

Browse files
committed
Adding the implementation of cGANS
1 parent 7745670 commit e13ac33

File tree

3 files changed

+203
-3
lines changed

3 files changed

+203
-3
lines changed

conditional_gan.py

+199
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import numpy as np
2+
import fire
3+
import cv2
4+
import matplotlib.pyplot as plt
5+
6+
import torch.nn as nn
7+
import torch
8+
from torchvision.transforms import transforms
9+
from torchvision.datasets import MNIST
10+
from torch.optim import Adam
11+
from multiprocessing import set_start_method
12+
13+
torch.set_default_tensor_type(torch.cuda.FloatTensor)
14+
15+
try:
16+
set_start_method('spawn')
17+
except RuntimeError:
18+
pass
19+
20+
21+
class CGAN:
22+
def __init__(self):
23+
self.discriminator = Discriminator()
24+
self.generator = Generator()
25+
self.gan = None
26+
self.gan_input = 100
27+
self.batch_size = 32
28+
self.test_count = 9
29+
self.classes = 10
30+
31+
self.train_mnist_dataloader = None
32+
self.test_mnist_dataloader = None
33+
self.mnist_epochs = 50
34+
self.discriminator_opt = Adam(self.discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
35+
self.generator_opt = Adam(self.generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
36+
self.loss = nn.BCELoss()
37+
38+
self.generator_model_path = 'models/cgan.hdf5'
39+
40+
def load_data(self):
41+
transform = transforms.Compose(
42+
[transforms.ToTensor()]
43+
)
44+
45+
train_set = MNIST(root='./data/mnist', train=True, download=True, transform=transform)
46+
self.train_mnist_dataloader = torch.utils.data.DataLoader(train_set,
47+
batch_size=self.batch_size,
48+
shuffle=True, num_workers=1)
49+
50+
test_set = MNIST(root='./data/mnist', train=False, download=True, transform=transform)
51+
self.test_mnist_dataloader = torch.utils.data.DataLoader(test_set, batch_size=self.batch_size,
52+
shuffle=False, num_workers=1)
53+
54+
def train(self):
55+
self.load_data()
56+
for epoch in range(self.mnist_epochs):
57+
58+
for i, data in enumerate(self.train_mnist_dataloader, 0):
59+
real, real_labels = data
60+
real_target = torch.ones(size=(self.batch_size, 1), requires_grad=False) - 0.1
61+
fake_target = torch.zeros(size=(self.batch_size, 1), requires_grad=False)
62+
63+
# generator update
64+
self.generator_opt.zero_grad()
65+
noise = torch.tensor(np.random.normal(0, 1, (self.batch_size, 100)), dtype=torch.float)
66+
noise_labels = torch.tensor(np.random.randint(0, self.classes, size=(self.batch_size, 1)),
67+
dtype=torch.long)
68+
fake = self.generator(noise, noise_labels)
69+
g_loss = self.loss(self.discriminator(fake, noise_labels), real_target)
70+
g_loss.backward()
71+
self.generator_opt.step()
72+
73+
# discriminator update
74+
self.discriminator_opt.zero_grad()
75+
real_loss = self.loss(self.discriminator(real.cuda().detach(), real_labels), real_target)
76+
fake_loss = self.loss(self.discriminator(fake.detach(), noise_labels), fake_target)
77+
d_loss = (real_loss + fake_loss) / 2
78+
d_loss.backward()
79+
self.discriminator_opt.step()
80+
81+
print("Generator Loss: ", g_loss)
82+
print("Discriminator Loss: ", d_loss)
83+
84+
# print results
85+
self.sample_and_save_gan(epoch)
86+
print("Epoch: ", epoch + 1)
87+
88+
print('Finished Training')
89+
90+
def sample_and_save_gan(self, epoch):
91+
noise = torch.randn(size=(1, self.gan_input))
92+
labels = torch.randint(self.classes - 1, size=(1, 1))
93+
img = self.generator(noise, labels)
94+
img = img.cpu().detach().numpy()
95+
img = np.squeeze(img, axis=0)
96+
img = np.squeeze(img, axis=0)
97+
img = img * 255.0
98+
print(img)
99+
cv2.imwrite('gan_generated/img_{}.png'.format(epoch), img)
100+
torch.save(self.generator.state_dict(), self.generator_model_path)
101+
102+
def plot_results(self, generated):
103+
fig = plt.figure(figsize=(28, 28))
104+
columns = np.sqrt(self.test_count)
105+
rows = np.sqrt(self.test_count)
106+
generated = generated.cpu().detach().numpy()
107+
generated = np.squeeze(generated, axis=1)
108+
generated = generated * 255.0
109+
for i in range(1, int(columns) * int(rows) + 1):
110+
fig.add_subplot(rows, columns, i)
111+
plt.imshow(generated[i - 1], cmap='gray_r')
112+
plt.show()
113+
114+
def test(self):
115+
self.load_generator()
116+
noise = torch.randn(size=(self.test_count, self.gan_input))
117+
labels = torch.randint(self.classes - 1, size=(self.test_count, 1))
118+
print(labels)
119+
generated = self.generator(noise, labels)
120+
self.plot_results(generated)
121+
122+
def load_generator(self):
123+
self.generator = Generator()
124+
self.generator.load_state_dict(torch.load(self.generator_model_path))
125+
self.generator.eval()
126+
127+
128+
class Discriminator(nn.Module):
129+
def __init__(self, n_classes=10):
130+
super(Discriminator, self).__init__()
131+
self.conv1 = nn.Conv2d(in_channels=2, out_channels=64, kernel_size=3, stride=2, padding=1)
132+
self.leaky_relu1 = nn.LeakyReLU(negative_slope=0.2)
133+
self.drop_out1 = nn.Dropout(0.4)
134+
self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1)
135+
self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.2)
136+
self.drop_out2 = nn.Dropout(0.4)
137+
self.fc1 = nn.Linear(7 * 7 * 64, 1)
138+
self.sigmoid = nn.Sigmoid()
139+
140+
self.embedding1 = nn.Embedding(n_classes, 50)
141+
self.fc1_label = nn.Linear(50, 784)
142+
143+
def forward(self, x, y):
144+
y = self.embedding1(y)
145+
y = self.fc1_label(y)
146+
y = y.view(-1, 1, 28, 28)
147+
148+
x = torch.cat([x, y], dim=1)
149+
150+
x = self.conv1(x)
151+
x = self.leaky_relu1(x)
152+
x = self.drop_out1(x)
153+
x = self.conv2(x)
154+
x = self.leaky_relu2(x)
155+
x = self.drop_out2(x)
156+
x = torch.flatten(x, start_dim=1)
157+
x = self.sigmoid(self.fc1(x))
158+
159+
return x
160+
161+
162+
class Generator(nn.Module):
163+
def __init__(self, n_classes=10):
164+
super(Generator, self).__init__()
165+
self.fc1_out = 128 * 7 * 7
166+
self.fc1 = nn.Linear(100, self.fc1_out)
167+
self.leaky_relu1 = nn.LeakyReLU(negative_slope=0.2)
168+
self.conv1 = nn.ConvTranspose2d(129, 128, kernel_size=4, stride=2, padding=1)
169+
self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.2)
170+
self.conv2 = nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1)
171+
self.leaky_relu3 = nn.LeakyReLU(negative_slope=0.2)
172+
self.conv3 = nn.Conv2d(128, 1, kernel_size=7, padding=3)
173+
self.sigmoid = nn.Sigmoid()
174+
175+
self.embedding1 = nn.Embedding(n_classes, 50)
176+
self.fc1_label = nn.Linear(50, 49)
177+
178+
def forward(self, x, y):
179+
x = self.fc1(x)
180+
x = self.leaky_relu1(x)
181+
x = x.view(-1, 128, 7, 7)
182+
183+
y = self.embedding1(y)
184+
y = self.fc1_label(y)
185+
y = y.view(-1, 1, 7, 7)
186+
187+
x = torch.cat([x, y], dim=1)
188+
189+
x = self.conv1(x)
190+
x = self.leaky_relu2(x)
191+
x = self.conv2(x)
192+
x = self.leaky_relu3(x)
193+
x = self.sigmoid(self.conv3(x))
194+
195+
return x
196+
197+
198+
if __name__ == "__main__":
199+
fire.Fire(CGAN)

gan.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def create_discriminator(self):
3636
opt = Adam(lr=0.0002, beta_1=0.5)
3737
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
3838

39+
model.summary()
40+
3941
self.discriminator = model
4042

4143
def create_generator(self):
@@ -126,11 +128,10 @@ def plot_results(self, generated):
126128
plt.show()
127129

128130
def test(self):
129-
generator = load_model(self.generator_model_path)
130131
generated = []
131132
for i in range(self.test_count):
132-
noise = self.generate_latent_points(100, self.batch_size)
133-
img = generator.predict(noise)
133+
noise = np.random.normal(0, 1, (1, 100))
134+
img = self.generator.predict(noise)
134135
img = np.squeeze(img, axis=0)
135136
img = np.squeeze(img, axis=-1)
136137
generated.append(img * 255.0)

models/cgan.hdf5

4.46 MB
Binary file not shown.

0 commit comments

Comments
 (0)