Skip to content

Commit 39fc7c4

Browse files
committed
Adding the Progressive Neural Networks preliminary draft implementation without the core
1 parent 4ef5f68 commit 39fc7c4

File tree

2 files changed

+231
-3
lines changed

2 files changed

+231
-3
lines changed

naas.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(self):
2727
self.transforms_post = None
2828

2929
# device
30-
self.device = "cpu"
30+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
3131

3232
# images
3333
self.content = None
@@ -36,7 +36,7 @@ def __init__(self):
3636

3737
# image characteristics
3838
self.images_dir = 'data/images/'
39-
self.images_size = (512, 512)
39+
self.images_size = (256, 256)
4040
self.images_shape = (3, *self.images_size)
4141

4242
# training
@@ -119,7 +119,7 @@ def closure():
119119
total_loss.backward()
120120

121121
print("Iteration: ", i + 1, "\tLoss: ", total_loss)
122-
img = self.transforms_post(self.opt.clone())
122+
img = self.transforms_post(self.opt.clone().cpu())
123123
img.save(self.images_dir + 'opt_{}.jpg'.format(i))
124124

125125
return total_loss

pnn.py

+228
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
import torch.nn as nn
2+
import torch
3+
import fire
4+
import torchvision
5+
import torch.nn.functional as F
6+
import torch.optim as optim
7+
from torchvision.transforms import transforms
8+
from multiprocessing import set_start_method
9+
10+
torch.set_default_tensor_type(torch.cuda.FloatTensor)
11+
12+
try:
13+
set_start_method('spawn')
14+
except RuntimeError:
15+
pass
16+
17+
18+
class PNN:
19+
def __init__(self):
20+
self.mode = 'train_mnist'
21+
self.cifar10_train_loader = None
22+
self.cifar10_test_loader = None
23+
self.cifar10_net = None
24+
self.cifar10_epochs = 50
25+
self.cifar10_path = 'models/cifar10'
26+
27+
self.cifar10_criterion = nn.CrossEntropyLoss()
28+
self.cifar10_optimizer = None
29+
30+
self.mnist_net = None
31+
self.mnist_epochs = 5
32+
self.mnist_train_loader = None
33+
self.mnist_test_loader = None
34+
self.mnist_path = 'models/mnist'
35+
36+
self.mnist_criterion = nn.CrossEntropyLoss()
37+
self.mnist_optimizer = None
38+
39+
def load_cifar10_dataset(self):
40+
transform = transforms.Compose(
41+
[transforms.ToTensor(),
42+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
43+
44+
train_set = torchvision.datasets.CIFAR10(root='./data/cifar10', train=True, download=False, transform=transform)
45+
self.cifar10_train_loader = torch.utils.data.DataLoader(train_set, batch_size=8,
46+
shuffle=True, num_workers=1)
47+
48+
test_set = torchvision.datasets.CIFAR10(root='./data/cifar10', train=False, download=False, transform=transform)
49+
self.cifar10_test_loader = torch.utils.data.DataLoader(test_set, batch_size=4,
50+
shuffle=False, num_workers=1)
51+
52+
def load_fashion_mnist_dataset(self):
53+
transform = transforms.Compose(
54+
[transforms.ToTensor()])
55+
56+
train_set = torchvision.datasets.FashionMNIST(root='./data/fashion-mnist', train=True, download=False,
57+
transform=transform)
58+
self.mnist_train_loader = torch.utils.data.DataLoader(train_set, batch_size=8,
59+
shuffle=True, num_workers=1)
60+
61+
test_set = torchvision.datasets.FashionMNIST(root='./data/fashion-mnist', train=False, download=False,
62+
transform=transform)
63+
self.mnist_test_loader = torch.utils.data.DataLoader(test_set, batch_size=4,
64+
shuffle=False, num_workers=1)
65+
66+
def train_cifar10(self):
67+
68+
for epoch in range(self.cifar10_epochs):
69+
70+
running_loss = 0.0
71+
for i, data in enumerate(self.cifar10_train_loader, 0):
72+
inputs, labels = data
73+
74+
# zero the parameter gradients
75+
self.cifar10_optimizer.zero_grad()
76+
77+
# forward + backward + optimize
78+
outputs = self.cifar10_net.forward(torch.tensor(inputs, device='cuda'))
79+
loss = self.cifar10_criterion(outputs, labels)
80+
loss.backward()
81+
self.cifar10_optimizer.step()
82+
83+
# print statistics
84+
running_loss += loss.item()
85+
if i % 2000 == 1999: # print every 2000 mini-batches
86+
print('[%d, %5d] loss: %.3f' %
87+
(epoch + 1, i + 1, running_loss / 2000))
88+
running_loss = 0.0
89+
90+
print('Finished Training')
91+
92+
def train_mnist(self):
93+
for epoch in range(self.mnist_epochs):
94+
95+
running_loss = 0.0
96+
for i, data in enumerate(self.mnist_train_loader, 0):
97+
inputs, labels = data
98+
99+
# zero the parameter gradients
100+
self.mnist_optimizer.zero_grad()
101+
102+
# forward + backward + optimize
103+
inputs = torch.flatten(torch.tensor(inputs, device='cuda'), start_dim=1)
104+
outputs = self.mnist_net.forward(torch.tensor(inputs, device='cuda'))
105+
loss = self.mnist_criterion(outputs, labels)
106+
loss.backward()
107+
self.mnist_optimizer.step()
108+
109+
# print statistics
110+
running_loss += loss.item()
111+
if i % 2000 == 1999: # print every 2000 mini-batches
112+
print('[%d, %5d] loss: %.3f' %
113+
(epoch + 1, i + 1, running_loss / 2000))
114+
running_loss = 0.0
115+
116+
print('Finished Training')
117+
118+
def test_cifar10(self):
119+
correct = 0
120+
total = 0
121+
with torch.no_grad():
122+
for data in self.cifar10_test_loader:
123+
images, labels = data
124+
outputs = self.cifar10_net.forward(torch.tensor(images, device='cuda'))
125+
_, predicted = torch.max(outputs.data, 1)
126+
total += labels.size(0)
127+
correct += (predicted == labels).sum().item()
128+
129+
print('Accuracy of the network on the 10000 test images: %d %%' % (
130+
100 * correct / total))
131+
132+
def test_mnist(self):
133+
correct = 0
134+
total = 0
135+
with torch.no_grad():
136+
for data in self.mnist_test_loader:
137+
images, labels = data
138+
outputs = self.mnist_net.forward(torch.tensor(images, device='cuda'))
139+
_, predicted = torch.max(outputs.data, 1)
140+
total += labels.size(0)
141+
correct += (predicted == labels).sum().item()
142+
143+
print('Accuracy of the network on the test images: %d %%' % (
144+
100 * correct / total))
145+
146+
def load_cifar10(self):
147+
if self.mode == 'train_cifar10':
148+
self.cifar10_net = Cifar10Net()
149+
else:
150+
self.cifar10_net = Cifar10Net()
151+
self.cifar10_net.load_state_dict(torch.load(self.cifar10_path))
152+
self.cifar10_net.eval()
153+
154+
def load_mnist(self):
155+
self.mnist_net = FashionMNISTNet()
156+
157+
def save_cifar10(self):
158+
torch.save(self.cifar10_net.state_dict(), self.cifar10_path)
159+
160+
def save_mnist(self):
161+
torch.save(self.mnist_net.state_dict(), self.mnist_path)
162+
163+
def train(self):
164+
self.load_cifar10_dataset()
165+
self.load_fashion_mnist_dataset()
166+
self.load_mnist()
167+
self.load_cifar10()
168+
if self.mode == 'train_cifar10':
169+
self.cifar10_optimizer = optim.SGD(self.cifar10_net.parameters(), lr=0.001, momentum=0.9)
170+
self.train_cifar10()
171+
self.save_cifar10()
172+
elif self.mode == 'test_cifar10':
173+
self.test_cifar10()
174+
elif self.mode == 'train_mnist':
175+
self.mnist_optimizer = optim.SGD(self.mnist_net.parameters(), lr=0.001, momentum=0.9)
176+
self.train_mnist()
177+
self.save_mnist()
178+
179+
180+
class Cifar10Net(nn.Module):
181+
def __init__(self):
182+
super(Cifar10Net, self).__init__()
183+
self.conv1 = nn.Conv2d(3, 6, 5)
184+
self.pool = nn.MaxPool2d(2, 2)
185+
self.conv2 = nn.Conv2d(6, 16, 5)
186+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
187+
self.fc2 = nn.Linear(120, 84)
188+
self.fc3 = nn.Linear(84, 10)
189+
190+
def forward(self, x):
191+
x = self.pool(F.relu(self.conv1(x)))
192+
x = self.pool(F.relu(self.conv2(x)))
193+
x = x.view(-1, 16 * 5 * 5)
194+
x = F.relu(self.fc1(x))
195+
x = F.relu(self.fc2(x))
196+
x = self.fc3(x)
197+
return x
198+
199+
200+
class FashionMNISTNet(nn.Module):
201+
def __init__(self):
202+
super(FashionMNISTNet, self).__init__()
203+
204+
# Defining the layers, 128, 64, 10 units each
205+
self.fc1 = nn.Linear(784, 128)
206+
self.fc2 = nn.Linear(128, 64)
207+
self.fc3 = nn.Linear(64, 32)
208+
209+
# Output layer, 10 units - one for each digit
210+
self.fc4 = nn.Linear(32, 10)
211+
212+
def forward(self, x):
213+
''' Forward pass through the network, returns the output logits '''
214+
215+
x = self.fc1(x)
216+
x = F.relu(x)
217+
x = self.fc2(x)
218+
x = F.relu(x)
219+
x = self.fc3(x)
220+
x = F.relu(x)
221+
x = self.fc4(x)
222+
x = F.softmax(x, dim=1)
223+
224+
return x
225+
226+
227+
if __name__ == '__main__':
228+
fire.Fire(PNN)

0 commit comments

Comments
 (0)