Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 097ea5f

Browse files
authored
Merge pull request #396 from rsepassi/push
v1.2.7
2 parents 9e7d03f + f564d6c commit 097ea5f

24 files changed

+1430
-469
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.2.6',
8+
version='1.2.7',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='[email protected]',

tensor2tensor/data_generators/generator_utils.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121

2222
from collections import defaultdict
2323
import gzip
24-
import io
2524
import os
2625
import random
26+
import stat
2727
import tarfile
2828

2929
# Dependency imports
@@ -190,8 +190,8 @@ def maybe_download(directory, filename, url):
190190
print()
191191
tf.gfile.Rename(inprogress_filepath, filepath)
192192
statinfo = os.stat(filepath)
193-
tf.logging.info("Successfully downloaded %s, %s bytes." % (filename,
194-
statinfo.st_size))
193+
tf.logging.info("Successfully downloaded %s, %s bytes." %
194+
(filename, statinfo.st_size))
195195
else:
196196
tf.logging.info("Not downloading, file already found: %s" % filepath)
197197
return filepath
@@ -243,7 +243,7 @@ def maybe_download_from_drive(directory, filename, url):
243243
print()
244244
statinfo = os.stat(filepath)
245245
tf.logging.info("Successfully downloaded %s, %s bytes." % (filename,
246-
statinfo.st_size))
246+
statinfo.st_size))
247247
return filepath
248248

249249

@@ -258,8 +258,11 @@ def gunzip_file(gz_path, new_path):
258258
tf.logging.info("File %s already exists, skipping unpacking" % new_path)
259259
return
260260
tf.logging.info("Unpacking %s to %s" % (gz_path, new_path))
261+
# We may be unpacking into a newly created directory, add write mode.
262+
mode = stat.S_IRWXU or stat.S_IXGRP or stat.S_IRGRP or stat.S_IROTH
263+
os.chmod(os.path.dirname(new_path), mode)
261264
with gzip.open(gz_path, "rb") as gz_file:
262-
with io.open(new_path, "wb") as new_file:
265+
with tf.gfile.GFile(new_path, mode="wb") as new_file:
263266
for line in gz_file:
264267
new_file.write(line)
265268

tensor2tensor/data_generators/image.py

+56
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import json
2525
import os
2626
import random
27+
import struct
2728
import tarfile
2829
import zipfile
2930

@@ -925,3 +926,58 @@ class ImageMsCocoTokens32k(ImageMsCocoTokens8k):
925926
@property
926927
def targeted_vocab_size(self):
927928
return 2**15 # 32768
929+
930+
931+
@registry.register_problem
932+
class OcrTest(Image2TextProblem):
933+
"""OCR test problem."""
934+
935+
@property
936+
def is_small(self):
937+
return True
938+
939+
@property
940+
def is_character_level(self):
941+
return True
942+
943+
@property
944+
def target_space_id(self):
945+
return problem.SpaceID.EN_CHR
946+
947+
@property
948+
def train_shards(self):
949+
return 1
950+
951+
@property
952+
def dev_shards(self):
953+
return 1
954+
955+
def preprocess_example(self, example, mode, _):
956+
# Resize from usual size ~1350x60 to 90x4 in this test.
957+
img = example["inputs"]
958+
example["inputs"] = tf.to_int64(
959+
tf.image.resize_images(img, [90, 4], tf.image.ResizeMethod.AREA))
960+
return example
961+
962+
def generator(self, data_dir, tmp_dir, is_training):
963+
# In this test problem, we assume that the data is in tmp_dir/ocr/ in
964+
# files names 0.png, 0.txt, 1.png, 1.txt and so on until num_examples.
965+
num_examples = 2
966+
ocr_dir = os.path.join(tmp_dir, "ocr/")
967+
tf.logging.info("Looking for OCR data in %s." % ocr_dir)
968+
for i in xrange(num_examples):
969+
image_filepath = os.path.join(ocr_dir, "%d.png" % i)
970+
text_filepath = os.path.join(ocr_dir, "%d.txt" % i)
971+
with tf.gfile.Open(text_filepath, "rb") as f:
972+
label = f.read()
973+
with tf.gfile.Open(image_filepath, "rb") as f:
974+
encoded_image_data = f.read()
975+
# In PNG files width and height are stored in these bytes.
976+
width, height = struct.unpack(">ii", encoded_image_data[16:24])
977+
yield {
978+
"image/encoded": [encoded_image_data],
979+
"image/format": ["png"],
980+
"image/class/label": label.strip(),
981+
"image/height": [height],
982+
"image/width": [width]
983+
}

tensor2tensor/data_generators/translate_enfr.py

+72-34
Original file line numberDiff line numberDiff line change
@@ -34,50 +34,54 @@
3434
# End-of-sentence marker.
3535
EOS = text_encoder.EOS_ID
3636

