Skip to content

Commit 20344b2

Browse files
author
Moeed
committed
Add Seq2Seq implementation
Fix VAE by summing loss Add Seq2Seq to README
1 parent bf1d1a2 commit 20344b2

File tree

3 files changed

+147
-14
lines changed

3 files changed

+147
-14
lines changed

README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,6 @@ The repo includes the following algorithms:
1616
7. **Variational Autoencoder (VAE)**
1717
[Paper: [Auto-Encoding Variational Bayes](https://arxiv.org/abs/1312.6114)]
1818
8. **Model Compression**
19-
[Paper: [BinaryConnect: Training Deep Neural Networks with binary weights during propagations](https://arxiv.org/abs/1511.00363)]
19+
[Paper: [BinaryConnect: Training Deep Neural Networks with binary weights during propagations](https://arxiv.org/abs/1511.00363)]
20+
9. **Neural Machine Translation**
21+
[Paper: [Sequence to Sequence Learning with Neural Networks](https://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf)]

seq2seq.py

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""
2+
Implementation of a Sequence-to-Sequence model for English to German translation in PyTorch
3+
"""
4+
5+
import numpy as np
6+
import fire
7+
import torch
8+
from torch.optim import Adam
9+
from multiprocessing import set_start_method
10+
11+
torch.set_default_tensor_type(torch.cuda.FloatTensor)
12+
13+
try:
14+
set_start_method('spawn')
15+
except RuntimeError:
16+
pass
17+
18+
19+
class Seq2Seq:
20+
def __init__(self):
21+
self.model = Model()
22+
self.corpus_en = None
23+
self.corpus_de = None
24+
self.vocab_en = None
25+
self.vocab_de = None
26+
self.vocab_len_en = None
27+
self.vocab_len_de = None
28+
self.data_en_path = "./data/en_de/train_en.dat"
29+
self.data_de_path = "./data/en_de/train_de.dat"
30+
self.embedding_dim = 256
31+
self.hidden_dim = 256
32+
self.model_name = "./models/seq2seq.h5"
33+
34+
# model
35+
self.batch_size = 32
36+
self.model_loss = torch.nn.CrossEntropyLoss()
37+
self.model_optim = None
38+
self.model_optim = Adam(self.model.parameters(), lr=0.0002, betas=(0.5, 0.999))
39+
self.max_len = None
40+
41+
def load_dataset(self, path):
42+
with open(path) as fp:
43+
corpus = fp.readlines()
44+
vocab = list(set(" ".join(corpus).split(" ")))
45+
vocab.extend(["<BLANK>", "<EOS>"])
46+
vocab_len = len(vocab)
47+
48+
len_corpus = len(max(corpus, key=len)) + 1
49+
if self.max_len is None:
50+
self.max_len = len_corpus
51+
if self.max_len < len_corpus:
52+
self.max_len = len_corpus
53+
54+
return corpus, vocab, vocab_len
55+
56+
def preprocess_corpus(self, corpus, lang, padding, eos):
57+
corpus_encoded = np.ones(shape=(len(corpus), self.max_len), dtype=np.float32) * padding
58+
for i, sentence in enumerate(corpus):
59+
for j, word in enumerate(sentence.split(" ")):
60+
corpus_encoded[i, j] = self.word_vocab_encode(word, lang)
61+
corpus_encoded[i, len(sentence.split(" "))] = eos
62+
63+
return corpus_encoded
64+
65+
def word_vocab_encode(self, word, lang):
66+
if lang == "en":
67+
return self.vocab_en.index(word)
68+
else:
69+
return self.vocab_de.index(word)
70+
71+
def save_preprocessed_corpus(self):
72+
self.corpus_en, self.vocab_en, self.vocab_len_en = self.load_dataset(self.data_en_path)
73+
self.corpus_de, self.vocab_de, self.vocab_len_de = self.load_dataset(self.data_de_path)
74+
75+
self.corpus_en = self.preprocess_corpus(self.corpus_en, "en", self.vocab_len_en - 2, self.vocab_len_en - 1)
76+
self.corpus_de = self.preprocess_corpus(self.corpus_de, "de", self.vocab_len_de - 2, self.vocab_len_de - 1)
77+
78+
np.save('./data/en_de/corpus_en', self.corpus_en)
79+
np.save('./data/en_de/corpus_de', self.corpus_de)
80+
81+
def train(self):
82+
_, self.vocab_en, self.vocab_len_en = self.load_dataset(self.data_en_path)
83+
_, self.vocab_de, self.vocab_len_de = self.load_dataset(self.data_de_path)
84+
85+
self.corpus_en = torch.tensor(np.load('./data/en_de/corpus_en.npy')).long()
86+
self.corpus_de = torch.tensor(np.load('./data/en_de/corpus_de.npy')).long()
87+
88+
self.model_optim.zero_grad()
89+
out = self.model(self.corpus_en[:self.batch_size], self.corpus_de[:self.batch_size])
90+
loss = self.model_loss(out, self.corpus_de[:self.batch_size])
91+
print("Loss: ", loss.item())
92+
loss.backward()
93+
self.model_optim.step()
94+
95+
96+
class Model(torch.nn.Module):
97+
def __init__(self):
98+
super().__init__()
99+
self.encoder = Encoder(8120)
100+
self.decoder = Decoder(10161)
101+
102+
def forward(self, x, y):
103+
x, state = self.encoder(x)
104+
x = self.decoder(y, state)
105+
106+
return x
107+
108+
109+
class Encoder(torch.nn.Module):
110+
def __init__(self, vocab_len_en):
111+
super().__init__()
112+
self.embedding1 = torch.nn.Embedding(vocab_len_en, 256)
113+
self.lstm1 = torch.nn.LSTM(256, hidden_size=256, num_layers=2)
114+
115+
def forward(self, x):
116+
x = self.embedding1(x)
117+
x, state = self.lstm1(x)
118+
119+
return x, state
120+
121+
122+
class Decoder(torch.nn.Module):
123+
def __init__(self, vocab_len_de):
124+
super().__init__()
125+
self.embedding1 = torch.nn.Embedding(num_embeddings=vocab_len_de, embedding_dim=256)
126+
self.lstm1 = torch.nn.LSTM(input_size=256, hidden_size=256)
127+
self.fc1 = torch.nn.Linear(in_features=256, out_features=vocab_len_de)
128+
129+
def forward(self, x, state):
130+
x = self.embedding1(x)
131+
# print(state[0].view(2, 1, -1, 256).shape[1])
132+
x, _ = self.lstm1(x, (state[0].view(2, 1, -1, 256)[1], state[1].view(2, 1, -1, 256)[1]))
133+
x = torch.softmax(self.fc1(x), dim=1)
134+
135+
return x
136+
137+
138+
def main():
139+
fire.Fire(Seq2Seq)
140+
141+
142+
if __name__ == "__main__":
143+
main()

vae.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def __init__(self):
3434
self.test_mnist_dataloader = None
3535
self.mnist_epochs = 50
3636
self.model_opt = Adam(self.model.parameters(), lr=0.0002, betas=(0.5, 0.999))
37-
self.generated_loss = torch.nn.BCELoss()
38-
self.latent_loss = torch.nn.KLDivLoss()
37+
self.generated_loss = torch.nn.BCELoss(reduction="sum")
3938
self.dist = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
4039

4140
self.model_path = 'models/vae.hdf5'
@@ -109,17 +108,6 @@ def train(self):
109108

110109
print('Finished Training')
111110

112-
def test(self):
113-
self.load_model()
114-
sample_vector = torch.randn(self.batch_size, self.latent_vector_size)
115-
generated = self.model.decode(sample_vector)
116-
self.plot_results(generated)
117-
118-
def load_model(self):
119-
self.model = Model()
120-
self.model.load_state_dict(torch.load(self.model_path))
121-
self.model.eval()
122-
123111

124112
class Model(torch.nn.Module):
125113
def __init__(self, n_classes=10):

0 commit comments

Comments
 (0)