Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 54622a5

Browse files
author
Lukasz Kaiser
committed
Improving data generation (removing some problems too) and adding eval printouts.
PiperOrigin-RevId: 161265195
1 parent 5a06e7a commit 54622a5

File tree

12 files changed

+93
-93
lines changed

12 files changed

+93
-93
lines changed

Diff for: README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ t2t-datagen \
7272
--num_shards=100 \
7373
--problem=$PROBLEM
7474
75-
mv $TMP_DIR/tokens.vocab.32768 $DATA_DIR
75+
cp $TMP_DIR/tokens.vocab.* $DATA_DIR
7676
7777
# Train
7878
# * If you run out of memory, add --hparams='batch_size=2048' or even 1024.

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.0.10',
8+
version='1.0.12',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='[email protected]',

Diff for: tensor2tensor/bin/t2t-datagen

+42-63
Original file line numberDiff line numberDiff line change
@@ -90,25 +90,16 @@ _SUPPORTED_PROBLEM_GENERATORS = {
9090
"algorithmic_reverse_nlplike_decimal8K": (
9191
lambda: algorithmic.reverse_generator_nlplike(8000, 70, 100000,
9292
10, 1.300),
93-
lambda: algorithmic.reverse_generator_nlplike(8000, 700, 10000,
93+
lambda: algorithmic.reverse_generator_nlplike(8000, 70, 10000,
9494
10, 1.300)),
9595
"algorithmic_reverse_nlplike_decimal32K": (
9696
lambda: algorithmic.reverse_generator_nlplike(32000, 70, 100000,
9797
10, 1.050),
98-
lambda: algorithmic.reverse_generator_nlplike(32000, 700, 10000,
98+
lambda: algorithmic.reverse_generator_nlplike(32000, 70, 10000,
9999
10, 1.050)),
100100
"algorithmic_algebra_inverse": (
101101
lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
102102
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
103-
"algorithmic_algebra_simplify": (
104-
lambda: algorithmic_math.algebra_simplify(8, 0, 2, 100000),
105-
lambda: algorithmic_math.algebra_simplify(8, 3, 3, 10000)),
106-
"algorithmic_calculus_integrate": (
107-
lambda: algorithmic_math.calculus_integrate(8, 0, 2, 100000),
108-
lambda: algorithmic_math.calculus_integrate(8, 3, 3, 10000)),
109-
"wmt_parsing_characters": (
110-
lambda: wmt.parsing_character_generator(FLAGS.tmp_dir, True),
111-
lambda: wmt.parsing_character_generator(FLAGS.tmp_dir, False)),
112103
"wmt_parsing_tokens_8k": (
113104
lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, True, 2**13),
114105
lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, False, 2**13)),
@@ -133,10 +124,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
133124
lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15),
134125
lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15)
135126
),
136-
"wmt_enfr_tokens_128k": (
137-
lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**17),
138-
lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**17)
139-
),
140127
"wmt_ende_characters": (
141128
lambda: wmt.ende_character_generator(FLAGS.tmp_dir, True),
142129
lambda: wmt.ende_character_generator(FLAGS.tmp_dir, False)),
@@ -151,10 +138,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
151138
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15),
152139
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15)
153140
),
154-
"wmt_ende_tokens_128k": (
155-
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**17),
156-
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**17)
157-
),
158141
"image_mnist_tune": (
159142
lambda: image.mnist_generator(FLAGS.tmp_dir, True, 55000),
160143
lambda: image.mnist_generator(FLAGS.tmp_dir, True, 5000, 55000)),
@@ -227,33 +210,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
227210
40000,
228211
vocab_filename="tokens.vocab.%d" % 2**15,
229212
vocab_size=2**15)),
230-
"image_mscoco_tokens_128k_tune": (
231-
lambda: image.mscoco_generator(
232-
FLAGS.tmp_dir,
233-
True,
234-
70000,
235-
vocab_filename="tokens.vocab.%d" % 2**17,
236-
vocab_size=2**17),
237-
lambda: image.mscoco_generator(
238-
FLAGS.tmp_dir,
239-
True,
240-
10000,
241-
70000,
242-
vocab_filename="tokens.vocab.%d" % 2**17,
243-
vocab_size=2**17)),
244-
"image_mscoco_tokens_128k_test": (
245-
lambda: image.mscoco_generator(
246-
FLAGS.tmp_dir,
247-
True,
248-
80000,
249-
vocab_filename="tokens.vocab.%d" % 2**17,
250-
vocab_size=2**17),
251-
lambda: image.mscoco_generator(
252-
FLAGS.tmp_dir,
253-
False,
254-
40000,
255-
vocab_filename="tokens.vocab.%d" % 2**17,
256-
vocab_size=2**17)),
257213
"snli_32k": (
258214
lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15),
259215
lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15),
@@ -340,10 +296,31 @@ def set_random_seed():
340296