37-
_ENFR_TRAIN_DATASETS = [
37+
_ENFR_TRAIN_SMALL_DATA = [
3838
[
3939
"https://s3.amazonaws.com/opennmt-trainingdata/baseline-1M-enfr.tgz",
4040
("baseline-1M-enfr/baseline-1M_train.en",
4141
"baseline-1M-enfr/baseline-1M_train.fr")
4242
],
43-
# [
44-
# "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz",
45-
# ("commoncrawl.fr-en.en", "commoncrawl.fr-en.fr")
46-
# ],
47-
# [
48-
# "http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz",
49-
# ("training/europarl-v7.fr-en.en", "training/europarl-v7.fr-en.fr")
50-
# ],
51-
# [
52-
# "http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz",
53-
# ("training/news-commentary-v9.fr-en.en",
54-
# "training/news-commentary-v9.fr-en.fr")
55-
# ],
56-
# [
57-
# "http://www.statmt.org/wmt10/training-giga-fren.tar",
58-
# ("giga-fren.release2.fixed.en.gz",
59-
# "giga-fren.release2.fixed.fr.gz")
60-
# ],
61-
# [
62-
# "http://www.statmt.org/wmt13/training-parallel-un.tgz",
63-
# ("un/undoc.2000.fr-en.en", "un/undoc.2000.fr-en.fr")
64-
# ],
6543
]
66-
_ENFR_TEST_DATASETS = [
44+
_ENFR_TEST_SMALL_DATA = [
6745
[
6846
"https://s3.amazonaws.com/opennmt-trainingdata/baseline-1M-enfr.tgz",
6947
("baseline-1M-enfr/baseline-1M_valid.en",
7048
"baseline-1M-enfr/baseline-1M_valid.fr")
7149
],
72-
# [
73-
# "http://data.statmt.org/wmt17/translation-task/dev.tgz",
74-
# ("dev/newstest2013.en", "dev/newstest2013.fr")
75-
# ],
50+
]
51+
_ENFR_TRAIN_LARGE_DATA = [
52+
[
53+
"http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz",
54+
("commoncrawl.fr-en.en", "commoncrawl.fr-en.fr")
55+
],
56+
[
57+
"http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz",
58+
("training/europarl-v7.fr-en.en", "training/europarl-v7.fr-en.fr")
59+
],
60+
[
61+
"http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz",
62+
("training/news-commentary-v9.fr-en.en",
63+
"training/news-commentary-v9.fr-en.fr")
64+
],
65+
[
66+
"http://www.statmt.org/wmt10/training-giga-fren.tar",
67+
("giga-fren.release2.fixed.en.gz",
68+
"giga-fren.release2.fixed.fr.gz")
69+
],
70+
[
71+
"http://www.statmt.org/wmt13/training-parallel-un.tgz",
72+
("un/undoc.2000.fr-en.en", "un/undoc.2000.fr-en.fr")
73+
],
74+
]
75+
_ENFR_TEST_LARGE_DATA = [
76+
[
77+
"http://data.statmt.org/wmt17/translation-task/dev.tgz",
78+
("dev/newstest2013.en", "dev/newstest2013.fr")
79+
],
7680
]
7781

7882

7983
@registry.register_problem
80-
class TranslateEnfrWmt8k(translate.TranslateProblem):
84+
class TranslateEnfrWmtSmall8k(translate.TranslateProblem):
8185
"""Problem spec for WMT En-Fr translation."""
8286

8387
@property
@@ -88,11 +92,18 @@ def targeted_vocab_size(self):
8892
def vocab_name(self):
8993
return "vocab.enfr"
9094

95+
@property
96+
def use_small_dataset(self):
97+
return True
98+
9199
def generator(self, data_dir, tmp_dir, train):
92100
symbolizer_vocab = generator_utils.get_or_generate_vocab(
93101
data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size,
94-
_ENFR_TRAIN_DATASETS)
95-
datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS
102+
_ENFR_TRAIN_SMALL_DATA)
103+
if self.use_small_dataset:
104+
datasets = _ENFR_TRAIN_SMALL_DATA if train else _ENFR_TEST_SMALL_DATA
105+
else:
106+
datasets = _ENFR_TRAIN_LARGE_DATA if train else _ENFR_TEST_LARGE_DATA
96107
tag = "train" if train else "dev"
97108
data_path = translate.compile_data(tmp_dir, datasets,
98109
"wmt_enfr_tok_%s" % tag)
@@ -109,15 +120,31 @@ def target_space_id(self):
109120

110121

111122
@registry.register_problem
112-
class TranslateEnfrWmt32k(TranslateEnfrWmt8k):
123+
class TranslateEnfrWmtSmall32k(TranslateEnfrWmtSmall8k):
113124

114125
@property
115126
def targeted_vocab_size(self):
116127
return 2**15 # 32768
117128

118129

119130
@registry.register_problem
120-
class TranslateEnfrWmtCharacters(translate.TranslateProblem):
131+
class TranslateEnfrWmt8k(TranslateEnfrWmtSmall8k):
132+
133+
@property
134+
def use_small_dataset(self):
135+
return False
136+
137+
138+
@registry.register_problem
139+
class TranslateEnfrWmt32k(TranslateEnfrWmtSmall32k):
140+
141+
@property
142+
def use_small_dataset(self):
143+
return False
144+
145+
146+
@registry.register_problem
147+
class TranslateEnfrWmtSmallCharacters(translate.TranslateProblem):
121148
"""Problem spec for WMT En-Fr translation."""
122149

123150
@property
@@ -130,7 +157,10 @@ def vocab_name(self):
130157

131158
def generator(self, data_dir, tmp_dir, train):
132159
character_vocab = text_encoder.ByteTextEncoder()
133-
datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS
160+
if self.use_small_dataset:
161+
datasets = _ENFR_TRAIN_SMALL_DATA if train else _ENFR_TEST_SMALL_DATA
162+
else:
163+
datasets = _ENFR_TRAIN_LARGE_DATA if train else _ENFR_TEST_LARGE_DATA
134164
tag = "train" if train else "dev"
135165
data_path = translate.compile_data(tmp_dir, datasets,
136166
"wmt_enfr_chr_%s" % tag)
@@ -144,3 +174,11 @@ def input_space_id(self):
144174
@property
145175
def target_space_id(self):
146176
return problem.SpaceID.FR_CHR
177+
178+
179+
@registry.register_problem
180+
class TranslateEnfrWmtCharacters(TranslateEnfrWmtSmallCharacters):
181+
182+
@property
183+
def use_small_dataset(self):
184+
return False

tensor2tensor/data_generators/translate_enzh.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,13 @@
3535

3636
# End-of-sentence marker.
3737
EOS = text_encoder.EOS_ID
38+
39+
# End-of-sentence marker.
40+
EOS = text_encoder.EOS_ID
41+
3842
# This is far from being the real WMT17 task - only toyset here
39-
# you need to register to get UN data and CWT data
40-
# also by convention this is EN to ZH - use translate_enzh_wmt8k_rev for ZH to EN task
43+
# you need to register to get UN data and CWT data. Also, by convention,
44+
# this is EN to ZH - use translate_enzh_wmt8k_rev for ZH to EN task
4145
_ENZH_TRAIN_DATASETS = [[("http://data.statmt.org/wmt17/translation-task/"
4246
"training-parallel-nc-v12.tgz"),
4347
("training/news-commentary-v12.zh-en.en",

tensor2tensor/layers/common_attention.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -2958,15 +2958,20 @@ def pad_and_reshape(x):
29582958

29592959
@expert_utils.add_var_scope()
29602960
def multihead_self_attention_reduced(
2961-
x, factor, nonlinearity, reduction_type, multihead_params):
2961+
x,
2962+
factor,
2963+
multihead_params,
2964+
nonlinearity="none",
2965+
reduction_type="conv",
2966+
):
29622967
"""Reduce the length dimension by compressing with conv.
29632968
29642969
Args:
29652970
x (tf.Tensor): float32 of shape [batch, length, depth]
29662971
factor (int): compression factor for the memory sequence
2972+
multihead_params (dict): parameters for multihead attention
29672973
nonlinearity (str): Add some non-linearity after the memory block
29682974
reduction_type (str): type of compression
2969-
multihead_params (dict): parameters for multihead attention
29702975
29712976
Returns:
29722977
(tf.Tensor): float32 of shape [batch, length, depth]

tensor2tensor/layers/common_hparams.py

+3
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,15 @@ def basic_params1():
116116
# If set to True, drop sequences longer than max_length during eval.
117117
# This affects the validity of the evaluation metrics.
118118
eval_drop_long_sequences=int(False),
119+
# TODO(lukaszkaiser): these parameters should probably be set elsewhere.
119120
# in SymbolModality, share the output embeddings and the softmax
120121
# variables.
121122
# You can also share the input embeddings with the output embeddings
122123
# by using a problem_hparams that uses the same modality object for
123124
# the input_modality and target_modality.
124125
shared_embedding_and_softmax_weights=int(False),
126+
# In SymbolModality, skip the top layer, assume we're providing logits.
127+
symbol_modality_skip_top=int(False),
125128
# For each feature for which you want to override the default input
126129
# modality, add an entry to this semicolon-separated string. Entries are
127130
# formatted "feature_name:modality_type:modality_name", e.g.

0 commit comments

Comments
 (0)