19
19
from __future__ import division
20
20
from __future__ import print_function
21
21
22
+ import hashlib
22
23
import os
23
24
import tarfile
24
- import hashlib
25
25
26
26
# Dependency imports
27
27
39
39
40
40
_DAILYMAIL_STORIES_DRIVE_URL = "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs"
41
41
42
+
42
43
# Note: using See et al. (2017) as reference for data generation
43
44
# For more info, use the links below
44
45
47
48
_DEV_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt"
48
49
_TEST_URLS = "https://github.com/abisee/cnn-dailymail/blob/master/url_lists/all_test.txt"
49
50
51
+
50
52
# End-of-sentence marker.
51
53
EOS = text_encoder .EOS_ID
52
54
55
+
53
56
# Techniques for data prep from See et al. (2017)
54
- dm_single_close_quote = u'\u2019 ' # unicode
55
- dm_double_close_quote = u'\u201d '
56
- END_TOKENS = [u'.' , u'!' , u'?' , u'...' , u"'" , u"`" , u'"' , dm_single_close_quote , dm_double_close_quote , u")" ] # acceptable ways to end a sentence
57
+ dm_single_close_quote = u"\u2019 " # unicode
58
+ dm_double_close_quote = u"\u201d "
59
+ # Acceptable ways to end a sentence.
60
+ END_TOKENS = [u"." , u"!" , u"?" , u"..." , u"'" , u"`" , u"\" " ,
61
+ dm_single_close_quote , dm_double_close_quote , u")" ]
57
62
58
63
59
64
def _maybe_download_corpora (tmp_dir , is_training ):
60
65
"""Download corpora if necessary and unzip them.
61
66
62
67
Args:
63
68
tmp_dir: directory containing dataset.
69
+ is_training: whether we're in training mode or not.
64
70
65
71
Returns:
66
- list of all files generated and path to file containing train/dev/test split info.
72
+ List of all files generated and path to file containing
73
+ train/dev/test split info.
67
74
"""
68
75
cnn_filename = "cnn_stories.tgz"
69
76
cnn_finalpath = os .path .join (tmp_dir , "cnn/stories/" )
@@ -85,43 +92,52 @@ def _maybe_download_corpora(tmp_dir, is_training):
85
92
all_files = cnn_files + dailymail_files
86
93
87
94
if is_training :
88
- urls_path = generator_utils .maybe_download (tmp_dir , "all_train.txt" , _TRAIN_URLS )
95
+ urls_path = generator_utils .maybe_download (
96
+ tmp_dir , "all_train.txt" , _TRAIN_URLS )
89
97
else :
90
- urls_path = generator_utils .maybe_download (tmp_dir , "all_val.txt" , _DEV_URLS )
98
+ urls_path = generator_utils .maybe_download (
99
+ tmp_dir , "all_val.txt" , _DEV_URLS )
91
100
92
101
return all_files , urls_path
93
102
103
+
94
104
def example_splits (url_file , all_files ):
105
+ """Generate splits of the data."""
95
106
def generate_hash (inp ):
96
- """Generate a sha1 hash to match the raw url to the filename extracted"""
97
- h = hashlib .sha1 ()
98
- h .update (inp )
99
- return h .hexdigest ()
107
+ """Generate a sha1 hash to match the raw url to the filename extracted. """
108
+ h = hashlib .sha1 ()
109
+ h .update (inp )
110
+ return h .hexdigest ()
100
111
101
- all_files_map = {f .split ("/" )[- 1 ]:f for f in all_files }
112
+ all_files_map = {f .split ("/" )[- 1 ]: f for f in all_files }
102
113
103
114
urls = []
104
115
for line in tf .gfile .Open (url_file ):
105
- urls .append (line .strip ().encode (' utf-8' ))
116
+ urls .append (line .strip ().encode (" utf-8" ))
106
117
107
118
filelist = []
108
119
for url in urls :
109
- url_hash = generate_hash (url )
110
- filename = url_hash + ".story"
111
- if filename not in all_files_map :
112
- tf .logging .info ("Missing file: %s" % url )
113
- continue
114
- filelist .append (all_files_map [filename ])
120
+ url_hash = generate_hash (url )
121
+ filename = url_hash + ".story"
122
+ if filename not in all_files_map :
123
+ tf .logging .info ("Missing file: %s" % url )
124
+ continue
125
+ filelist .append (all_files_map [filename ])
115
126
116
127
tf .logging .info ("Found %d examples" % len (filelist ))
117
128
118
129
return filelist
119
130
131
+
120
132
def example_generator (tmp_dir , is_training , sum_token ):
133
+ """Generate examples."""
121
134
def fix_run_on_sents (line ):
122
- if u"@highlight" in line : return line
123
- if line == "" : return line
124
- if line [- 1 ] in END_TOKENS : return line
135
+ if u"@highlight" in line :
136
+ return line
137
+ if not line :
138
+ return line
139
+ if line [- 1 ] in END_TOKENS :
140
+ return line
125
141
return line + u"."
126
142
127
143
all_files , urls_path = _maybe_download_corpora (tmp_dir , is_training )
@@ -133,28 +149,33 @@ def fix_run_on_sents(line):
133
149
summary = []
134
150
reading_highlights = False
135
151
for line in tf .gfile .Open (story_file , "rb" ):
136
- line = unicode (line .strip (), "utf-8" ) if six .PY2 else line .strip ().decode ("utf-8" )
152
+ if six .PY2 :
153
+ line = unicode (line .strip (), "utf-8" )
154
+ else :
155
+ line = line .strip ().decode ("utf-8" )
137
156
line = fix_run_on_sents (line )
138
- if line == "" :
139
- continue
157
+ if not line :
158
+ continue
140
159
elif line .startswith (u"@highlight" ):
141
- if len (story ) == 0 : break # No article text
142
- reading_highlights = True
160
+ if not story :
161
+ break # No article text.
162
+ reading_highlights = True
143
163
elif reading_highlights :
144
- summary .append (line )
164
+ summary .append (line )
145
165
else :
146
- story .append (line )
166
+ story .append (line )
147
167
148
- if len ( story ) == 0 or len ( summary ) == 0 :
149
- continue
168
+ if ( not story ) or not summary :
169
+ continue
150
170
151
171
yield " " .join (story ) + story_summary_split_token + " " .join (summary )
152
172
173
+
153
174
def _story_summary_split (story ):
154
175
split_str = u" <summary> "
155
176
split_str_len = len (split_str )
156
177
split_pos = story .find (split_str )
157
- return story [:split_pos ], story [split_pos + split_str_len :] # story, summary
178
+ return story [:split_pos ], story [split_pos + split_str_len :] # story, summary
158
179
159
180
160
181
@registry .register_problem
0 commit comments