Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 02da1be

Browse files
author
Ryan Sepassi
committed
Add random seed, py3 fix, disable flaky test
PiperOrigin-RevId: 179942374
1 parent 2a07e8f commit 02da1be

File tree

7 files changed

+18
-5
lines changed

7 files changed

+18
-5
lines changed

.travis.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ env:
1414
- T2T_DATA_DIR=/tmp/t2t-data
1515
- T2T_TRAIN_DIR=/tmp/t2t-train
1616
script:
17-
- pytest --ignore=tensor2tensor/utils/registry_test.py --ignore=tensor2tensor/problems_test.py --ignore=tensor2tensor/tpu/tpu_trainer_lib_test.py
17+
- pytest --ignore=tensor2tensor/utils/registry_test.py --ignore=tensor2tensor/problems_test.py --ignore=tensor2tensor/tpu/tpu_trainer_lib_test.py --ignore=tensor2tensor/data_generators/algorithmic_math_test.py
1818
- pytest tensor2tensor/utils/registry_test.py
1919
- pytest tensor2tensor/tpu/tpu_trainer_lib_test.py
2020
- t2t-datagen 2>&1 | grep translate && echo passed

tensor2tensor/bin/t2t-trainer

+2-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ flags.DEFINE_string("t2t_usr_dir", "",
4545
"The imported files should contain registrations, "
4646
"e.g. @registry.register_model calls, that will then be "
4747
"available to the t2t-trainer.")
48+
flags.DEFINE_integer("random_seed", 1234, "Random seed.")
4849
flags.DEFINE_integer("tpu_num_shards", 8, "Number of tpu shards.")
4950
flags.DEFINE_integer("iterations_per_loop", 1000,
5051
"Number of iterations in a TPU training loop.")
@@ -171,7 +172,7 @@ def execute_schedule(exp):
171172

172173
def main(_):
173174
tf.logging.set_verbosity(tf.logging.INFO)
174-
tf.set_random_seed(123)
175+
tpu_trainer_lib.set_random_seed(FLAGS.random_seed)
175176
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
176177
log_registry()
177178

tensor2tensor/bin/t2t_trainer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
"The imported files should contain registrations, "
4545
"e.g. @registry.register_model calls, that will then be "
4646
"available to the t2t-trainer.")
47+
flags.DEFINE_integer("random_seed", 1234, "Random seed.")
4748
flags.DEFINE_integer("tpu_num_shards", 8, "Number of tpu shards.")
4849
flags.DEFINE_integer("iterations_per_loop", 1000,
4950
"Number of iterations in a TPU training loop.")
@@ -170,7 +171,7 @@ def execute_schedule(exp):
170171

171172
def main(_):
172173
tf.logging.set_verbosity(tf.logging.INFO)
173-
tf.set_random_seed(123)
174+
tpu_trainer_lib.set_random_seed(FLAGS.random_seed)
174175
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
175176
log_registry()
176177

tensor2tensor/data_generators/algorithmic_math_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
"""Tests for tensor2tensor.data_generators.algorithmic_math."""
17+
# TODO(rsepassi): This test is flaky. Disable, remove, or update.
1718

1819
from __future__ import absolute_import
1920
from __future__ import division

tensor2tensor/data_generators/problem.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -947,7 +947,7 @@ def standardize_shapes(features, batch_size=None):
947947

948948
def pad_batch(features, batch_multiple):
949949
"""Pad batch dim of features to nearest multiple of batch_multiple."""
950-
feature = features.items()[0][1]
950+
feature = list(features.items())[0][1]
951951
batch_size = tf.shape(feature)[0]
952952
mod = batch_size % batch_multiple
953953
has_mod = tf.cast(tf.cast(mod, tf.bool), tf.int32)

tensor2tensor/tpu/tpu_trainer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
"The imported files should contain registrations, "
4545
"e.g. @registry.register_model calls, that will then be "
4646
"available to the t2t-trainer.")
47+
flags.DEFINE_integer("random_seed", 1234, "Random seed.")
4748
flags.DEFINE_integer("tpu_num_shards", 8, "Number of tpu shards.")
4849
flags.DEFINE_integer("iterations_per_loop", 1000,
4950
"Number of iterations in a TPU training loop.")
@@ -170,7 +171,7 @@ def execute_schedule(exp):
170171

171172
def main(_):
172173
tf.logging.set_verbosity(tf.logging.INFO)
173-
tf.set_random_seed(123)
174+
tpu_trainer_lib.set_random_seed(FLAGS.random_seed)
174175
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
175176
log_registry()
176177

tensor2tensor/tpu/tpu_trainer_lib.py

+9
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020
from __future__ import print_function
2121

2222
import os
23+
import random
2324

2425
# Dependency imports
2526

27+
import numpy as np
28+
2629
from tensor2tensor.utils import devices
2730
from tensor2tensor.utils import expert_utils
2831
from tensor2tensor.utils import metrics_hook
@@ -336,3 +339,9 @@ def add_problem_hparams(hparams, problems):
336339

337340
hparams.problem_instances.append(problem)
338341
hparams.problems.append(p_hparams)
342+
343+
344+
def set_random_seed(seed):
345+
tf.set_random_seed(seed)
346+
random.seed(seed)
347+
np.random.seed(seed)

0 commit comments

Comments
 (0)