Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 7087807

Browse files
authored
Merge pull request #73 from rsepassi/push
v1.0.9
2 parents a2a6178 + e4fe66c commit 7087807

26 files changed

+395
-245
lines changed

.gitignore

-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# Compiled python modules.
22
*.pyc
3-
# Byte-compiled
4-
__pycache__/
53

64
# Python egg metadata, regenerated from source files by setuptools.
75
/*.egg-info

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ t2t-trainer --registry_help
5757
5858
PROBLEM=wmt_ende_tokens_32k
5959
MODEL=transformer
60-
HPARAMS=transformer_base
60+
HPARAMS=transformer_base_single_gpu
6161
6262
DATA_DIR=$HOME/t2t_data
6363
TMP_DIR=/tmp/t2t_datagen
@@ -209,7 +209,7 @@ and hyperparameter set functions can compose other hyperparameter set functions.
209209
The **trainer** binary is the main entrypoint for training, evaluation, and
210210
inference. Users can easily switch between problems, models, and hyperparameter
211211
sets by using the `--model`, `--problems`, and `--hparams_set` flags. Specific
212-
hyperparameters can be overriden with the `--hparams` flag. `--schedule` and
212+
hyperparameters can be overridden with the `--hparams` flag. `--schedule` and
213213
related flags control local and distributed training/evaluation
214214
([distributed training documentation](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/docs/distributed_training.md)).
215215

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.0.8',
8+
version='1.0.9',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='[email protected]',

tensor2tensor/bin/make_tf_configs.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232

3333
# Dependency imports
3434

35-
import six
3635
import tensorflow as tf
3736

3837
flags = tf.flags
@@ -51,7 +50,7 @@ def main(_):
5150

5251
cluster = {"ps": ps, "worker": workers}
5352

54-
for task_type, jobs in six.iteritems(cluster):
53+
for task_type, jobs in (("worker", workers), ("ps", ps)):
5554
for idx, job in enumerate(jobs):
5655
if task_type == "worker":
5756
cmd_line_flags = " ".join([
@@ -77,7 +76,7 @@ def main(_):
7776
"index": idx
7877
}
7978
})
80-
print(tf_config + "\t" + cmd_line_flags)
79+
print("'%s'\t%s" % (tf_config, cmd_line_flags))
8180

8281

8382
if __name__ == "__main__":

tensor2tensor/bin/t2t-datagen

100755100644
+3-3
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ from tensor2tensor.data_generators import algorithmic_math
3737
from tensor2tensor.data_generators import audio
3838
from tensor2tensor.data_generators import generator_utils
3939
from tensor2tensor.data_generators import image
40+
from tensor2tensor.data_generators import ptb
4041
from tensor2tensor.data_generators import snli
4142
from tensor2tensor.data_generators import wmt
4243
from tensor2tensor.data_generators import wsj_parsing
43-
from tensor2tensor.data_generators import ptb
4444

4545
import tensorflow as tf
4646

@@ -319,11 +319,11 @@ _SUPPORTED_PROBLEM_GENERATORS = {
319319
vocab_filename="tokens.vocab.%d" % 2**15,
320320
vocab_size=2**15)),
321321
"lmptb_10k": (
322-
lambda: ptb.train_generator(
322+
lambda: ptb.train_generator(
323323
FLAGS.tmp_dir,
324324
FLAGS.data_dir,
325325
False),
326-
lambda: ptb.valid_generator()),
326+
ptb.valid_generator),
327327
}
328328

329329
# pylint: enable=g-long-lambda

tensor2tensor/bin/t2t-trainer

100755100644
File mode changed.

tensor2tensor/data_generators/algorithmic.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def zipf_distribution(nbr_symbols, alpha):
102102
Usually for modelling natural text distribution is in
103103
the range [1.1-1.6].
104104
105-
Return:
105+
Returns:
106106
distr_map: list of float, Zipf's distribution over nbr_symbols.
107107
108108
"""
@@ -118,7 +118,7 @@ def zipf_random_sample(distr_map, sample_len):
118118
distr_map: list of float, Zipf's distribution over nbr_symbols.
119119
sample_len: integer, length of sequence to generate.
120120
121-
Return:
121+
Returns:
122122
sample: list of integer, Zipf's random sample over nbr_symbols.
123123
124124
"""
@@ -131,8 +131,8 @@ def zipf_random_sample(distr_map, sample_len):
131131
return [t+1 if t > 0 else t+2 for t in np.searchsorted(distr_map, u)]
132132

133133

134-
def reverse_generator_nlplike(nbr_symbols, max_length, nbr_cases, \
135-
scale_std_dev=100, alpha=1.5):
134+
def reverse_generator_nlplike(nbr_symbols, max_length, nbr_cases,
135+
scale_std_dev=100, alpha=1.5):
136136
"""Generator for the reversing nlp-like task on sequences of symbols.
137137
138138
The length of the sequence is drawn from a Gaussian(Normal) distribution
@@ -141,6 +141,7 @@ def reverse_generator_nlplike(nbr_symbols, max_length, nbr_cases, \
141141
nbr_cases sequences have been produced.
142142
143143
Args:
144+
nbr_symbols: integer, number of symbols.
144145
max_length: integer, maximum length of sequences to generate.
145146
nbr_cases: the number of cases to generate.
146147
scale_std_dev: float, Normal distribution's standard deviation scale factor

tensor2tensor/data_generators/algorithmic_test.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,13 @@ def testReverseGenerator(self):
4141
self.assertEqual(list(reversed(d["inputs"])) + [1], d["targets"])
4242
self.assertEqual(counter, 10)
4343

44-
def testZipfDistribution(self):
45-
# Following Zipf's Law with alpha equals 1: the first in rank is two times
46-
# more probable/frequent that the second in rank, three times more prob/freq
47-
# that the third in rank and so on.
44+
def testZipfDistribution(self):
45+
# Following Zipf's Law with alpha equals 1: the first in rank is two times
46+
# more probable/frequent that the second in rank, three times more prob/freq
47+
# that the third in rank and so on.
4848
d = algorithmic.zipf_distribution(10, 1.0001)
4949
for i in xrange(len(d[1:])-1):
50-
self.assertEqual("%.4f" % (abs(d[i+1]-d[i+2])*(i+2)), \
51-
"%.4f" % d[1])
50+
self.assertEqual("%.4f" % (abs(d[i+1]-d[i+2])*(i+2)), "%.4f" % d[1])
5251

5352
def testReverseGeneratorNlpLike(self):
5453
counter = 0

tensor2tensor/data_generators/generator_utils.py

100755100644
+2-1
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,8 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size):
244244
if ".gz" in lang_file:
245245
new_filepath = os.path.join(tmp_dir, lang_file[:-3])
246246
if os.path.exists(new_filepath):
247-
tf.logging.info("Subdirectory %s already exists, skipping unpacking" % filepath)
247+
tf.logging.info("Subdirectory %s already exists, skipping unpacking"
248+
% filepath)
248249
else:
249250
tf.logging.info("Unpacking subdirectory %s" % filepath)
250251
gunzip_file(filepath, new_filepath)

tensor2tensor/data_generators/problem_hparams.py

+16-18
Original file line numberDiff line numberDiff line change
@@ -340,24 +340,6 @@ def lm1b_16k(model_hparams):
340340
p.target_space_id = 3
341341
return p
342342

343-
def lmptb_10k(model_hparams):
344-
"""Penn Tree Bank language-modeling benchmark, 10k token vocabulary."""
345-
p = default_problem_hparams()
346-
p.input_modality = {}
347-
p.target_modality = (registry.Modalities.SYMBOL, 10000)
348-
349-
vocabulary = text_encoder.TokenTextEncoder(
350-
os.path.join(model_hparams.data_dir,
351-
"lmptb_10k.vocab"))
352-
353-
p.vocabulary = {
354-
"inputs": vocabulary,
355-
"targets": vocabulary,
356-
}
357-
358-
p.input_space_id = 3
359-
p.target_space_id = 3
360-
return p
361343

362344
def lm1b_64k(model_hparams):
363345
"""Billion-word language-modeling benchmark, 64k subtoken vocabulary."""
@@ -374,6 +356,22 @@ def lm1b_64k(model_hparams):
374356
p.target_space_id = 3
375357
return p
376358

359+
360+
def lmptb_10k(model_hparams):
361+
"""Penn Tree Bank language-modeling benchmark, 10k token vocabulary."""
362+
p = default_problem_hparams()
363+
p.input_modality = {}
364+
p.target_modality = (registry.Modalities.SYMBOL, 10000)
365+
vocabulary = text_encoder.TokenTextEncoder(
366+
os.path.join(model_hparams.data_dir, "lmptb_10k.vocab"))
367+
p.vocabulary = {
368+
"targets": vocabulary,
369+
}
370+
p.input_space_id = 3
371+
p.target_space_id = 3
372+
return p
373+
374+
377375
def wmt_enfr_characters(unused_model_hparams):
378376
"""English to French translation benchmark."""
379377
p = default_problem_hparams()

tensor2tensor/data_generators/ptb.py

+33-42
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import collections
2122
import os
2223
import sys
2324
import tarfile
24-
import collections
2525

2626
# Dependency imports
2727

@@ -34,68 +34,62 @@
3434
EOS = text_encoder.EOS
3535
PTB_URL = "http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz"
3636

37+
3738
def _read_words(filename):
38-
"""Reads words from a file.
39-
It returns a list of words without '\n'
40-
Originally from:
41-
https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/reader.py
42-
"""
39+
"""Reads words from a file."""
4340
with tf.gfile.GFile(filename, "r") as f:
4441
if sys.version_info[0] >= 3:
4542
return f.read().replace("\n", " ").split()
4643
else:
4744
return f.read().decode("utf-8").replace("\n", " ").split()
48-
49-
45+
5046

5147
def _build_vocab(filename, vocab_path, vocab_size):
52-
"""Reads a file a build a vocabulary of `vocab_size` words to
53-
as a list of words to `filename`
54-
The vocabulary is sorted by occurence count and has one word per line
55-
Originally from:
56-
https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/reader.py
48+
"""Reads a file to build a vocabulary of `vocab_size` most common words.
49+
50+
The vocabulary is sorted by occurence count and has one word per line.
51+
Originally from:
52+
https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/reader.py
53+
54+
Args:
55+
filename: file to read list of words from.
56+
vocab_path: path where to save the vocabulary.
57+
vocab_size: size of the vocablulary to generate.
5758
"""
5859
data = _read_words(filename)
59-
6060
counter = collections.Counter(data)
6161
count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
62-
words, _ = list(zip(*count_pairs))
62+
words, _ = list(zip(*count_pairs))
6363
words = words[:vocab_size]
64-
65-
with open(vocab_path, 'w') as f:
64+
with open(vocab_path, "w") as f:
6665
f.write("\n".join(words))
6766

67+
6868
def _get_token_encoder(vocab_dir, filename):
69-
"""Reads from file and returns a `TokenTextEncoder` based on the vocabulary
70-
"""
69+
"""Reads from file and returns a `TokenTextEncoder` for the vocabulary."""
7170
vocab_name = "lmptb_10k.vocab"
7271
vocab_path = os.path.join(vocab_dir, vocab_name)
73-
74-
7572
_build_vocab(filename, vocab_path, 10000)
76-
7773
return text_encoder.TokenTextEncoder(vocab_path)
78-
74+
7975

8076
class PTB(object):
77+
"""A class for generating PTB data."""
78+
8179
def __init__(self, tmp_dir, data_dir, char=False):
8280
assert not char, "char mode for PTB is not yet implemented"
8381
self.char = char
8482
self.data_dir = data_dir
85-
#self.num_steps = num_steps
8683

8784
url = PTB_URL
88-
8985
filename = os.path.basename(url)
90-
compressed_filepath = generator_utils.maybe_download(tmp_dir,
91-
filename,
92-
url)
93-
86+
compressed_filepath = generator_utils.maybe_download(
87+
tmp_dir, filename, url)
9488
ptb_files = []
9589
ptb_char_files = []
9690
with tarfile.open(compressed_filepath, "r:gz") as tgz:
9791
files = []
98-
# selecting only relevant files
92+
# Selecting only relevant files.
9993
for m in tgz.getmembers():
10094
if "ptb" in m.name and ".txt" in m.name:
10195
if "char" in m.name:
@@ -120,7 +114,6 @@ def __init__(self, tmp_dir, data_dir, char=False):
120114

121115
assert hasattr(self, "train"), "Training file not found"
122116
assert hasattr(self, "valid"), "Validation file not found"
123-
124117
self.encoder = _get_token_encoder(data_dir, self.train)
125118

126119
def train_generator(self):
@@ -132,27 +125,25 @@ def valid_generator(self):
132125
def _generator(self, filename):
133126
with tf.gfile.GFile(filename, "r") as f:
134127
for line in f:
135-
line = " ".join(line.replace('\n', EOS).split())
128+
line = " ".join(line.replace("\n", EOS).split())
136129
tok = self.encoder.encode(line)
137-
x = tok[:-1]
138-
y = tok[1:]
139-
140-
yield {"inputs": x,
141-
"targets": y}
130+
yield {"inputs": tok[:-1], "targets": tok[1:]}
131+
142132

143133
# Using a object "singleton"
144134
# `train_generator` must be called before
145135
# `valid_generator` in order to work
146136
_ptb = {}
137+
138+
147139
def train_generator(*args, **kwargs):
148-
"""The train data generator to be called
149-
"""
140+
"""The train data generator to be called."""
150141
global _ptb
151142
_ptb = PTB(*args, **kwargs)
152143
return _ptb.train_generator()
153144

145+
154146
def valid_generator():
155-
"""Validation (aka. dev) data generator
156-
"""
157-
global _ptb
147+
"""Validation (aka. dev) data generator."""
148+
global _ptb # pylint:disable=global-variable-not-assigned
158149
return _ptb.valid_generator()

tensor2tensor/data_generators/snli.py

100755100644
+1
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def _parse_dataset(file_path, tmp_dir, train):
130130

131131

132132
def _get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size):
133+
"""Read or create vocabulary."""
133134
vocab_filepath = os.path.join(tmp_dir, vocab_filename)
134135
print('Vocab file written to: ' + vocab_filepath)
135136

0 commit comments

Comments
 (0)