|
| 1 | +""" |
| 2 | +Implementation of a Sequence-to-Sequence model for English to German translation in PyTorch |
| 3 | +""" |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import fire |
| 7 | +import torch |
| 8 | +from torch.optim import Adam |
| 9 | +from multiprocessing import set_start_method |
| 10 | + |
| 11 | +torch.set_default_tensor_type(torch.cuda.FloatTensor) |
| 12 | + |
| 13 | +try: |
| 14 | + set_start_method('spawn') |
| 15 | +except RuntimeError: |
| 16 | + pass |
| 17 | + |
| 18 | + |
| 19 | +class Seq2Seq: |
| 20 | + def __init__(self): |
| 21 | + self.model = Model() |
| 22 | + self.corpus_en = None |
| 23 | + self.corpus_de = None |
| 24 | + self.vocab_en = None |
| 25 | + self.vocab_de = None |
| 26 | + self.vocab_len_en = None |
| 27 | + self.vocab_len_de = None |
| 28 | + self.data_en_path = "./data/en_de/train_en.dat" |
| 29 | + self.data_de_path = "./data/en_de/train_de.dat" |
| 30 | + self.embedding_dim = 256 |
| 31 | + self.hidden_dim = 256 |
| 32 | + self.model_name = "./models/seq2seq.h5" |
| 33 | + |
| 34 | + # model |
| 35 | + self.batch_size = 32 |
| 36 | + self.model_loss = torch.nn.CrossEntropyLoss() |
| 37 | + self.model_optim = None |
| 38 | + self.model_optim = Adam(self.model.parameters(), lr=0.0002, betas=(0.5, 0.999)) |
| 39 | + self.max_len = None |
| 40 | + |
| 41 | + def load_dataset(self, path): |
| 42 | + with open(path) as fp: |
| 43 | + corpus = fp.readlines() |
| 44 | + vocab = list(set(" ".join(corpus).split(" "))) |
| 45 | + vocab.extend(["<BLANK>", "<EOS>"]) |
| 46 | + vocab_len = len(vocab) |
| 47 | + |
| 48 | + len_corpus = len(max(corpus, key=len)) + 1 |
| 49 | + if self.max_len is None: |
| 50 | + self.max_len = len_corpus |
| 51 | + if self.max_len < len_corpus: |
| 52 | + self.max_len = len_corpus |
| 53 | + |
| 54 | + return corpus, vocab, vocab_len |
| 55 | + |
| 56 | + def preprocess_corpus(self, corpus, lang, padding, eos): |
| 57 | + corpus_encoded = np.ones(shape=(len(corpus), self.max_len), dtype=np.float32) * padding |
| 58 | + for i, sentence in enumerate(corpus): |
| 59 | + for j, word in enumerate(sentence.split(" ")): |
| 60 | + corpus_encoded[i, j] = self.word_vocab_encode(word, lang) |
| 61 | + corpus_encoded[i, len(sentence.split(" "))] = eos |
| 62 | + |
| 63 | + return corpus_encoded |
| 64 | + |
| 65 | + def word_vocab_encode(self, word, lang): |
| 66 | + if lang == "en": |
| 67 | + return self.vocab_en.index(word) |
| 68 | + else: |
| 69 | + return self.vocab_de.index(word) |
| 70 | + |
| 71 | + def save_preprocessed_corpus(self): |
| 72 | + self.corpus_en, self.vocab_en, self.vocab_len_en = self.load_dataset(self.data_en_path) |
| 73 | + self.corpus_de, self.vocab_de, self.vocab_len_de = self.load_dataset(self.data_de_path) |
| 74 | + |
| 75 | + self.corpus_en = self.preprocess_corpus(self.corpus_en, "en", self.vocab_len_en - 2, self.vocab_len_en - 1) |
| 76 | + self.corpus_de = self.preprocess_corpus(self.corpus_de, "de", self.vocab_len_de - 2, self.vocab_len_de - 1) |
| 77 | + |
| 78 | + np.save('./data/en_de/corpus_en', self.corpus_en) |
| 79 | + np.save('./data/en_de/corpus_de', self.corpus_de) |
| 80 | + |
| 81 | + def train(self): |
| 82 | + _, self.vocab_en, self.vocab_len_en = self.load_dataset(self.data_en_path) |
| 83 | + _, self.vocab_de, self.vocab_len_de = self.load_dataset(self.data_de_path) |
| 84 | + |
| 85 | + self.corpus_en = torch.tensor(np.load('./data/en_de/corpus_en.npy')).long() |
| 86 | + self.corpus_de = torch.tensor(np.load('./data/en_de/corpus_de.npy')).long() |
| 87 | + |
| 88 | + self.model_optim.zero_grad() |
| 89 | + out = self.model(self.corpus_en[:self.batch_size], self.corpus_de[:self.batch_size]) |
| 90 | + loss = self.model_loss(out, self.corpus_de[:self.batch_size]) |
| 91 | + print("Loss: ", loss.item()) |
| 92 | + loss.backward() |
| 93 | + self.model_optim.step() |
| 94 | + |
| 95 | + |
| 96 | +class Model(torch.nn.Module): |
| 97 | + def __init__(self): |
| 98 | + super().__init__() |
| 99 | + self.encoder = Encoder(8120) |
| 100 | + self.decoder = Decoder(10161) |
| 101 | + |
| 102 | + def forward(self, x, y): |
| 103 | + x, state = self.encoder(x) |
| 104 | + x = self.decoder(y, state) |
| 105 | + |
| 106 | + return x |
| 107 | + |
| 108 | + |
| 109 | +class Encoder(torch.nn.Module): |
| 110 | + def __init__(self, vocab_len_en): |
| 111 | + super().__init__() |
| 112 | + self.embedding1 = torch.nn.Embedding(vocab_len_en, 256) |
| 113 | + self.lstm1 = torch.nn.LSTM(256, hidden_size=256, num_layers=2) |
| 114 | + |
| 115 | + def forward(self, x): |
| 116 | + x = self.embedding1(x) |
| 117 | + x, state = self.lstm1(x) |
| 118 | + |
| 119 | + return x, state |
| 120 | + |
| 121 | + |
| 122 | +class Decoder(torch.nn.Module): |
| 123 | + def __init__(self, vocab_len_de): |
| 124 | + super().__init__() |
| 125 | + self.embedding1 = torch.nn.Embedding(num_embeddings=vocab_len_de, embedding_dim=256) |
| 126 | + self.lstm1 = torch.nn.LSTM(input_size=256, hidden_size=256) |
| 127 | + self.fc1 = torch.nn.Linear(in_features=256, out_features=vocab_len_de) |
| 128 | + |
| 129 | + def forward(self, x, state): |
| 130 | + x = self.embedding1(x) |
| 131 | + # print(state[0].view(2, 1, -1, 256).shape[1]) |
| 132 | + x, _ = self.lstm1(x, (state[0].view(2, 1, -1, 256)[1], state[1].view(2, 1, -1, 256)[1])) |
| 133 | + x = torch.softmax(self.fc1(x), dim=1) |
| 134 | + |
| 135 | + return x |
| 136 | + |
| 137 | + |
| 138 | +def main(): |
| 139 | + fire.Fire(Seq2Seq) |
| 140 | + |
| 141 | + |
| 142 | +if __name__ == "__main__": |
| 143 | + main() |
0 commit comments