341297
def main(_):
342298
tf.logging.set_verbosity(tf.logging.INFO)
343-
if FLAGS.problem not in _SUPPORTED_PROBLEM_GENERATORS:
299+
300+
# Calculate the list of problems to generate.
301+
problems = list(sorted(_SUPPORTED_PROBLEM_GENERATORS))
302+
if FLAGS.problem and FLAGS.problem[-1] == "*":
303+
problems = [p for p in problems if p.startswith(FLAGS.problem[:-1])]
304+
elif FLAGS.problem:
305+
problems = [p for p in problems if p == FLAGS.problem]
306+
else:
307+
problems = []
308+
# Remove TIMIT if paths are not given.
309+
if not FLAGS.timit_paths:
310+
problems = [p for p in problems if "timit" not in p]
311+
# Remove parsing if paths are not given.
312+
if not FLAGS.parsing_path:
313+
problems = [p for p in problems if "parsing" not in p]
314+
# Remove en-de BPE if paths are not given.
315+
if not FLAGS.ende_bpe_path:
316+
problems = [p for p in problems if "ende_bpe" not in p]
317+
318+
if not problems:
344319
problems_str = "\n * ".join(sorted(_SUPPORTED_PROBLEM_GENERATORS))
345320
error_msg = ("You must specify one of the supported problems to "
346321
"generate data for:\n * " + problems_str + "\n")
322+
error_msg += ("TIMIT, ende_bpe and parsing need data_sets specified with "
323+
"--timit_paths, --ende_bpe_path and --parsing_path.")
347324
raise ValueError(error_msg)
348325

349326
if not FLAGS.data_dir:
@@ -352,26 +329,28 @@ def main(_):
352329
"Data will be written to default data_dir=%s.",
353330
FLAGS.data_dir)
354331

355-
set_random_seed()
332+
tf.logging.info("Generating problems:\n * %s\n" % "\n * ".join(problems))
333+
for problem in problems:
334+
set_random_seed()
356335

357-
training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[FLAGS.problem]
336+
training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem]
358337

359-
tf.logging.info("Generating training data for %s.", FLAGS.problem)
360-
train_output_files = generator_utils.generate_files(
361-
training_gen(), FLAGS.problem + UNSHUFFLED_SUFFIX + "-train",
362-
FLAGS.data_dir, FLAGS.num_shards, FLAGS.max_cases)
338+
tf.logging.info("Generating training data for %s.", problem)
339+
train_output_files = generator_utils.generate_files(
340+
training_gen(), problem + UNSHUFFLED_SUFFIX + "-train",
341+
FLAGS.data_dir, FLAGS.num_shards, FLAGS.max_cases)
363342

364-
tf.logging.info("Generating development data for %s.", FLAGS.problem)
365-
dev_output_files = generator_utils.generate_files(
366-
dev_gen(), FLAGS.problem + UNSHUFFLED_SUFFIX + "-dev", FLAGS.data_dir, 1)
343+
tf.logging.info("Generating development data for %s.", problem)
344+
dev_output_files = generator_utils.generate_files(
345+
dev_gen(), problem + UNSHUFFLED_SUFFIX + "-dev", FLAGS.data_dir, 1)
367346

