Skip to content

Commit 96fe374

Browse files
committedApr 11, 2017
Merge branch 'master' of github.com:jimfleming/recurrent-entity-networks
2 parents 086e294 + 365f058 commit 96fe374

16 files changed

+556
-439
lines changed
 

‎LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2016 Jim Fleming
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

‎README.md

+41-5
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,52 @@
11
# Recurrent Entity Networks
22

3-
This repository contains a TensorFlow implementation of recurrent entity networks from [Tracking the World State with
4-
Recurrent Entity Networks](https://openreview.net/forum?id=rJTKKKqeg).
3+
This repository contains an independent TensorFlow implementation of recurrent entity networks from [Tracking the World State with
4+
Recurrent Entity Networks](https://openreview.net/forum?id=rJTKKKqeg). This paper introduces the first method to solve all of the bAbI tasks using 10k training examples. The author's original Torch implementation is available [here](https://github.com/facebook/MemNN/tree/master/EntNet-babi).
55

6-
![Diagram of recurrent entity network](images/diagram.png)
6+
<img src="images/diagram.png" alt="Diagram of recurrent entity network" width="886" height="658">
7+
8+
## Results
9+
10+
Percent error for each task, comparing those in the paper to the implementation contained in this repository.
11+
12+
Task | EntNet (paper) | EntNet (repo)
13+
--- | --- | ---
14+
1: 1 supporting fact | 0 | 0
15+
2: 2 supporting facts | 0.1 | 3.0
16+
3: 3 supporting facts | 4.1 | ?
17+
4: 2 argument relations | 0 | 0
18+
5: 3 argument relations | 0.3 | ?
19+
6: yes/no questions | 0.2 | 0.1
20+
7: counting | 0 | ?
21+
8: lists/sets | 0.5 | ?
22+
9: simple negation | 0.1 | 0.7
23+
10: indefinite knowledge | 0.6 | 0.1
24+
11: basic coreference | 0.3 | 0
25+
12: conjunction | 0 | 0
26+
13: compound coreference | 1.3 | 0
27+
14: time reasoning | 0 | 4.5
28+
15: basic deduction | 0 | 0
29+
16: basic induction | 0.2 | 54.0 ([#5](../../issues/5))
30+
17: positional reasoning | 0.5 | 1.7
31+
18: size reasoning | 0.3 | 1.5
32+
19: path finding | 2.3 | 41.9 ([#5](../../issues/5))
33+
20: agents motivation | 0 | 0.2
34+
**Failed Tasks** | 0 | ?
35+
**Mean Error** | 0.5 | ?
736

837
## Setup
938

1039
1. Download the datasets by running [download_datasets.sh](download_datasets.sh) or from [The bAbI Project](https://research.facebook.com/research/babi/).
1140
2. Run [prep_datasets.py](prep_datasets.py) which will convert the datasets into [TFRecords](https://www.tensorflow.org/versions/r0.11/how_tos/reading_data/index.html#standard_tensorflow_format).
12-
3. Run `python -m entity_networks.main` to begin training.
41+
3. Run `python -m entity_networks.main` to begin training on QA1.
42+
4. Run `./run_all.sh` to train on all tasks.
1343

1444
## Dependencies
1545

16-
- TensorFlow v0.11rc0
46+
- TensorFlow v0.11
47+
48+
## Thanks!
49+
50+
- Thanks to Mikael Henaff for providing details about their paper over Thanksgiving break. :)
51+
- Thanks to Andy Zhang ([@zhangandyx](https://twitter.com/zhangandyx)) for helping me troubleshoot numerical instabilities.
52+
- Thanks to Mike Young for providing results on some of the longer tasks.

‎download_datasets.sh

+4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#!/bin/bash
22

3+
if [ ! -d ./datasets ]; then
4+
mkdir -p ./datasets
5+
fi
6+
37
BABI_TASKS=datasets/babi_tasks_data_1_20_v1.2.tar.gz
48
DIALOG_TASKS=datasets/dialog_babi_tasks_data_1_6.tgz
59
CHILDRENS_BOOK=datasets/childrens_book_test.tgz

‎entity_networks/activations.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,8 @@ def prelu(features, initializer=None, scope=None):
88
"""
99
Implementation of [Parametric ReLU](https://arxiv.org/abs/1502.01852) borrowed from Keras.
1010
"""
11-
with tf.variable_scope(scope, 'PReLU'):
12-
alpha = tf.get_variable('alpha',
13-
shape=features.get_shape().as_list()[1:],
14-
initializer=initializer)
11+
with tf.variable_scope(scope, 'PReLU', initializer=initializer):
12+
alpha = tf.get_variable('alpha', features.get_shape().as_list()[1:])
1513
pos = tf.nn.relu(features)
1614
neg = alpha * (features - tf.abs(features)) * 0.5
1715
return pos + neg

‎entity_networks/dataset.py

+43-83
Original file line numberDiff line numberDiff line change
@@ -2,90 +2,50 @@
22
from __future__ import print_function
33
from __future__ import division
44

5+
import os
6+
import json
57
import tensorflow as tf
68

7-
MAX_SENTENCE_LENGTH = 7
8-
MAX_STORY_LENGTH = 10
9-
MAX_QUERY_LENGTH = 4
10-
11-
DATASET_SIZE = 10000
12-
VOCAB_SIZE = 22
13-
14-
def record_reader(filename_queue):
15-
reader = tf.TFRecordReader()
16-
_, serialized = reader.read(filename_queue)
17-
18-
features = tf.parse_single_example(serialized, features={
19-
"story": tf.FixedLenFeature([MAX_STORY_LENGTH, MAX_SENTENCE_LENGTH], dtype=tf.int64),
20-
"query": tf.FixedLenFeature([1, MAX_QUERY_LENGTH], dtype=tf.int64),
21-
"answer": tf.FixedLenFeature([], dtype=tf.int64),
22-
})
23-
24-
story = features['story']
25-
query = features['query']
26-
answer = features['answer']
27-
28-
return story, query, answer
29-
309
class Dataset(object):
3110

32-
def __init__(self, filename, batch_size, shuffle=False):
33-
self._batch_size = batch_size
34-
35-
filename_queue = tf.train.string_input_producer([filename], shuffle=shuffle)
36-
records = record_reader(filename_queue)
37-
38-
min_after_dequeue = DATASET_SIZE
39-
capacity = min_after_dequeue + 100 * batch_size
40-
41-
if shuffle:
42-
self._story_batch, self._query_batch, self._answer_batch = \
43-
tf.train.shuffle_batch(records,
44-
batch_size=batch_size,
45-
min_after_dequeue=min_after_dequeue,
46-
capacity=capacity)
47-
else:
48-
self._story_batch, self._query_batch, self._answer_batch = \
49-
tf.train.batch(records,
50-
batch_size=batch_size,
51-
capacity=capacity)
52-
53-
@property
54-
def story_batch(self):
55-
return self._story_batch
56-
57-
@property
58-
def query_batch(self):
59-
return self._query_batch
60-
61-
@property
62-
def answer_batch(self):
63-
return self._answer_batch
64-
65-
@property
66-
def batch_size(self):
67-
return self._batch_size
68-
69-
@property
70-
def max_sentence_length(self):
71-
return MAX_SENTENCE_LENGTH
72-
73-
@property
74-
def max_story_length(self):
75-
return MAX_STORY_LENGTH
76-
77-
@property
78-
def max_query_length(self):
79-
return MAX_QUERY_LENGTH
80-
81-
@property
82-
def vocab_size(self):
83-
return VOCAB_SIZE
84-
85-
@property
86-
def size(self):
87-
return DATASET_SIZE
88-
89-
@property
90-
def num_batches(self):
91-
return DATASET_SIZE // self._batch_size
11+
def __init__(self, dataset_path, batch_size):
12+
self.dataset_dir = os.path.dirname(dataset_path)
13+
self.batch_size = batch_size
14+
self.examples_per_epoch = 10000
15+
16+
with open(dataset_path) as f:
17+
metadata = json.load(f)
18+
19+
self.max_sentence_length = metadata['max_sentence_length']
20+
self.max_story_length = metadata['max_story_length']
21+
self.max_query_length = metadata['max_query_length']
22+
self.dataset_size = metadata['dataset_size']
23+
self.vocab_size = metadata['vocab_size']
24+
self.tokens = metadata['tokens']
25+
self.datasets = metadata['datasets']
26+
27+
@property
28+
def steps_per_epoch(self):
29+
return self.batch_size * self.examples_per_epoch
30+
31+
def get_input_fn(self, name, num_epochs, shuffle):
32+
def input_fn():
33+
features = {
34+
"story": tf.FixedLenFeature([self.max_story_length, self.max_sentence_length], dtype=tf.int64),
35+
"query": tf.FixedLenFeature([1, self.max_query_length], dtype=tf.int64),
36+
"answer": tf.FixedLenFeature([], dtype=tf.int64),
37+
}
38+
39+
dataset_path = os.path.join(self.dataset_dir, self.datasets[name])
40+
features = tf.contrib.learn.read_batch_record_features(dataset_path,
41+
features=features,
42+
batch_size=self.batch_size,
43+
randomize_input=shuffle,
44+
num_epochs=num_epochs)
45+
46+
story = features['story']
47+
query = features['query']
48+
answer = features['answer']
49+
50+
return {'story': story, 'query': query}, answer
51+
return input_fn

‎entity_networks/dynamic_memory_cell.py

+26-21
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@ class DynamicMemoryCell(tf.nn.rnn_cell.RNNCell):
1111
The cell's hidden state is divided into blocks and each block's weights are tied.
1212
"""
1313

14-
def __init__(self, num_blocks, num_units_per_block, activation=tf.nn.relu):
14+
def __init__(self, num_blocks, num_units_per_block, keys, initializer=None, activation=tf.nn.relu):
1515
self._num_blocks = num_blocks # M
1616
self._num_units_per_block = num_units_per_block # d
17+
self._keys = keys
1718
self._activation = activation # \phi
19+
self._initializer = initializer
1820

1921
@property
2022
def state_size(self):
@@ -24,14 +26,22 @@ def state_size(self):
2426
def output_size(self):
2527
return self._num_blocks * self._num_units_per_block
2628

27-
def get_gate(self, inputs, state_j, key_j):
29+
def zero_state(self, batch_size, dtype):
2830
"""
29-
Implements the gate (a scalar for each block). Equation 2:
31+
We initialize the memory to the key values.
32+
"""
33+
zero_state = tf.concat(1, [tf.expand_dims(key, 0) for key in self._keys])
34+
zero_state_batch = tf.tile(zero_state, tf.pack([batch_size, 1]))
35+
return zero_state_batch
36+
37+
def get_gate(self, state_j, key_j, inputs):
38+
"""
39+
Implements the gate (scalar for each block). Equation 2:
3040
3141
g_j <- \sigma(s_t^T h_j + s_t^T w_j)
3242
"""
3343
a = tf.reduce_sum(inputs * state_j, reduction_indices=[1])
34-
b = tf.reduce_sum(inputs * key_j, reduction_indices=[1])
44+
b = tf.reduce_sum(inputs * tf.expand_dims(key_j, 0), reduction_indices=[1])
3545
return tf.sigmoid(a + b)
3646

3747
def get_candidate(self, state_j, key_j, inputs, U, V, W):
@@ -41,41 +51,36 @@ def get_candidate(self, state_j, key_j, inputs, U, V, W):
4151
4252
h_j^~ <- \phi(U h_j + V w_j + W s_t)
4353
"""
54+
key_V = tf.matmul(tf.expand_dims(key_j, 0), V)
4455
state_U = tf.matmul(state_j, U)
4556
inputs_W = tf.matmul(inputs, W)
46-
key_V = tf.matmul(tf.expand_dims(key_j, 0), V)
4757
return self._activation(state_U + key_V + inputs_W)
4858

4959
def __call__(self, inputs, state, scope=None):
50-
with tf.variable_scope(scope or type(self).__name__):
60+
with tf.variable_scope(scope or type(self).__name__, initializer=self._initializer):
5161
# Split the hidden state into blocks (each U, V, W are shared across blocks).
5262
state = tf.split(1, self._num_blocks, state)
5363

54-
U = tf.get_variable('U',
55-
shape=[self._num_units_per_block, self._num_units_per_block],
56-
initializer=tf.random_normal_initializer(0.1))
57-
V = tf.get_variable('V',
58-
shape=[self._num_units_per_block, self._num_units_per_block],
59-
initializer=tf.random_normal_initializer(0.1))
60-
W = tf.get_variable('W',
61-
shape=[self._num_units_per_block, self._num_units_per_block],
62-
initializer=tf.random_normal_initializer(0.1))
64+
# TODO: ortho init?
65+
U = tf.get_variable('U', [self._num_units_per_block, self._num_units_per_block])
66+
V = tf.get_variable('V', [self._num_units_per_block, self._num_units_per_block])
67+
W = tf.get_variable('W', [self._num_units_per_block, self._num_units_per_block])
68+
69+
# TODO: layer norm?
6370

6471
next_states = []
6572
for j, state_j in enumerate(state): # Hidden State (j)
66-
key_j = tf.get_variable('key_{}'.format(j),
67-
shape=[self._num_units_per_block],
68-
initializer=tf.random_normal_initializer(0.1))
69-
gate_j = self.get_gate(inputs, state_j, key_j)
73+
key_j = self._keys[j]
74+
gate_j = self.get_gate(state_j, key_j, inputs)
7075
candidate_j = self.get_candidate(state_j, key_j, inputs, U, V, W)
7176

7277
# Equation 4: h_j <- h_j + g_j * h_j^~
7378
# Perform an update of the hidden state (memory).
7479
state_j_next = state_j + tf.expand_dims(gate_j, -1) * candidate_j
7580

7681
# Equation 5: h_j <- h_j / \norm{h_j}
77-
# Forgot previous memories by normalization.
78-
state_j_next = tf.nn.l2_normalize(state_j_next, -1)
82+
# Forget previous memories by normalization.
83+
state_j_next = tf.nn.l2_normalize(state_j_next, -1, epsilon=1e-7) # TODO: Is epsilon necessary?
7984

8085
next_states.append(state_j_next)
8186
state_next = tf.concat(1, next_states)

0 commit comments

Comments
 (0)
Please sign in to comment.