Skip to content

Commit cc6d8de

Browse files
committed
fix random seed
1 parent 0b0ba5e commit cc6d8de

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

Diff for: efficientdet/dataloader.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def __call__(self, params, input_context=None, batch_size=None):
414414
)
415415

416416
batch_size = batch_size or params['batch_size']
417-
seed = params['tf_random_seed'] if self._debug else None
417+
seed = params.get('tf_random_seed', None)
418418
dataset = tf.data.Dataset.list_files(
419419
self._file_pattern, shuffle=self._is_training, seed=seed)
420420
if input_context:

Diff for: efficientdet/tf2/train.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def define_flags():
101101
flags.DEFINE_string('model_name', 'efficientdet-d0', 'Model name.')
102102
flags.DEFINE_bool('debug', False, 'Enable debug mode')
103103
flags.DEFINE_integer(
104-
'tf_random_seed', 111111,
104+
'tf_random_seed', None,
105105
'Fixed random seed for deterministic execution across runs for debugging.'
106106
)
107107
flags.DEFINE_bool('profile', False, 'Enable profile mode')
@@ -163,10 +163,12 @@ def main(_):
163163
for gpu in tf.config.list_physical_devices('GPU'):
164164
tf.config.experimental.set_memory_growth(gpu, True)
165165

166+
if FLAGS.tf_random_seed:
167+
tf.random.set_seed(FLAGS.tf_random_seed)
168+
166169
if FLAGS.debug:
167170
tf.debugging.set_log_device_placement(True)
168171
os.environ['TF_DETERMINISTIC_OPS'] = '1'
169-
tf.random.set_seed(FLAGS.tf_random_seed)
170172
logging.set_verbosity(logging.DEBUG)
171173

172174
if FLAGS.strategy == 'tpu':

0 commit comments

Comments
 (0)