|
| 1 | +from unittest.mock import MagicMock, patch |
| 2 | +import numpy as np |
| 3 | +import torch |
| 4 | + |
| 5 | + |
| 6 | +class _TestNN(torch.nn.Module): |
| 7 | + def __init__(self, input_size, output_size): |
| 8 | + super(_TestNN, self).__init__() |
| 9 | + self.decoder = torch.nn.Linear(input_size, output_size) |
| 10 | + self.forward_called = False |
| 11 | + |
| 12 | + def forward(self, nn_input, hidden): |
| 13 | + self.forward_called = True |
| 14 | + output = self.decoder(nn_input) |
| 15 | + |
| 16 | + return output, hidden |
| 17 | + |
| 18 | + |
| 19 | +def _print_success_message(): |
| 20 | + print('Tests Passed') |
| 21 | + |
| 22 | + |
| 23 | +class AssertTest(object): |
| 24 | + def __init__(self, params): |
| 25 | + self.assert_param_message = '\n'.join([str(k) + ': ' + str(v) + '' for k, v in params.items()]) |
| 26 | + |
| 27 | + def test(self, assert_condition, assert_message): |
| 28 | + assert assert_condition, assert_message + '\n\nUnit Test Function Parameters\n' + self.assert_param_message |
| 29 | + |
| 30 | + |
| 31 | +def test_create_lookup_tables(create_lookup_tables): |
| 32 | + test_text = ''' |
| 33 | + Moe_Szyslak Moe's Tavern Where the elite meet to drink |
| 34 | + Bart_Simpson Eh yeah hello is Mike there Last name Rotch |
| 35 | + Moe_Szyslak Hold on I'll check Mike Rotch Mike Rotch Hey has anybody seen Mike Rotch lately |
| 36 | + Moe_Szyslak Listen you little puke One of these days I'm gonna catch you and I'm gonna carve my name on your back with an ice pick |
| 37 | + Moe_Szyslak Whats the matter Homer You're not your normal effervescent self |
| 38 | + Homer_Simpson I got my problems Moe Give me another one |
| 39 | + Moe_Szyslak Homer hey you should not drink to forget your problems |
| 40 | + Barney_Gumble Yeah you should only drink to enhance your social skills''' |
| 41 | + |
| 42 | + test_text = test_text.lower() |
| 43 | + test_text = test_text.split() |
| 44 | + |
| 45 | + vocab_to_int, int_to_vocab = create_lookup_tables(test_text) |
| 46 | + |
| 47 | + # Check types |
| 48 | + assert isinstance(vocab_to_int, dict),\ |
| 49 | + 'vocab_to_int is not a dictionary.' |
| 50 | + assert isinstance(int_to_vocab, dict),\ |
| 51 | + 'int_to_vocab is not a dictionary.' |
| 52 | + |
| 53 | + # Compare lengths of dicts |
| 54 | + assert len(vocab_to_int) == len(int_to_vocab),\ |
| 55 | + 'Length of vocab_to_int and int_to_vocab don\'t match. ' \ |
| 56 | + 'vocab_to_int is length {}. int_to_vocab is length {}'.format(len(vocab_to_int), len(int_to_vocab)) |
| 57 | + |
| 58 | + # Make sure the dicts have the same words |
| 59 | + vocab_to_int_word_set = set(vocab_to_int.keys()) |
| 60 | + int_to_vocab_word_set = set(int_to_vocab.values()) |
| 61 | + |
| 62 | + assert not (vocab_to_int_word_set - int_to_vocab_word_set),\ |
| 63 | + 'vocab_to_int and int_to_vocab don\'t have the same words.' \ |
| 64 | + '{} found in vocab_to_int, but not in int_to_vocab'.format(vocab_to_int_word_set - int_to_vocab_word_set) |
| 65 | + assert not (int_to_vocab_word_set - vocab_to_int_word_set),\ |
| 66 | + 'vocab_to_int and int_to_vocab don\'t have the same words.' \ |
| 67 | + '{} found in int_to_vocab, but not in vocab_to_int'.format(int_to_vocab_word_set - vocab_to_int_word_set) |
| 68 | + |
| 69 | + # Make sure the dicts have the same word ids |
| 70 | + vocab_to_int_word_id_set = set(vocab_to_int.values()) |
| 71 | + int_to_vocab_word_id_set = set(int_to_vocab.keys()) |
| 72 | + |
| 73 | + assert not (vocab_to_int_word_id_set - int_to_vocab_word_id_set),\ |
| 74 | + 'vocab_to_int and int_to_vocab don\'t contain the same word ids.' \ |
| 75 | + '{} found in vocab_to_int, but not in int_to_vocab'.format(vocab_to_int_word_id_set - int_to_vocab_word_id_set) |
| 76 | + assert not (int_to_vocab_word_id_set - vocab_to_int_word_id_set),\ |
| 77 | + 'vocab_to_int and int_to_vocab don\'t contain the same word ids.' \ |
| 78 | + '{} found in int_to_vocab, but not in vocab_to_int'.format(int_to_vocab_word_id_set - vocab_to_int_word_id_set) |
| 79 | + |
| 80 | + # Make sure the dicts make the same lookup |
| 81 | + missmatches = [(word, id, id, int_to_vocab[id]) for word, id in vocab_to_int.items() if int_to_vocab[id] != word] |
| 82 | + |
| 83 | + assert not missmatches,\ |
| 84 | + 'Found {} missmatche(s). First missmatch: vocab_to_int[{}] = {} and int_to_vocab[{}] = {}'.format(len(missmatches), |
| 85 | + *missmatches[0]) |
| 86 | + |
| 87 | + assert len(vocab_to_int) > len(set(test_text))/2,\ |
| 88 | + 'The length of vocab seems too small. Found a length of {}'.format(len(vocab_to_int)) |
| 89 | + |
| 90 | + _print_success_message() |
| 91 | + |
| 92 | + |
| 93 | +def test_tokenize(token_lookup): |
| 94 | + symbols = set(['.', ',', '"', ';', '!', '?', '(', ')', '-', '\n']) |
| 95 | + token_dict = token_lookup() |
| 96 | + |
| 97 | + # Check type |
| 98 | + assert isinstance(token_dict, dict), \ |
| 99 | + 'Returned type is {}.'.format(type(token_dict)) |
| 100 | + |
| 101 | + # Check symbols |
| 102 | + missing_symbols = symbols - set(token_dict.keys()) |
| 103 | + unknown_symbols = set(token_dict.keys()) - symbols |
| 104 | + |
| 105 | + assert not missing_symbols, \ |
| 106 | + 'Missing symbols: {}'.format(missing_symbols) |
| 107 | + assert not unknown_symbols, \ |
| 108 | + 'Unknown symbols: {}'.format(unknown_symbols) |
| 109 | + |
| 110 | + # Check values type |
| 111 | + bad_value_type = [type(val) for val in token_dict.values() if not isinstance(val, str)] |
| 112 | + |
| 113 | + assert not bad_value_type,\ |
| 114 | + 'Found token as {} type.'.format(bad_value_type[0]) |
| 115 | + |
| 116 | + # Check for spaces |
| 117 | + key_has_spaces = [k for k in token_dict.keys() if ' ' in k] |
| 118 | + val_has_spaces = [val for val in token_dict.values() if ' ' in val] |
| 119 | + |
| 120 | + assert not key_has_spaces,\ |
| 121 | + 'The key "{}" includes spaces. Remove spaces from keys and values'.format(key_has_spaces[0]) |
| 122 | + assert not val_has_spaces,\ |
| 123 | + 'The value "{}" includes spaces. Remove spaces from keys and values'.format(val_has_spaces[0]) |
| 124 | + |
| 125 | + # Check for symbols in values |
| 126 | + symbol_val = () |
| 127 | + for symbol in symbols: |
| 128 | + for val in token_dict.values(): |
| 129 | + if symbol in val: |
| 130 | + symbol_val = (symbol, val) |
| 131 | + |
| 132 | + assert not symbol_val,\ |
| 133 | + 'Don\'t use a symbol that will be replaced in your tokens. Found the symbol {} in value {}'.format(*symbol_val) |
| 134 | + |
| 135 | + _print_success_message() |
| 136 | + |
| 137 | + |
| 138 | +def test_rnn(RNN, train_on_gpu): |
| 139 | + batch_size = 50 |
| 140 | + sequence_length = 3 |
| 141 | + vocab_size = 20 |
| 142 | + output_size=20 |
| 143 | + embedding_dim=15 |
| 144 | + hidden_dim = 10 |
| 145 | + n_layers = 2 |
| 146 | + |
| 147 | + # create test RNN |
| 148 | + # params: (vocab_size, output_size, embedding_dim, hidden_dim, n_layers) |
| 149 | + rnn = RNN(vocab_size, output_size, embedding_dim, hidden_dim, n_layers) |
| 150 | + |
| 151 | + # create test input |
| 152 | + a = np.random.randint(vocab_size, size=(batch_size, sequence_length)) |
| 153 | + #b = torch.LongTensor(a) |
| 154 | + b = torch.from_numpy(a) |
| 155 | + hidden = rnn.init_hidden(batch_size) |
| 156 | + |
| 157 | + if(train_on_gpu): |
| 158 | + rnn.cuda() |
| 159 | + b = b.cuda() |
| 160 | + |
| 161 | + output, hidden_out = rnn(b, hidden) |
| 162 | + |
| 163 | + assert_test = AssertTest({ |
| 164 | + 'Input Size': vocab_size, |
| 165 | + 'Output Size': output_size, |
| 166 | + 'Hidden Dim': hidden_dim, |
| 167 | + 'N Layers': n_layers, |
| 168 | + 'Batch Size': batch_size, |
| 169 | + 'Sequence Length': sequence_length, |
| 170 | + 'Input': b}) |
| 171 | + |
| 172 | + # initialization |
| 173 | + correct_hidden_size = (n_layers, batch_size, hidden_dim) |
| 174 | + |
| 175 | + if type(hidden) == tuple: |
| 176 | + # LSTM |
| 177 | + assert_condition = hidden[0].size() == correct_hidden_size |
| 178 | + else: |
| 179 | + # GRU |
| 180 | + assert_condition = hidden.size() == correct_hidden_size |
| 181 | + |
| 182 | + assert_message = 'Wrong hidden state size. Expected type {}. Got type {}'.format(correct_hidden_size, hidden[0].size()) |
| 183 | + assert_test.test(assert_condition, assert_message) |
| 184 | + |
| 185 | + # output of rnn |
| 186 | + correct_hidden_size = (n_layers, batch_size, hidden_dim) |
| 187 | + |
| 188 | + if type(hidden) == tuple: |
| 189 | + # LSTM |
| 190 | + assert_condition = hidden_out[0].size() == correct_hidden_size |
| 191 | + else: |
| 192 | + # GRU |
| 193 | + assert_condition = hidden_out.size() == correct_hidden_size |
| 194 | + |
| 195 | + assert_message = 'Wrong hidden state size. Expected type {}. Got type {}'.format(correct_hidden_size, hidden_out[0].size()) |
| 196 | + assert_test.test(assert_condition, assert_message) |
| 197 | + |
| 198 | + correct_output_size = (batch_size, output_size) |
| 199 | + assert_condition = output.size() == correct_output_size |
| 200 | + assert_message = 'Wrong output size. Expected type {}. Got type {}'.format(correct_output_size, output.size()) |
| 201 | + assert_test.test(assert_condition, assert_message) |
| 202 | + |
| 203 | + _print_success_message() |
| 204 | + |
| 205 | + |
| 206 | +def test_forward_back_prop(RNN, forward_back_prop, train_on_gpu): |
| 207 | + batch_size = 200 |
| 208 | + input_size = 20 |
| 209 | + output_size = 10 |
| 210 | + sequence_length = 3 |
| 211 | + embedding_dim=15 |
| 212 | + hidden_dim = 10 |
| 213 | + n_layers = 2 |
| 214 | + learning_rate = 0.01 |
| 215 | + |
| 216 | + # create test RNN |
| 217 | + rnn = RNN(input_size, output_size, embedding_dim, hidden_dim, n_layers) |
| 218 | + |
| 219 | + mock_decoder = MagicMock(wraps=_TestNN(input_size, output_size)) |
| 220 | + if train_on_gpu: |
| 221 | + mock_decoder.cuda() |
| 222 | + |
| 223 | + mock_decoder_optimizer = MagicMock(wraps=torch.optim.Adam(mock_decoder.parameters(), lr=learning_rate)) |
| 224 | + mock_criterion = MagicMock(wraps=torch.nn.CrossEntropyLoss()) |
| 225 | + |
| 226 | + with patch.object(torch.autograd, 'backward', wraps=torch.autograd.backward) as mock_autograd_backward: |
| 227 | + inp = torch.FloatTensor(np.random.rand(batch_size, input_size)) |
| 228 | + target = torch.LongTensor(np.random.randint(output_size, size=batch_size)) |
| 229 | + |
| 230 | + hidden = rnn.init_hidden(batch_size) |
| 231 | + |
| 232 | + loss, hidden_out = forward_back_prop(mock_decoder, mock_decoder_optimizer, mock_criterion, inp, target, hidden) |
| 233 | + |
| 234 | + if type(hidden_out) == tuple: |
| 235 | + # LSTM |
| 236 | + assert (hidden_out[0][0]==hidden[0][0]).sum()==batch_size*hidden_dim, 'Returned hidden state is the incorrect size.' |
| 237 | + else: |
| 238 | + # GRU |
| 239 | + assert (hidden_out[0]==hidden[0]).sum()==batch_size*hidden_dim, 'Returned hidden state is the incorrect size.' |
| 240 | + |
| 241 | + assert mock_decoder.zero_grad.called or mock_decoder_optimizer.zero_grad.called, 'Didn\'t set the gradients to 0.' |
| 242 | + assert mock_decoder.forward_called, 'Forward propagation not called.' |
| 243 | + assert mock_autograd_backward.called, 'Backward propagation not called' |
| 244 | + assert mock_decoder_optimizer.step.called, 'Optimization step not performed' |
| 245 | + assert type(loss) == float, 'Wrong return type. Expected {}, got {}'.format(float, type(loss)) |
| 246 | + |
| 247 | + _print_success_message() |
0 commit comments