|
| 1 | +import numpy as np |
| 2 | +import fire |
| 3 | +import cv2 |
| 4 | +import matplotlib.pyplot as plt |
| 5 | + |
| 6 | +import torch.nn as nn |
| 7 | +import torch |
| 8 | +from torchvision.transforms import transforms |
| 9 | +from torchvision.datasets import MNIST |
| 10 | +from torch.optim import Adam |
| 11 | +from multiprocessing import set_start_method |
| 12 | + |
| 13 | +torch.set_default_tensor_type(torch.cuda.FloatTensor) |
| 14 | + |
| 15 | +try: |
| 16 | + set_start_method('spawn') |
| 17 | +except RuntimeError: |
| 18 | + pass |
| 19 | + |
| 20 | + |
| 21 | +class CGAN: |
| 22 | + def __init__(self): |
| 23 | + self.discriminator = Discriminator() |
| 24 | + self.generator = Generator() |
| 25 | + self.gan = None |
| 26 | + self.gan_input = 100 |
| 27 | + self.batch_size = 32 |
| 28 | + self.test_count = 9 |
| 29 | + self.classes = 10 |
| 30 | + |
| 31 | + self.train_mnist_dataloader = None |
| 32 | + self.test_mnist_dataloader = None |
| 33 | + self.mnist_epochs = 50 |
| 34 | + self.discriminator_opt = Adam(self.discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)) |
| 35 | + self.generator_opt = Adam(self.generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) |
| 36 | + self.loss = nn.BCELoss() |
| 37 | + |
| 38 | + self.generator_model_path = 'models/cgan.hdf5' |
| 39 | + |
| 40 | + def load_data(self): |
| 41 | + transform = transforms.Compose( |
| 42 | + [transforms.ToTensor()] |
| 43 | + ) |
| 44 | + |
| 45 | + train_set = MNIST(root='./data/mnist', train=True, download=True, transform=transform) |
| 46 | + self.train_mnist_dataloader = torch.utils.data.DataLoader(train_set, |
| 47 | + batch_size=self.batch_size, |
| 48 | + shuffle=True, num_workers=1) |
| 49 | + |
| 50 | + test_set = MNIST(root='./data/mnist', train=False, download=True, transform=transform) |
| 51 | + self.test_mnist_dataloader = torch.utils.data.DataLoader(test_set, batch_size=self.batch_size, |
| 52 | + shuffle=False, num_workers=1) |
| 53 | + |
| 54 | + def train(self): |
| 55 | + self.load_data() |
| 56 | + for epoch in range(self.mnist_epochs): |
| 57 | + |
| 58 | + for i, data in enumerate(self.train_mnist_dataloader, 0): |
| 59 | + real, real_labels = data |
| 60 | + real_target = torch.ones(size=(self.batch_size, 1), requires_grad=False) - 0.1 |
| 61 | + fake_target = torch.zeros(size=(self.batch_size, 1), requires_grad=False) |
| 62 | + |
| 63 | + # generator update |
| 64 | + self.generator_opt.zero_grad() |
| 65 | + noise = torch.tensor(np.random.normal(0, 1, (self.batch_size, 100)), dtype=torch.float) |
| 66 | + noise_labels = torch.tensor(np.random.randint(0, self.classes, size=(self.batch_size, 1)), |
| 67 | + dtype=torch.long) |
| 68 | + fake = self.generator(noise, noise_labels) |
| 69 | + g_loss = self.loss(self.discriminator(fake, noise_labels), real_target) |
| 70 | + g_loss.backward() |
| 71 | + self.generator_opt.step() |
| 72 | + |
| 73 | + # discriminator update |
| 74 | + self.discriminator_opt.zero_grad() |
| 75 | + real_loss = self.loss(self.discriminator(real.cuda().detach(), real_labels), real_target) |
| 76 | + fake_loss = self.loss(self.discriminator(fake.detach(), noise_labels), fake_target) |
| 77 | + d_loss = (real_loss + fake_loss) / 2 |
| 78 | + d_loss.backward() |
| 79 | + self.discriminator_opt.step() |
| 80 | + |
| 81 | + print("Generator Loss: ", g_loss) |
| 82 | + print("Discriminator Loss: ", d_loss) |
| 83 | + |
| 84 | + # print results |
| 85 | + self.sample_and_save_gan(epoch) |
| 86 | + print("Epoch: ", epoch + 1) |
| 87 | + |
| 88 | + print('Finished Training') |
| 89 | + |
| 90 | + def sample_and_save_gan(self, epoch): |
| 91 | + noise = torch.randn(size=(1, self.gan_input)) |
| 92 | + labels = torch.randint(self.classes - 1, size=(1, 1)) |
| 93 | + img = self.generator(noise, labels) |
| 94 | + img = img.cpu().detach().numpy() |
| 95 | + img = np.squeeze(img, axis=0) |
| 96 | + img = np.squeeze(img, axis=0) |
| 97 | + img = img * 255.0 |
| 98 | + print(img) |
| 99 | + cv2.imwrite('gan_generated/img_{}.png'.format(epoch), img) |
| 100 | + torch.save(self.generator.state_dict(), self.generator_model_path) |
| 101 | + |
| 102 | + def plot_results(self, generated): |
| 103 | + fig = plt.figure(figsize=(28, 28)) |
| 104 | + columns = np.sqrt(self.test_count) |
| 105 | + rows = np.sqrt(self.test_count) |
| 106 | + generated = generated.cpu().detach().numpy() |
| 107 | + generated = np.squeeze(generated, axis=1) |
| 108 | + generated = generated * 255.0 |
| 109 | + for i in range(1, int(columns) * int(rows) + 1): |
| 110 | + fig.add_subplot(rows, columns, i) |
| 111 | + plt.imshow(generated[i - 1], cmap='gray_r') |
| 112 | + plt.show() |
| 113 | + |
| 114 | + def test(self): |
| 115 | + self.load_generator() |
| 116 | + noise = torch.randn(size=(self.test_count, self.gan_input)) |
| 117 | + labels = torch.randint(self.classes - 1, size=(self.test_count, 1)) |
| 118 | + print(labels) |
| 119 | + generated = self.generator(noise, labels) |
| 120 | + self.plot_results(generated) |
| 121 | + |
| 122 | + def load_generator(self): |
| 123 | + self.generator = Generator() |
| 124 | + self.generator.load_state_dict(torch.load(self.generator_model_path)) |
| 125 | + self.generator.eval() |
| 126 | + |
| 127 | + |
| 128 | +class Discriminator(nn.Module): |
| 129 | + def __init__(self, n_classes=10): |
| 130 | + super(Discriminator, self).__init__() |
| 131 | + self.conv1 = nn.Conv2d(in_channels=2, out_channels=64, kernel_size=3, stride=2, padding=1) |
| 132 | + self.leaky_relu1 = nn.LeakyReLU(negative_slope=0.2) |
| 133 | + self.drop_out1 = nn.Dropout(0.4) |
| 134 | + self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1) |
| 135 | + self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.2) |
| 136 | + self.drop_out2 = nn.Dropout(0.4) |
| 137 | + self.fc1 = nn.Linear(7 * 7 * 64, 1) |
| 138 | + self.sigmoid = nn.Sigmoid() |
| 139 | + |
| 140 | + self.embedding1 = nn.Embedding(n_classes, 50) |
| 141 | + self.fc1_label = nn.Linear(50, 784) |
| 142 | + |
| 143 | + def forward(self, x, y): |
| 144 | + y = self.embedding1(y) |
| 145 | + y = self.fc1_label(y) |
| 146 | + y = y.view(-1, 1, 28, 28) |
| 147 | + |
| 148 | + x = torch.cat([x, y], dim=1) |
| 149 | + |
| 150 | + x = self.conv1(x) |
| 151 | + x = self.leaky_relu1(x) |
| 152 | + x = self.drop_out1(x) |
| 153 | + x = self.conv2(x) |
| 154 | + x = self.leaky_relu2(x) |
| 155 | + x = self.drop_out2(x) |
| 156 | + x = torch.flatten(x, start_dim=1) |
| 157 | + x = self.sigmoid(self.fc1(x)) |
| 158 | + |
| 159 | + return x |
| 160 | + |
| 161 | + |
| 162 | +class Generator(nn.Module): |
| 163 | + def __init__(self, n_classes=10): |
| 164 | + super(Generator, self).__init__() |
| 165 | + self.fc1_out = 128 * 7 * 7 |
| 166 | + self.fc1 = nn.Linear(100, self.fc1_out) |
| 167 | + self.leaky_relu1 = nn.LeakyReLU(negative_slope=0.2) |
| 168 | + self.conv1 = nn.ConvTranspose2d(129, 128, kernel_size=4, stride=2, padding=1) |
| 169 | + self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.2) |
| 170 | + self.conv2 = nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1) |
| 171 | + self.leaky_relu3 = nn.LeakyReLU(negative_slope=0.2) |
| 172 | + self.conv3 = nn.Conv2d(128, 1, kernel_size=7, padding=3) |
| 173 | + self.sigmoid = nn.Sigmoid() |
| 174 | + |
| 175 | + self.embedding1 = nn.Embedding(n_classes, 50) |
| 176 | + self.fc1_label = nn.Linear(50, 49) |
| 177 | + |
| 178 | + def forward(self, x, y): |
| 179 | + x = self.fc1(x) |
| 180 | + x = self.leaky_relu1(x) |
| 181 | + x = x.view(-1, 128, 7, 7) |
| 182 | + |
| 183 | + y = self.embedding1(y) |
| 184 | + y = self.fc1_label(y) |
| 185 | + y = y.view(-1, 1, 7, 7) |
| 186 | + |
| 187 | + x = torch.cat([x, y], dim=1) |
| 188 | + |
| 189 | + x = self.conv1(x) |
| 190 | + x = self.leaky_relu2(x) |
| 191 | + x = self.conv2(x) |
| 192 | + x = self.leaky_relu3(x) |
| 193 | + x = self.sigmoid(self.conv3(x)) |
| 194 | + |
| 195 | + return x |
| 196 | + |
| 197 | + |
| 198 | +if __name__ == "__main__": |
| 199 | + fire.Fire(CGAN) |
0 commit comments