@@ -90,25 +90,16 @@ _SUPPORTED_PROBLEM_GENERATORS = {
90
90
"algorithmic_reverse_nlplike_decimal8K" : (
91
91
lambda : algorithmic .reverse_generator_nlplike (8000 , 70 , 100000 ,
92
92
10 , 1.300 ),
93
- lambda : algorithmic .reverse_generator_nlplike (8000 , 700 , 10000 ,
93
+ lambda : algorithmic .reverse_generator_nlplike (8000 , 70 , 10000 ,
94
94
10 , 1.300 )),
95
95
"algorithmic_reverse_nlplike_decimal32K" : (
96
96
lambda : algorithmic .reverse_generator_nlplike (32000 , 70 , 100000 ,
97
97
10 , 1.050 ),
98
- lambda : algorithmic .reverse_generator_nlplike (32000 , 700 , 10000 ,
98
+ lambda : algorithmic .reverse_generator_nlplike (32000 , 70 , 10000 ,
99
99
10 , 1.050 )),
100
100
"algorithmic_algebra_inverse" : (
101
101
lambda : algorithmic_math .algebra_inverse (26 , 0 , 2 , 100000 ),
102
102
lambda : algorithmic_math .algebra_inverse (26 , 3 , 3 , 10000 )),
103
- "algorithmic_algebra_simplify" : (
104
- lambda : algorithmic_math .algebra_simplify (8 , 0 , 2 , 100000 ),
105
- lambda : algorithmic_math .algebra_simplify (8 , 3 , 3 , 10000 )),
106
- "algorithmic_calculus_integrate" : (
107
- lambda : algorithmic_math .calculus_integrate (8 , 0 , 2 , 100000 ),
108
- lambda : algorithmic_math .calculus_integrate (8 , 3 , 3 , 10000 )),
109
- "wmt_parsing_characters" : (
110
- lambda : wmt .parsing_character_generator (FLAGS .tmp_dir , True ),
111
- lambda : wmt .parsing_character_generator (FLAGS .tmp_dir , False )),
112
103
"wmt_parsing_tokens_8k" : (
113
104
lambda : wmt .parsing_token_generator (FLAGS .tmp_dir , True , 2 ** 13 ),
114
105
lambda : wmt .parsing_token_generator (FLAGS .tmp_dir , False , 2 ** 13 )),
@@ -133,10 +124,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
133
124
lambda : wmt .enfr_wordpiece_token_generator (FLAGS .tmp_dir , True , 2 ** 15 ),
134
125
lambda : wmt .enfr_wordpiece_token_generator (FLAGS .tmp_dir , False , 2 ** 15 )
135
126
),
136
- "wmt_enfr_tokens_128k" : (
137
- lambda : wmt .enfr_wordpiece_token_generator (FLAGS .tmp_dir , True , 2 ** 17 ),
138
- lambda : wmt .enfr_wordpiece_token_generator (FLAGS .tmp_dir , False , 2 ** 17 )
139
- ),
140
127
"wmt_ende_characters" : (
141
128
lambda : wmt .ende_character_generator (FLAGS .tmp_dir , True ),
142
129
lambda : wmt .ende_character_generator (FLAGS .tmp_dir , False )),
@@ -151,10 +138,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
151
138
lambda : wmt .ende_wordpiece_token_generator (FLAGS .tmp_dir , True , 2 ** 15 ),
152
139
lambda : wmt .ende_wordpiece_token_generator (FLAGS .tmp_dir , False , 2 ** 15 )
153
140
),
154
- "wmt_ende_tokens_128k" : (
155
- lambda : wmt .ende_wordpiece_token_generator (FLAGS .tmp_dir , True , 2 ** 17 ),
156
- lambda : wmt .ende_wordpiece_token_generator (FLAGS .tmp_dir , False , 2 ** 17 )
157
- ),
158
141
"image_mnist_tune" : (
159
142
lambda : image .mnist_generator (FLAGS .tmp_dir , True , 55000 ),
160
143
lambda : image .mnist_generator (FLAGS .tmp_dir , True , 5000 , 55000 )),
@@ -227,33 +210,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
227
210
40000 ,
228
211
vocab_filename = "tokens.vocab.%d" % 2 ** 15 ,
229
212
vocab_size = 2 ** 15 )),
230
- "image_mscoco_tokens_128k_tune" : (
231
- lambda : image .mscoco_generator (
232
- FLAGS .tmp_dir ,
233
- True ,
234
- 70000 ,
235
- vocab_filename = "tokens.vocab.%d" % 2 ** 17 ,
236
- vocab_size = 2 ** 17 ),
237
- lambda : image .mscoco_generator (
238
- FLAGS .tmp_dir ,
239
- True ,
240
- 10000 ,
241
- 70000 ,
242
- vocab_filename = "tokens.vocab.%d" % 2 ** 17 ,
243
- vocab_size = 2 ** 17 )),
244
- "image_mscoco_tokens_128k_test" : (
245
- lambda : image .mscoco_generator (
246
- FLAGS .tmp_dir ,
247
- True ,
248
- 80000 ,
249
- vocab_filename = "tokens.vocab.%d" % 2 ** 17 ,
250
- vocab_size = 2 ** 17 ),
251
- lambda : image .mscoco_generator (
252
- FLAGS .tmp_dir ,
253
- False ,
254
- 40000 ,
255
- vocab_filename = "tokens.vocab.%d" % 2 ** 17 ,
256
- vocab_size = 2 ** 17 )),
257
213
"snli_32k" : (
258
214
lambda : snli .snli_token_generator (FLAGS .tmp_dir , True , 2 ** 15 ),
259
215
lambda : snli .snli_token_generator (FLAGS .tmp_dir , False , 2 ** 15 ),
@@ -340,10 +296,31 @@ def set_random_seed():
340
296
341
297
def main (_ ):
342
298
tf .logging .set_verbosity (tf .logging .INFO )
343
- if FLAGS .problem not in _SUPPORTED_PROBLEM_GENERATORS :
299
+
300
+ # Calculate the list of problems to generate.
301
+ problems = list (sorted (_SUPPORTED_PROBLEM_GENERATORS ))
302
+ if FLAGS .problem and FLAGS .problem [- 1 ] == "*" :
303
+ problems = [p for p in problems if p .startswith (FLAGS .problem [:- 1 ])]
304
+ elif FLAGS .problem :
305
+ problems = [p for p in problems if p == FLAGS .problem ]
306
+ else :
307
+ problems = []
308
+ # Remove TIMIT if paths are not given.
309
+ if not FLAGS .timit_paths :
310
+ problems = [p for p in problems if "timit" not in p ]
311
+ # Remove parsing if paths are not given.
312
+ if not FLAGS .parsing_path :
313
+ problems = [p for p in problems if "parsing" not in p ]
314
+ # Remove en-de BPE if paths are not given.
315
+ if not FLAGS .ende_bpe_path :
316
+ problems = [p for p in problems if "ende_bpe" not in p ]
317
+
318
+ if not problems :
344
319
problems_str = "\n * " .join (sorted (_SUPPORTED_PROBLEM_GENERATORS ))
345
320
error_msg = ("You must specify one of the supported problems to "
346
321
"generate data for:\n * " + problems_str + "\n " )
322
+ error_msg += ("TIMIT, ende_bpe and parsing need data_sets specified with "
323
+ "--timit_paths, --ende_bpe_path and --parsing_path." )
347
324
raise ValueError (error_msg )
348
325
349
326
if not FLAGS .data_dir :
@@ -352,26 +329,28 @@ def main(_):
352
329
"Data will be written to default data_dir=%s." ,
353
330
FLAGS .data_dir )
354
331
355
- set_random_seed ()
332
+ tf .logging .info ("Generating problems:\n * %s\n " % "\n * " .join (problems ))
333
+ for problem in problems :
334
+ set_random_seed ()
356
335
357
- training_gen , dev_gen = _SUPPORTED_PROBLEM_GENERATORS [FLAGS . problem ]
336
+ training_gen , dev_gen = _SUPPORTED_PROBLEM_GENERATORS [problem ]
358
337
359
- tf .logging .info ("Generating training data for %s." , FLAGS . problem )
360
- train_output_files = generator_utils .generate_files (
361
- training_gen (), FLAGS . problem + UNSHUFFLED_SUFFIX + "-train" ,
362
- FLAGS .data_dir , FLAGS .num_shards , FLAGS .max_cases )
338
+ tf .logging .info ("Generating training data for %s." , problem )
339
+ train_output_files = generator_utils .generate_files (
340
+ training_gen (), problem + UNSHUFFLED_SUFFIX + "-train" ,
341
+ FLAGS .data_dir , FLAGS .num_shards , FLAGS .max_cases )
363
342
364
- tf .logging .info ("Generating development data for %s." , FLAGS . problem )
365
- dev_output_files = generator_utils .generate_files (
366
- dev_gen (), FLAGS . problem + UNSHUFFLED_SUFFIX + "-dev" , FLAGS .data_dir , 1 )
343
+ tf .logging .info ("Generating development data for %s." , problem )
344
+ dev_output_files = generator_utils .generate_files (
345
+ dev_gen (), problem + UNSHUFFLED_SUFFIX + "-dev" , FLAGS .data_dir , 1 )
367
346
368
- tf .logging .info ("Shuffling data..." )
369
- for fname in train_output_files + dev_output_files :
370
- records = generator_utils .read_records (fname )
371
- random .shuffle (records )
372
- out_fname = fname .replace (UNSHUFFLED_SUFFIX , "" )
373
- generator_utils .write_records (records , out_fname )
374
- tf .gfile .Remove (fname )
347
+ tf .logging .info ("Shuffling data..." )
348
+ for fname in train_output_files + dev_output_files :
349
+ records = generator_utils .read_records (fname )
350
+ random .shuffle (records )
351
+ out_fname = fname .replace (UNSHUFFLED_SUFFIX , "" )
352
+ generator_utils .write_records (records , out_fname )
353
+ tf .gfile .Remove (fname )
375
354
376
355
377
356
if __name__ == "__main__" :
0 commit comments