368-
tf.logging.info("Shuffling data...")
369-
for fname in train_output_files + dev_output_files:
370-
records = generator_utils.read_records(fname)
371-
random.shuffle(records)
372-
out_fname = fname.replace(UNSHUFFLED_SUFFIX, "")
373-
generator_utils.write_records(records, out_fname)
374-
tf.gfile.Remove(fname)
347+
tf.logging.info("Shuffling data...")
348+
for fname in train_output_files + dev_output_files:
349+
records = generator_utils.read_records(fname)
350+
random.shuffle(records)
351+
out_fname = fname.replace(UNSHUFFLED_SUFFIX, "")
352+
generator_utils.write_records(records, out_fname)
353+
tf.gfile.Remove(fname)
375354

376355

377356
if __name__ == "__main__":

Diff for: tensor2tensor/data_generators/algorithmic_math.py

+2
Original file line numberDiff line numberDiff line change
@@ -582,4 +582,6 @@ def calculus_integrate(alphabet_size=26,
582582
}
583583
except: # pylint:disable=bare-except
584584
continue
585+
if nbr_case % 10000 == 0:
586+
print(" calculus_integrate: generating case %d." % nbr_case)
585587
nbr_case += 1

Diff for: tensor2tensor/data_generators/generator_utils.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,14 @@ def to_example(dictionary):
4646
elif isinstance(v[0], float):
4747
features[k] = tf.train.Feature(float_list=tf.train.FloatList(value=v))
4848
elif isinstance(v[0], six.string_types):
49+
if not six.PY2: # Convert in python 3.
50+
v = [bytes(x, "utf-8") for x in v]
51+
features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v))
52+
elif isinstance(v[0], bytes):
4953
features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v))
5054
else:
51-
raise ValueError("Value is neither an int nor a float; v: %s type: %s" %
52-
(str(v[0]), str(type(v[0]))))
55+
raise ValueError("Value for %s is not a recognized type; v: %s type: %s" %
56+
(k, str(v[0]), str(type(v[0]))))
5357
return tf.train.Example(features=tf.train.Features(feature=features))
5458

5559

