1
1
#!/usr/bin/env python
2
+ # coding=utf-8
2
3
# Copyright 2017 The Tensor2Tensor Authors.
3
4
#
4
5
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -62,10 +63,12 @@ flags.DEFINE_string("problem", "",
62
63
"The name of the problem to generate data for." )
63
64
flags .DEFINE_string ("exclude_problems" , "" ,
64
65
"Comma-separates list of problems to exclude." )
65
- flags .DEFINE_integer ("num_shards" , 10 , "How many shards to use." )
66
+ flags .DEFINE_integer ("num_shards" , 0 , "How many shards to use. Ignored for "
67
+ "registered Problems." )
66
68
flags .DEFINE_integer ("max_cases" , 0 ,
67
69
"Maximum number of cases to generate (unbounded if 0)." )
68
70
flags .DEFINE_integer ("random_seed" , 429459 , "Random seed to use." )
71
+ flags .DEFINE_integer ("task_id" , - 1 , "For distributed data generation." )
69
72
flags .DEFINE_string ("t2t_usr_dir" , "" ,
70
73
"Path to a Python module that will be imported. The "
71
74
"__init__.py file should include the necessary imports. "
@@ -108,6 +111,10 @@ _SUPPORTED_PROBLEM_GENERATORS = {
108
111
lambda : lm1b .generator (FLAGS .tmp_dir , True ),
109
112
lambda : lm1b .generator (FLAGS .tmp_dir , False )
110
113
),
114
+ "lm1b_characters" : (
115
+ lambda : lm1b .generator (FLAGS .tmp_dir , True , characters = True ),
116
+ lambda : lm1b .generator (FLAGS .tmp_dir , False , characters = True )
117
+ ),
111
118
"wiki_32k" : (
112
119
lambda : wiki .generator (FLAGS .tmp_dir , True ),
113
120
1000
@@ -246,7 +253,7 @@ def generate_data_for_problem(problem):
246
253
if isinstance (dev_gen , int ):
247
254
# The dev set and test sets are generated as extra shards using the
248
255
# training generator. The integer specifies the number of training
249
- # shards. FLAGS.num_shards is ignored.
256
+ # shards. FLAGS.num_shards is ignored.
250
257
num_training_shards = dev_gen
251
258
tf .logging .info ("Generating data for %s." , problem )
252
259
all_output_files = generator_utils .combined_data_filenames (
@@ -257,10 +264,11 @@ def generate_data_for_problem(problem):
257
264
else :
258
265
# usual case - train data and dev data are generated using separate
259
266
# generators.
267
+ num_shards = FLAGS .num_shards or 10
260
268
tf .logging .info ("Generating training data for %s." , problem )
261
269
train_output_files = generator_utils .train_data_filenames (
262
270
problem + generator_utils .UNSHUFFLED_SUFFIX , FLAGS .data_dir ,
263
- FLAGS . num_shards )
271
+ num_shards )
264
272
generator_utils .generate_files (training_gen (), train_output_files ,
265
273
FLAGS .max_cases )
266
274
tf .logging .info ("Generating development data for %s." , problem )
@@ -275,10 +283,14 @@ def generate_data_for_problem(problem):
275
283
276
284
277
285
def generate_data_for_registered_problem (problem_name ):
286
+ tf .logging .info ("Generating training data for %s." , problem_name )
287
+ if FLAGS .num_shards :
288
+ raise ValueError ("--num_shards should not be set for registered Problem." )
278
289
problem = registry .problem (problem_name )
290
+ task_id = None if FLAGS .task_id < 0 else FLAGS .task_id
279
291
problem .generate_data (os .path .expanduser (FLAGS .data_dir ),
280
292
os .path .expanduser (FLAGS .tmp_dir ),
281
- FLAGS . num_shards )
293
+ task_id = task_id )
282
294
283
295
284
296
if __name__ == "__main__" :
0 commit comments