Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 9bdc801

Browse files
authored
Merge pull request #383 from rsepassi/push
v1.2.6
2 parents a836d66 + ba47b61 commit 9bdc801

40 files changed

+932
-314
lines changed

.travis.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@ script:
2424
- mkdir $T2T_TRAIN_DIR
2525
- t2t-datagen --problem=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR
2626
- t2t-trainer --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --train_steps=5 --eval_steps=5 --output_dir=$T2T_TRAIN_DIR
27-
- t2t-decoder --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR --decode_hparams='num_samples=10'
27+
- t2t-decoder --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR --decode_hparams='num_samples=10,use_last_position_only=True'
2828
git:
2929
depth: 3

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ registrations.
286286
To add a new dataset, subclass
287287
[`Problem`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py)
288288
and register it with `@registry.register_problem`. See
289-
[`TranslateEndeWmt8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py)
289+
[`TranslateEndeWmt8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/translate_ende.py)
290290
for an example.
291291

292292
Also see the [data generators

docs/new_problem.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ We're almost done. `generator` generates the training and evaluation data and
105105
stores them in files like "word2def_train.lang1" in your DATA_DIR. Thankfully
106106
several commonly used methods like `character_generator`, and `token_generator`
107107
are already written in the file
108-
[`wmt.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py).
108+
[`translate.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/translate.py).
109109
We will import `character_generator` and
110110
[`text_encoder`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/text_encoder.py)
111111
to write:

docs/walkthrough.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ registrations.
286286
To add a new dataset, subclass
287287
[`Problem`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py)
288288
and register it with `@registry.register_problem`. See
289-
[`TranslateEndeWmt8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py)
289+
[`TranslateEndeWmt8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/translate_ende.py)
290290
for an example.
291291

292292
Also see the [data generators

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.2.5',
8+
version='1.2.6',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='[email protected]',

tensor2tensor/bin/t2t-datagen

100755100644
+2-2
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ _SUPPORTED_PROBLEM_GENERATORS = {
8282
lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
8383
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
8484
"parsing_english_ptb8k": (
85-
lambda: wmt.parsing_token_generator(
85+
lambda: translate.parsing_token_generator(
8686
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13),
87-
lambda: wmt.parsing_token_generator(
87+
lambda: translate.parsing_token_generator(
8888
FLAGS.data_dir, FLAGS.tmp_dir, False, 2**13)),
8989
"parsing_english_ptb16k": (
9090
lambda: wsj_parsing.parsing_token_generator(

tensor2tensor/bin/t2t-decoder

100755100644
+1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def main(_):
8484

8585
decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
8686
decode_hp.add_hparam("shards", FLAGS.decode_shards)
87+
decode_hp.add_hparam("shard_id", FLAGS.worker_id)
8788
if FLAGS.decode_interactive:
8889
decoding.decode_interactively(estimator, decode_hp)
8990
elif FLAGS.decode_from_file:

tensor2tensor/bin/t2t-make-tf-configs

100755100644
File mode changed.

tensor2tensor/bin/t2t-trainer

100755100644
File mode changed.

tensor2tensor/data_generators/README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ All tasks produce TFRecord files of `tensorflow.Example` protocol buffers.
2323
To add a new problem, subclass
2424
[`Problem`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py)
2525
and register it with `@registry.register_problem`. See
26-
[`WMTEnDeTokens8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py)
26+
[`TranslateEndeWmt8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/translate_ende.py)
2727
for an example.
2828

2929
`Problem`s support data generation, training, and decoding.
@@ -37,7 +37,7 @@ for training/decoding, e.g. a vocabulary file.
3737
A particularly easy way to implement `Problem.generate_data` for your dataset is
3838
to create 2 Python generators, one for the training data and another for the
3939
dev data, and pass them to `generator_utils.generate_dataset_and_shuffle`. See
40-
[`WMTEnDeTokens8k.generate_data`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py)
40+
[`TranslateEndeWmt8k.generate_data`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/translate_ende.py)
4141
for an example of usage.
4242

4343
The generators should yield dictionaries with string keys and values being lists
@@ -66,5 +66,5 @@ Some examples:
6666

6767
* [Algorithmic problems](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/algorithmic.py)
6868
and their [unit tests](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/algorithmic_test.py)
69-
* [WMT problems](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py)
69+
* [WMT En-De problems](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/translate_ende.py)
7070
and their [unit tests](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt_test.py)

tensor2tensor/data_generators/all_problems.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@
2929
from tensor2tensor.data_generators import image
3030
from tensor2tensor.data_generators import imdb
3131
from tensor2tensor.data_generators import lm1b
32+
from tensor2tensor.data_generators import multinli
3233
from tensor2tensor.data_generators import problem_hparams
3334
from tensor2tensor.data_generators import ptb
3435
from tensor2tensor.data_generators import snli
35-
from tensor2tensor.data_generators import wiki
36-
from tensor2tensor.data_generators import translate
37-
from tensor2tensor.data_generators import translate_enfr
38-
from tensor2tensor.data_generators import translate_ende
3936
from tensor2tensor.data_generators import translate_encs
40-
from tensor2tensor.data_generators import translate_enzh
37+
from tensor2tensor.data_generators import translate_ende
38+
from tensor2tensor.data_generators import translate_enfr
4139
from tensor2tensor.data_generators import translate_enmk
40+
from tensor2tensor.data_generators import translate_enzh
41+
from tensor2tensor.data_generators import wiki
4242
from tensor2tensor.data_generators import wsj_parsing
4343

4444

tensor2tensor/data_generators/cnn_dailymail.py

+53-32
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import hashlib
2223
import os
2324
import tarfile
24-
import hashlib
2525

2626
# Dependency imports
2727

@@ -39,6 +39,7 @@
3939

4040
_DAILYMAIL_STORIES_DRIVE_URL = "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs"
4141

42+
4243
# Note: using See et al. (2017) as reference for data generation
4344
# For more info, use the links below
4445

@@ -47,23 +48,29 @@
4748
_DEV_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt"
4849
_TEST_URLS = "https://github.com/abisee/cnn-dailymail/blob/master/url_lists/all_test.txt"
4950

51+
5052
# End-of-sentence marker.
5153
EOS = text_encoder.EOS_ID
5254

55+
5356
# Techniques for data prep from See et al. (2017)
54-
dm_single_close_quote = u'\u2019' # unicode
55-
dm_double_close_quote = u'\u201d'
56-
END_TOKENS = [u'.', u'!', u'?', u'...', u"'", u"`", u'"', dm_single_close_quote, dm_double_close_quote, u")"] # acceptable ways to end a sentence
57+
dm_single_close_quote = u"\u2019" # unicode
58+
dm_double_close_quote = u"\u201d"
59+
# Acceptable ways to end a sentence.
60+
END_TOKENS = [u".", u"!", u"?", u"...", u"'", u"`", u"\"",
61+
dm_single_close_quote, dm_double_close_quote, u")"]
5762

5863

5964
def _maybe_download_corpora(tmp_dir, is_training):
6065
"""Download corpora if necessary and unzip them.
6166
6267
Args:
6368
tmp_dir: directory containing dataset.
69+
is_training: whether we're in training mode or not.
6470
6571
Returns:
66-
list of all files generated and path to file containing train/dev/test split info.
72+
List of all files generated and path to file containing
73+
train/dev/test split info.
6774
"""
6875
cnn_filename = "cnn_stories.tgz"
6976
cnn_finalpath = os.path.join(tmp_dir, "cnn/stories/")
@@ -85,43 +92,52 @@ def _maybe_download_corpora(tmp_dir, is_training):
8592
all_files = cnn_files + dailymail_files
8693

8794
if is_training:
88-
urls_path = generator_utils.maybe_download(tmp_dir, "all_train.txt", _TRAIN_URLS)
95+
urls_path = generator_utils.maybe_download(
96+
tmp_dir, "all_train.txt", _TRAIN_URLS)
8997
else:
90-
urls_path = generator_utils.maybe_download(tmp_dir, "all_val.txt", _DEV_URLS)
98+
urls_path = generator_utils.maybe_download(
99+
tmp_dir, "all_val.txt", _DEV_URLS)
91100

92101
return all_files, urls_path
93102

103+
94104
def example_splits(url_file, all_files):
105+
"""Generate splits of the data."""
95106
def generate_hash(inp):
96-
"""Generate a sha1 hash to match the raw url to the filename extracted"""
97-
h = hashlib.sha1()
98-
h.update(inp)
99-
return h.hexdigest()
107+
"""Generate a sha1 hash to match the raw url to the filename extracted."""
108+
h = hashlib.sha1()
109+
h.update(inp)
110+
return h.hexdigest()
100111

101-
all_files_map = {f.split("/")[-1]:f for f in all_files}
112+
all_files_map = {f.split("/")[-1]: f for f in all_files}
102113

103114
urls = []
104115
for line in tf.gfile.Open(url_file):
105-
urls.append(line.strip().encode('utf-8'))
116+
urls.append(line.strip().encode("utf-8"))
106117

107118
filelist = []
108119
for url in urls:
109-
url_hash = generate_hash(url)
110-
filename = url_hash + ".story"
111-
if filename not in all_files_map:
112-
tf.logging.info("Missing file: %s" % url)
113-
continue
114-
filelist.append(all_files_map[filename])
120+
url_hash = generate_hash(url)
121+
filename = url_hash + ".story"
122+
if filename not in all_files_map:
123+
tf.logging.info("Missing file: %s" % url)
124+
continue
125+
filelist.append(all_files_map[filename])
115126

116127
tf.logging.info("Found %d examples" % len(filelist))
117128

118129
return filelist
119130

131+
120132
def example_generator(tmp_dir, is_training, sum_token):
133+
"""Generate examples."""
121134
def fix_run_on_sents(line):
122-
if u"@highlight" in line: return line
123-
if line=="": return line
124-
if line[-1] in END_TOKENS: return line
135+
if u"@highlight" in line:
136+
return line
137+
if not line:
138+
return line
139+
if line[-1] in END_TOKENS:
140+
return line
125141
return line + u"."
126142

127143
all_files, urls_path = _maybe_download_corpora(tmp_dir, is_training)
@@ -133,28 +149,33 @@ def fix_run_on_sents(line):
133149
summary = []
134150
reading_highlights = False
135151
for line in tf.gfile.Open(story_file, "rb"):
136-
line = unicode(line.strip(), "utf-8") if six.PY2 else line.strip().decode("utf-8")
152+
if six.PY2:
153+
line = unicode(line.strip(), "utf-8")
154+
else:
155+
line = line.strip().decode("utf-8")
137156
line = fix_run_on_sents(line)
138-
if line == "":
139-
continue
157+
if not line:
158+
continue
140159
elif line.startswith(u"@highlight"):
141-
if len(story) == 0: break # No article text
142-
reading_highlights = True
160+
if not story:
161+
break # No article text.
162+
reading_highlights = True
143163
elif reading_highlights:
144-
summary.append(line)
164+
summary.append(line)
145165
else:
146-
story.append(line)
166+
story.append(line)
147167

148-
if len(story) == 0 or len(summary) == 0:
149-
continue
168+
if (not story) or not summary:
169+
continue
150170

151171
yield " ".join(story) + story_summary_split_token + " ".join(summary)
152172

173+
153174
def _story_summary_split(story):
154175
split_str = u" <summary> "
155176
split_str_len = len(split_str)
156177
split_pos = story.find(split_str)
157-
return story[:split_pos], story[split_pos+split_str_len:] # story, summary
178+
return story[:split_pos], story[split_pos+split_str_len:] # story, summary
158179

159180

160181
@registry.register_problem

tensor2tensor/data_generators/generator_utils.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ def gunzip_file(gz_path, new_path):
263263
for line in gz_file:
264264
new_file.write(line)
265265

266+
266267
def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,
267268
generator):
268269
"""Inner implementation for vocab generators.
@@ -301,10 +302,7 @@ def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,
301302
return vocab
302303

303304

304-
def get_or_generate_vocab(data_dir,
305-
tmp_dir,
306-
vocab_filename,
307-
vocab_size,
305+
def get_or_generate_vocab(data_dir, tmp_dir, vocab_filename, vocab_size,
308306
sources):
309307
"""Generate a vocabulary from the datasets in sources."""
310308

tensor2tensor/data_generators/ice_parsing.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from tensor2tensor.data_generators import generator_utils
3333
from tensor2tensor.data_generators import problem
3434
from tensor2tensor.data_generators import text_encoder
35-
from tensor2tensor.data_generators.translate import tabbed_generator
35+
from tensor2tensor.data_generators import translate
3636
from tensor2tensor.utils import registry
3737

3838

@@ -51,15 +51,17 @@ def tabbed_parsing_token_generator(data_dir, tmp_dir, train, prefix,
5151
data_dir, tmp_dir, filename, 1,
5252
prefix + "_target.tokens.vocab.%d" % target_vocab_size, target_vocab_size)
5353
pair_filepath = os.path.join(tmp_dir, filename)
54-
return tabbed_generator(pair_filepath, source_vocab, target_vocab, EOS)
54+
return translate.tabbed_generator(pair_filepath, source_vocab, target_vocab,
55+
EOS)
5556

5657

5758
def tabbed_parsing_character_generator(tmp_dir, train):
5859
"""Generate source and target data from a single file."""
5960
character_vocab = text_encoder.ByteTextEncoder()
6061
filename = "parsing_{0}.pairs".format("train" if train else "dev")
6162
pair_filepath = os.path.join(tmp_dir, filename)
62-
return tabbed_generator(pair_filepath, character_vocab, character_vocab, EOS)
63+
return translate.tabbed_generator(pair_filepath, character_vocab,
64+
character_vocab, EOS)
6365

6466

6567
@registry.register_problem

tensor2tensor/data_generators/image.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def feature_encoders(self, data_dir):
227227
# This vocab file must be present within the data directory.
228228
vocab_filename = os.path.join(data_dir, "charset_size134.txt")
229229
return {
230-
"inputs": text_encoder.TextEncoder(),
230+
"inputs": text_encoder.ImageEncoder(),
231231
"targets": text_encoder.SubwordTextEncoder(vocab_filename)
232232
}
233233

@@ -273,7 +273,7 @@ def class_labels(self):
273273
def feature_encoders(self, data_dir):
274274
del data_dir
275275
return {
276-
"inputs": text_encoder.TextEncoder(),
276+
"inputs": text_encoder.ImageEncoder(),
277277
"targets": text_encoder.ClassLabelEncoder(self.class_labels)
278278
}
279279

0 commit comments

Comments
 (0)