Skip to content

Commit 9544fe7

Browse files
committed
Rewrite to use TF 1.1.0, support ML engine, updated results
1 parent 96fe374 commit 9544fe7

24 files changed

+814
-503
lines changed

.dockerignore

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
data/
2+
logs/

.editorconfig

-9
This file was deleted.

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
datasets/
1+
data/
22
logs/
33
.env/
44
*.pyc

README.md

+20-17
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Recurrent Entity Networks
22

33
This repository contains an independent TensorFlow implementation of recurrent entity networks from [Tracking the World State with
4-
Recurrent Entity Networks](https://openreview.net/forum?id=rJTKKKqeg). This paper introduces the first method to solve all of the bAbI tasks using 10k training examples. The author's original Torch implementation is available [here](https://github.com/facebook/MemNN/tree/master/EntNet-babi).
4+
Recurrent Entity Networks](https://arxiv.org/abs/1612.03969). This paper introduces the first method to solve all of the bAbI tasks using 10k training examples. The author's original Torch implementation is now available [here](https://github.com/facebook/MemNN/tree/master/EntNet-babi).
55

6-
<img src="images/diagram.png" alt="Diagram of recurrent entity network" width="886" height="658">
6+
<img src="assets/diagram.png" alt="Diagram of recurrent entity network architecture" width="886" height="658">
77

88
## Results
99

@@ -16,37 +16,40 @@ Task | EntNet (paper) | EntNet (repo)
1616
3: 3 supporting facts | 4.1 | ?
1717
4: 2 argument relations | 0 | 0
1818
5: 3 argument relations | 0.3 | ?
19-
6: yes/no questions | 0.2 | 0.1
20-
7: counting | 0 | ?
21-
8: lists/sets | 0.5 | ?
22-
9: simple negation | 0.1 | 0.7
23-
10: indefinite knowledge | 0.6 | 0.1
19+
6: yes/no questions | 0.2 | 0
20+
7: counting | 0 | 0
21+
8: lists/sets | 0.5 | 0
22+
9: simple negation | 0.1 | 0
23+
10: indefinite knowledge | 0.6 | 0
2424
11: basic coreference | 0.3 | 0
2525
12: conjunction | 0 | 0
2626
13: compound coreference | 1.3 | 0
27-
14: time reasoning | 0 | 4.5
27+
14: time reasoning | 0 | 0
2828
15: basic deduction | 0 | 0
29-
16: basic induction | 0.2 | 54.0 ([#5](../../issues/5))
29+
16: basic induction | 0.2 | 0
3030
17: positional reasoning | 0.5 | 1.7
3131
18: size reasoning | 0.3 | 1.5
32-
19: path finding | 2.3 | 41.9 ([#5](../../issues/5))
33-
20: agents motivation | 0 | 0.2
32+
19: path finding | 2.3 | 0
33+
20: agents motivation | 0 | 0
3434
**Failed Tasks** | 0 | ?
3535
**Mean Error** | 0.5 | ?
3636

37+
NOTE: Some of these tasks (16 and 19, in particular) required a change in learning rate schedule to reliably converge.
38+
3739
## Setup
3840

39-
1. Download the datasets by running [download_datasets.sh](download_datasets.sh) or from [The bAbI Project](https://research.facebook.com/research/babi/).
40-
2. Run [prep_datasets.py](prep_datasets.py) which will convert the datasets into [TFRecords](https://www.tensorflow.org/versions/r0.11/how_tos/reading_data/index.html#standard_tensorflow_format).
41+
1. Download the datasets by running [download_babi.sh](download_babi.sh) or from [The bAbI Project](https://research.facebook.com/research/babi/).
42+
2. Run [prep_data.py](entity_networks/prep_data.py) which will convert the datasets into [TFRecords](https://www.tensorflow.org/programmers_guide/reading_data#standard_tensorflow_format).
4143
3. Run `python -m entity_networks.main` to begin training on QA1.
42-
4. Run `./run_all.sh` to train on all tasks.
4344

44-
## Dependencies
45+
## Major Dependencies
46+
47+
- TensorFlow v1.1.0
4548

46-
- TensorFlow v0.11
49+
(For additional dependencies see [requirements.txt](requirements.txt))
4750

4851
## Thanks!
4952

5053
- Thanks to Mikael Henaff for providing details about their paper over Thanksgiving break. :)
5154
- Thanks to Andy Zhang ([@zhangandyx](https://twitter.com/zhangandyx)) for helping me troubleshoot numerical instabilities.
52-
- Thanks to Mike Young for providing results on some of the longer tasks.
55+
- Thanks to Mike Young (@mikalyoung) for providing results on some of the longer tasks.
File renamed without changes.

download_babi.sh

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#!/bin/bash
2+
3+
if [ ! -d ./datasets ]; then
4+
mkdir -p ./datasets
5+
fi
6+
7+
BABI_TASKS=datasets/babi_tasks_data_1_20_v1.2.tar.gz
8+
9+
if [ ! -f $BABI_TASKS ]; then
10+
wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz -O $BABI_TASKS
11+
fi

download_datasets.sh

-21
This file was deleted.

entity_networks/activations.py

-15
This file was deleted.

entity_networks/create_instances.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from __future__ import absolute_import
2+
from __future__ import print_function
3+
from __future__ import division
4+
5+
import os
6+
import json
7+
import random
8+
import argparse
9+
import tensorflow as tf
10+
11+
from tqdm import tqdm
12+
13+
from entity_networks.inputs import generate_input_fn
14+
15+
def main():
16+
parser = argparse.ArgumentParser()
17+
parser.add_argument(
18+
'--data-dir',
19+
help='Directory containing data',
20+
default='data/babi/records/')
21+
args = parser.parse_args()
22+
23+
tasks_dir = 'tasks/'
24+
25+
if not os.path.exists(tasks_dir):
26+
os.makedirs(tasks_dir)
27+
28+
task_names = [
29+
'qa1_single-supporting-fact',
30+
'qa2_two-supporting-facts',
31+
'qa3_three-supporting-facts',
32+
'qa4_two-arg-relations',
33+
'qa5_three-arg-relations',
34+
'qa6_yes-no-questions',
35+
'qa7_counting',
36+
'qa8_lists-sets',
37+
'qa9_simple-negation',
38+
'qa10_indefinite-knowledge',
39+
'qa11_basic-coreference',
40+
'qa12_conjunction',
41+
'qa13_compound-coreference',
42+
'qa14_time-reasoning',
43+
'qa15_basic-deduction',
44+
'qa16_basic-induction',
45+
'qa17_positional-reasoning',
46+
'qa18_size-reasoning',
47+
'qa19_path-finding',
48+
'qa20_agents-motivations',
49+
]
50+
51+
for task_name in tqdm(task_name.iteritems()):
52+
metadata_path = os.path.join(args.data_dir, '{}_10k.json'.format(task_name))
53+
with open(metadata_path) as metadata_file:
54+
metadata = json.load(metadata_file)
55+
56+
filename = os.path.join(data_dir, '{}_10k_{}.tfrecords'.format(dataset_id, 'test'))
57+
input_fn = generate_input_fn(
58+
filename=eval_filename,
59+
metadata=metadata,
60+
batch_size=BATCH_SIZE,
61+
num_epochs=1,
62+
shuffle=False)
63+
64+
with tf.Graph().as_default():
65+
features, answer = input_fn()
66+
67+
story = features['story']
68+
query = features['query']
69+
70+
instances = []
71+
72+
with tf.train.SingularMonitoredSession() as sess:
73+
while not sess.should_stop():
74+
story_, query_, answer_ = sess.run([story, query, answer])
75+
76+
instance = {
77+
'story': story_[0].tolist(),
78+
'query': query_[0].tolist(),
79+
'answer': answer_[0].tolist(),
80+
}
81+
82+
instances.append(instance)
83+
84+
metadata['instances'] = random.sample(instances, k=10)
85+
86+
output_path = os.path.join(tasks_dir, '{}.json'.format(task_name))
87+
with open(output_path, 'w') as f:
88+
f.write(json.dumps(metadata))
89+
90+
if __name__ == '__main__':
91+
main()

entity_networks/dataset.py

-51
This file was deleted.

0 commit comments

Comments
 (0)