@@ -111,7 +115,7 @@ def generate_files(generator,
111115

112116
counter, shard = 0, 0
113117
for case in generator:
114-
if counter % 100000 == 0:
118+
if counter > 0 and counter % 100000 == 0:
115119
tf.logging.info("Generating case %d for %s." % (counter, output_name))
116120
counter += 1
117121
if max_cases and counter > max_cases:
@@ -176,6 +180,9 @@ def gunzip_file(gz_path, new_path):
176180
gz_path: path to the zipped file.
177181
new_path: path to where the file will be unzipped.
178182
"""
183+
if tf.gfile.Exists(new_path):
184+
tf.logging.info("File %s already exists, skipping unpacking" % new_path)
185+
return
179186
tf.logging.info("Unpacking %s to %s" % (gz_path, new_path))
180187
with gzip.open(gz_path, "rb") as gz_file:
181188
with io.open(new_path, "wb") as new_file:
@@ -221,7 +228,7 @@ def gunzip_file(gz_path, new_path):
221228
def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None):
222229
"""Generate a vocabulary from the datasets in sources (_DATA_FILE_URLS)."""
223230
vocab_filepath = os.path.join(tmp_dir, vocab_filename)
224-
if os.path.exists(vocab_filepath):
231+
if tf.gfile.Exists(vocab_filepath):
225232
tf.logging.info("Found vocab file: %s", vocab_filepath)
226233
vocab = text_encoder.SubwordTextEncoder(vocab_filepath)
227234
return vocab
@@ -246,7 +253,7 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None):
246253
# For some datasets a second extraction is necessary.
247254
if ".gz" in lang_file:
248255
new_filepath = os.path.join(tmp_dir, lang_file[:-3])
249-
if os.path.exists(new_filepath):
256+
if tf.gfile.Exists(new_filepath):
250257
tf.logging.info("Subdirectory %s already exists, skipping unpacking"
251258
% filepath)
252259
else:
@@ -275,7 +282,7 @@ def read_records(filename):
275282
records = []
276283
for record in reader:
277284
records.append(record)
278-
if len(records) % 10000 == 0:
285+
if len(records) % 100000 == 0:
279286
tf.logging.info("read: %d", len(records))
280287
return records
281288

@@ -284,6 +291,6 @@ def write_records(records, out_filename):
284291
writer = tf.python_io.TFRecordWriter(out_filename)
285292
for count, record in enumerate(records):
286293
writer.write(record)
287-
if count % 10000 == 0:
294+
if count > 0 and count % 100000 == 0:
288295
tf.logging.info("write: %d", count)
289296
writer.close()

Diff for: tensor2tensor/data_generators/image.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def image_generator(images, labels):
6868
yield {
6969
"image/encoded": [enc_string],
7070
"image/format": ["png"],
71-
"image/class/label": [label],
71+
"image/class/label": [int(label)],
7272
"image/height": [height],
7373
"image/width": [width]
7474
}

Diff for: tensor2tensor/data_generators/wmt.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,19 @@
2525

2626
from tensor2tensor.data_generators import generator_utils
2727
from tensor2tensor.data_generators import text_encoder
28+
from tensor2tensor.data_generators import wsj_parsing
2829

2930
import tensorflow as tf
3031

3132

33+
tf.flags.DEFINE_string("ende_bpe_path", "", "Path to BPE files in tmp_dir."
34+
"Download from https://drive.google.com/open?"
35+
"id=0B_bZck-ksdkpM25jRUN2X2UxMm8")
36+
37+
38+
FLAGS = tf.flags.FLAGS
39+
40+
3241
# End-of-sentence marker (should correspond to the position of EOS in the
3342
# RESERVED_TOKENS list in text_encoder.py)
3443
EOS = 1
@@ -100,7 +109,7 @@ def _get_wmt_ende_dataset(directory, filename):
100109
# We expect that this file has been downloaded from:
101110
# https://drive.google.com/open?id=0B_bZck-ksdkpM25jRUN2X2UxMm8 and placed
102111
# in `directory`.
103-
corpus_file = os.path.join(directory, "wmt16_en_de.tar.gz")
112+
corpus_file = os.path.join(directory, FLAGS.ende_bpe_path)
104113
with tarfile.open(corpus_file, "r:gz") as corpus_tar:
105114
corpus_tar.extractall(directory)
106115
return train_path
@@ -265,18 +274,10 @@ def enfr_character_generator(tmp_dir, train):
265274
character_vocab, EOS)
266275

267276

268-
def parsing_character_generator(tmp_dir, train):
269-
character_vocab = text_encoder.ByteTextEncoder()
270-
filename = "parsing_%s" % ("train" if train else "dev")
271-
text_filepath = os.path.join(tmp_dir, filename + ".text")
272-
tags_filepath = os.path.join(tmp_dir, filename + ".tags")
273-
return character_generator(text_filepath, tags_filepath, character_vocab, EOS)
274-
275-
276277
def parsing_token_generator(tmp_dir, train, vocab_size):
277278
symbolizer_vocab = generator_utils.get_or_generate_vocab(
278279
tmp_dir, "tokens.vocab.%d" % vocab_size, vocab_size)
279-
filename = "parsing_%s" % ("train" if train else "dev")
280-
text_filepath = os.path.join(tmp_dir, filename + ".text")
281-
tags_filepath = os.path.join(tmp_dir, filename + ".tags")
282-
return token_generator(text_filepath, tags_filepath, symbolizer_vocab, EOS)
280+
filename = "%s_%s.trees" % (FLAGS.parsing_path, "train" if train else "dev")
281+
tree_filepath = os.path.join(tmp_dir, filename)
282+
return wsj_parsing.token_generator(tree_filepath,
283+
symbolizer_vocab, symbolizer_vocab, EOS)

Diff for: tensor2tensor/data_generators/wsj_parsing.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323
import tensorflow as tf
2424

2525

26+
tf.flags.DEFINE_string("parsing_path", "", "Path to parsing files in tmp_dir.")
27+
28+
29+
FLAGS = tf.flags.FLAGS
30+
31+
2632
def words_and_tags_from_wsj_tree(tree_string):
2733
"""Generates linearized trees and tokens from the wsj tree format.
2834
@@ -84,9 +90,8 @@ def parsing_token_generator(tmp_dir, train, source_vocab_size,
8490
target_vocab_size):
8591
"""Generator for parsing as a sequence-to-sequence task that uses tokens.
8692
87-
This generator assumes the files parsing_{train,dev}.wsj, which contain trees
88-
in wsj format and wsj_{source,target}.tokens.vocab.<vocab_size> exist in
89-
tmp_dir.
93+
This generator assumes the files parsing_{train,dev}.trees, which contain
94+
trees in wsj format.
9095
9196
Args:
9297
tmp_dir: path to the file with source sentences.
@@ -103,7 +108,7 @@ def parsing_token_generator(tmp_dir, train, source_vocab_size,
103108
target_symbolizer_vocab = generator_utils.get_or_generate_vocab(
104109
tmp_dir, "wsj_target.tokens.vocab.%d" % target_vocab_size,
105110
target_vocab_size)
106-
filename = "parsing_%s.trees" % ("train" if train else "dev")
111+
filename = "%s_%s.trees" % (FLAGS.parsing_path, "train" if train else "dev")
107112
tree_filepath = os.path.join(tmp_dir, filename)
108113
return token_generator(tree_filepath, source_symbolizer_vocab,
109114
target_symbolizer_vocab, 1)

Diff for: tensor2tensor/models/lstm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def model_fn_body(self, features):
268268
def lstm_attention():
269269
"""hparams for LSTM with attention."""
270270
hparams = common_hparams.basic_params1()
271-
hparams.batch_size = 128
271+
hparams.batch_size = 1024
272272
hparams.hidden_size = 128
273273
hparams.num_hidden_layers = 2
274274

Diff for: tensor2tensor/models/transformer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def model_fn_body(self, features):
4848
inputs = features.get("inputs")
4949
target_space = features.get("target_space_id")
5050

51-
inputs = tf.squeeze(inputs, 2)
52-
targets = tf.squeeze(targets, 2)
51+
inputs = common_layers.flatten4d3d(inputs)
52+
targets = common_layers.flatten4d3d(targets)
5353

5454
(encoder_input, encoder_attention_bias, _) = (transformer_prepare_encoder(
5555
inputs, target_space, hparams))

Diff for: tensor2tensor/utils/t2t_model.py

+2
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def _create_modalities(self, problem_hparams, hparams):
124124
problem_hparams.input_modality = input_modality
125125

126126
target_modality_spec = problem_hparams.target_modality
127+
if isinstance(target_modality_spec, modality.Modality):
128+
return
127129
if target_modality_name:
128130
_warn_changed_modality_type(target_modality_name, target_modality_spec[0],
129131
"target")

Diff for: tensor2tensor/utils/trainer_utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
flags.DEFINE_integer("train_steps", 250000,
7070
"The number of steps to run training for.")
7171
flags.DEFINE_integer("eval_steps", 10, "Number of steps in evaluation.")
72+
flags.DEFINE_bool("eval_print", False, "Print eval logits and predictions.")
7273
flags.DEFINE_integer("keep_checkpoint_max", 20,
7374
"How many recent checkpoints to keep.")
7475
flags.DEFINE_bool("experimental_optimize_placement", False,
@@ -452,6 +453,9 @@ def nth_model(n):
452453
sharded_logits, total_loss = result_list[1:], result_list[0]
453454
if mode == tf.contrib.learn.ModeKeys.EVAL:
454455
logits = tf.concat(sharded_logits, 0)
456+
if FLAGS.eval_print:
457+
logits = tf.Print(logits, [features["inputs"], logits],
458+
"EVAL PRINT", summarize=10000)
455459
# For evaluation, return the logits layer as our predictions.
456460
run_info["predictions"] = logits
457461
train_op = None

0 commit comments

Comments
 (0)