|
37 | 37 | # Dependency imports
|
38 | 38 |
|
39 | 39 | from tensor2tensor.bin import t2t_trainer
|
| 40 | +from tensor2tensor.data_generators import text_encoder |
40 | 41 | from tensor2tensor.utils import decoding
|
| 42 | +from tensor2tensor.utils import registry |
41 | 43 | from tensor2tensor.utils import trainer_lib
|
42 | 44 | from tensor2tensor.utils import usr_dir
|
43 | 45 |
|
|
59 | 61 | flags.DEFINE_bool("decode_interactive", False,
|
60 | 62 | "Interactive local inference mode.")
|
61 | 63 | flags.DEFINE_integer("decode_shards", 1, "Number of decoding replicas.")
|
| 64 | +flags.DEFINE_string("score_file", "", "File to score. Each line in the file " |
| 65 | + "must be in the format input \t target.") |
62 | 66 |
|
63 | 67 |
|
64 | 68 | def create_hparams():
|
@@ -96,11 +100,80 @@ def decode(estimator, hparams, decode_hp):
|
96 | 100 | dataset_split="test" if FLAGS.eval_use_test_set else None)
|
97 | 101 |
|
98 | 102 |
|
| 103 | +def score_file(filename): |
| 104 | + """Score each line in a file and return the scores.""" |
| 105 | + # Prepare model. |
| 106 | + hparams = create_hparams() |
| 107 | + encoders = registry.problem(FLAGS.problems).feature_encoders(FLAGS.data_dir) |
| 108 | + has_inputs = "inputs" in encoders |
| 109 | + |
| 110 | + # Prepare features for feeding into the model. |
| 111 | + if has_inputs: |
| 112 | + inputs_ph = tf.placeholder(dtype=tf.int32) # Just length dimension. |
| 113 | + batch_inputs = tf.reshape(inputs_ph, [1, -1, 1, 1]) # Make it 4D. |
| 114 | + targets_ph = tf.placeholder(dtype=tf.int32) # Just length dimension. |
| 115 | + batch_targets = tf.reshape(targets_ph, [1, -1, 1, 1]) # Make it 4D. |
| 116 | + features = { |
| 117 | + "inputs": batch_inputs, |
| 118 | + "targets": batch_targets, |
| 119 | + } if has_inputs else {"targets": batch_targets} |
| 120 | + |
| 121 | + # Prepare the model and the graph when model runs on features. |
| 122 | + model = registry.model(FLAGS.model)(hparams, tf.estimator.ModeKeys.EVAL) |
| 123 | + _, losses = model(features) |
| 124 | + saver = tf.train.Saver() |
| 125 | + |
| 126 | + with tf.Session() as sess: |
| 127 | + # Load weights from checkpoint. |
| 128 | + ckpts = tf.train.get_checkpoint_state(FLAGS.output_dir) |
| 129 | + ckpt = ckpts.model_checkpoint_path |
| 130 | + saver.restore(sess, ckpt) |
| 131 | + # Run on each line. |
| 132 | + results = [] |
| 133 | + for line in open(filename): |
| 134 | + tab_split = line.split("\t") |
| 135 | + if len(tab_split) > 2: |
| 136 | + raise ValueError("Each line must have at most one tab separator.") |
| 137 | + if len(tab_split) == 1: |
| 138 | + targets = tab_split[0].strip() |
| 139 | + else: |
| 140 | + targets = tab_split[1].strip() |
| 141 | + inputs = tab_split[0].strip() |
| 142 | + # Run encoders and append EOS symbol. |
| 143 | + targets_numpy = encoders["targets"].encode( |
| 144 | + targets) + [text_encoder.EOS_ID] |
| 145 | + if has_inputs: |
| 146 | + inputs_numpy = encoders["inputs"].encode(inputs) + [text_encoder.EOS_ID] |
| 147 | + # Prepare the feed. |
| 148 | + feed = { |
| 149 | + inputs_ph: inputs_numpy, |
| 150 | + targets_ph: targets_numpy |
| 151 | + } if has_inputs else {targets_ph: targets_numpy} |
| 152 | + # Get the score. |
| 153 | + np_loss = sess.run(losses["training"], feed) |
| 154 | + results.append(np_loss) |
| 155 | + return results |
| 156 | + |
| 157 | + |
99 | 158 | def main(_):
|
100 | 159 | tf.logging.set_verbosity(tf.logging.INFO)
|
| 160 | + trainer_lib.set_random_seed(FLAGS.random_seed) |
101 | 161 | usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
|
102 | 162 | FLAGS.use_tpu = False # decoding not supported on TPU
|
103 | 163 |
|
| 164 | + if FLAGS.score_file: |
| 165 | + filename = os.path.expanduser(FLAGS.score_file) |
| 166 | + if not tf.gfile.Exists(filename): |
| 167 | + raise ValueError("The file to score doesn't exist: %s" % filename) |
| 168 | + results = score_file(filename) |
| 169 | + if not FLAGS.decode_to_file: |
| 170 | + raise ValueError("To score a file, specify --decode_to_file for results.") |
| 171 | + write_file = open(os.path.expanduser(FLAGS.decode_to_file), "w") |
| 172 | + for score in results: |
| 173 | + write_file.write("%.6f\n" % score) |
| 174 | + write_file.close() |
| 175 | + return |
| 176 | + |
104 | 177 | hp = create_hparams()
|
105 | 178 | decode_hp = create_decode_hparams()
|
106 | 179 |
|
|
0 commit comments