Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 8bdecbe

Browse files
authored
Merge pull request #646 from rsepassi/push
v1.5.5
2 parents af82068 + 688f4d5 commit 8bdecbe

32 files changed

+883
-280
lines changed

.travis.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ matrix:
1616
- python: "3.6"
1717
env: TF_VERSION="1.4.*"
1818
- python: "3.6"
19-
env: TF_VERSION="1.6.*"
19+
env: TF_VERSION="1.5.*"
2020
before_install:
2121
- echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list
2222
- curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -

.github/ISSUE_TEMPLATE.md ISSUE_TEMPLATE.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
77
### *TensorFlow* and *tensor2tensor* versions
88

9-
<!-- **Note** Run `pip list | grep tensor` to include TensorFlow and tensor2tensor versions -->
9+
<!-- **Note** Run `pip freeze | grep tensor` to get versions -->
1010

1111
>
1212
@@ -16,7 +16,7 @@
1616
1717
### In case of bug report: Error log
1818

19-
<!-- Please use code markdown to format output messages. -->
19+
<!-- Please use code markdown (```) to format output messages. -->
2020
<!-- See https://help.github.com/articles/creating-and-highlighting-code-blocks/ -->
2121

2222
>

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO
1515
of deep learning models and datasets designed to make deep learning more
1616
accessible and [accelerate ML
1717
research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html).
18-
is actively used and maintained by researchers and engineers within the
18+
T2T is actively used and maintained by researchers and engineers within the
1919
[Google Brain team](https://research.google.com/teams/brain/) and a community
2020
of users. We're eager to collaborate with you too, so feel free to
2121
[open an issue on GitHub](https://github.com/tensorflow/tensor2tensor/issues)

docs/cloud_tpu.md

+11
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@ See the official tutorial for [running Transfomer
1818
on Cloud TPUs](https://cloud.google.com/tpu/docs/tutorials/transformer)
1919
for some examples and try out your own problems.
2020

21+
Image Transformer:
22+
* `imagetransformer` with `imagetransformer_base_tpu` (or
23+
`imagetransformer_tiny_tpu`)
24+
* `img2img_transformer` with `img2img_transformer_base_tpu` (or
25+
`img2img_transformer_tiny_tpu`)
26+
27+
You can run the `ImageTransformer` model on problems like unconditional or
28+
conditional Image generation and `Img2ImgTransformer` model on Super Resolution.
29+
We run on datasets like CelebA, CIFAR and ImageNet but they should work with any
30+
other image dataset.
31+
2132
Residual networks:
2233
* `resnet` with `resnet_50` (or `resnet_18` or `resnet_34`)
2334
* `revnet` with `revnet_104` (or `revnet_38_cifar`)

docs/walkthrough.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO
1515
of deep learning models and datasets designed to make deep learning more
1616
accessible and [accelerate ML
1717
research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html).
18-
is actively used and maintained by researchers and engineers within the
18+
T2T is actively used and maintained by researchers and engineers within the
1919
[Google Brain team](https://research.google.com/teams/brain/) and a community
2020
of users. We're eager to collaborate with you too, so feel free to
2121
[open an issue on GitHub](https://github.com/tensorflow/tensor2tensor/issues)
@@ -154,7 +154,7 @@ For all translation problems, we suggest to try the Transformer model:
154154
this should reach a BLEU score of about 28 on the English-German data-set,
155155
which is close to state-of-the art. If training on a single GPU, try the
156156
`--hparams_set=transformer_base_single_gpu` setting. For very good results
157-
or larger data-sets (e.g., for English-French)m, try the big model
157+
or larger data-sets (e.g., for English-French), try the big model
158158
with `--hparams_set=transformer_big`.
159159

160160
## Basics

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.5.4',
8+
version='1.5.5',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='[email protected]',

tensor2tensor/data_generators/cifar.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ def preprocess_example(self, example, mode, unused_hparams):
124124
image.set_shape([_CIFAR10_IMAGE_SIZE, _CIFAR10_IMAGE_SIZE, 3])
125125
if mode == tf.estimator.ModeKeys.TRAIN:
126126
image = image_utils.cifar_image_augmentation(image)
127-
image = tf.image.per_image_standardization(image)
127+
if not self._was_reversed:
128+
image = tf.image.per_image_standardization(image)
128129
example["inputs"] = image
129130
return example
130131

@@ -151,7 +152,8 @@ class ImageCifar10Plain(ImageCifar10):
151152
def preprocess_example(self, example, mode, unused_hparams):
152153
image = example["inputs"]
153154
image.set_shape([_CIFAR10_IMAGE_SIZE, _CIFAR10_IMAGE_SIZE, 3])
154-
image = tf.image.per_image_standardization(image)
155+
if not self._was_reversed:
156+
image = tf.image.per_image_standardization(image)
155157
example["inputs"] = image
156158
return example
157159

@@ -179,7 +181,8 @@ def dataset_filename(self):
179181
def preprocess_example(self, example, mode, unused_hparams):
180182
image = example["inputs"]
181183
image = image_utils.resize_by_area(image, 8)
182-
image = tf.image.per_image_standardization(image)
184+
if not self._was_reversed:
185+
image = tf.image.per_image_standardization(image)
183186
example["inputs"] = image
184187
return example
185188

@@ -192,7 +195,6 @@ def dataset_filename(self):
192195
return "image_cifar10_plain" # Reuse CIFAR-10 plain data.
193196

194197
def preprocess_example(self, example, unused_mode, unused_hparams):
195-
196198
inputs = example["inputs"]
197199
# For Img2Img resize input and output images as desired.
198200
example["inputs"] = image_utils.resize_by_area(inputs, 8)
@@ -330,7 +332,8 @@ def preprocess_example(self, example, mode, unused_hparams):
330332
image.set_shape([_CIFAR100_IMAGE_SIZE, _CIFAR100_IMAGE_SIZE, 3])
331333
if mode == tf.estimator.ModeKeys.TRAIN:
332334
image = image_utils.cifar_image_augmentation(image)
333-
image = tf.image.per_image_standardization(image)
335+
if not self._was_reversed:
336+
image = tf.image.per_image_standardization(image)
334337
example["inputs"] = image
335338
return example
336339

@@ -357,7 +360,8 @@ class ImageCifar100Plain(ImageCifar100):
357360
def preprocess_example(self, example, mode, unused_hparams):
358361
image = example["inputs"]
359362
image.set_shape([_CIFAR100_IMAGE_SIZE, _CIFAR100_IMAGE_SIZE, 3])
360-
image = tf.image.per_image_standardization(image)
363+
if not self._was_reversed:
364+
image = tf.image.per_image_standardization(image)
361365
example["inputs"] = image
362366
return example
363367

@@ -385,7 +389,8 @@ def dataset_filename(self):
385389
def preprocess_example(self, example, mode, unused_hparams):
386390
image = example["inputs"]
387391
image = image_utils.resize_by_area(image, 8)
388-
image = tf.image.per_image_standardization(image)
392+
if not self._was_reversed:
393+
image = tf.image.per_image_standardization(image)
389394
example["inputs"] = image
390395
return example
391396

@@ -398,7 +403,6 @@ def dataset_filename(self):
398403
return "image_cifar100_plain" # Reuse CIFAR-100 plain data.
399404

400405
def preprocess_example(self, example, unused_mode, unused_hparams):
401-
402406
inputs = example["inputs"]
403407
# For Img2Img resize input and output images as desired.
404408
example["inputs"] = image_utils.resize_by_area(inputs, 8)

tensor2tensor/data_generators/gym.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
import tensorflow as tf
3636

3737

38+
39+
3840
flags = tf.flags
3941
FLAGS = flags.FLAGS
4042

@@ -157,7 +159,6 @@ def num_steps(self):
157159
return 5000
158160

159161

160-
161162
@registry.register_problem
162163
class GymPongTrajectoriesFromPolicy(GymDiscreteProblem):
163164
"""Pong game, loaded actions."""
@@ -197,7 +198,7 @@ def generator(self, data_dir, tmp_dir):
197198
model_saver.restore(sess, FLAGS.model_path)
198199
for item in super(GymPongTrajectoriesFromPolicy,
199200
self).generator(data_dir, tmp_dir):
200-
yield item
201+
yield item
201202

202203
# TODO(blazej0): For training of atari agents wrappers are usually used.
203204
# Below we have a hacky solution which is a workaround to be used together

tensor2tensor/data_generators/image_utils.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tensor2tensor.data_generators import generator_utils
2727
from tensor2tensor.data_generators import problem
2828
from tensor2tensor.data_generators import text_encoder
29+
from tensor2tensor.utils import metrics
2930
from tensor2tensor.utils import registry
3031

3132
import tensorflow as tf
@@ -64,9 +65,19 @@ def example_reading_spec(self, label_repr=None):
6465
return data_fields, data_items_to_decoders
6566

6667
def preprocess_example(self, example, mode, hparams):
67-
example["inputs"] = tf.image.per_image_standardization(example["inputs"])
68+
if not self._was_reversed:
69+
example["inputs"] = tf.image.per_image_standardization(example["inputs"])
6870
return example
6971

72+
def eval_metrics(self):
73+
eval_metrics = [
74+
metrics.Metrics.ACC, metrics.Metrics.ACC_TOP5,
75+
metrics.Metrics.ACC_PER_SEQ, metrics.Metrics.NEG_LOG_PERPLEXITY
76+
]
77+
if self._was_reversed:
78+
eval_metrics += [metrics.Metrics.IMAGE_SUMMARY]
79+
return eval_metrics
80+
7081

7182
class Image2ClassProblem(ImageProblem):
7283
"""Base class for image classification problems."""

tensor2tensor/data_generators/imagenet.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,8 @@ def distorted_bounding_box_crop(image,
334334
Returns:
335335
(cropped image `Tensor`, distorted bbox `Tensor`).
336336
"""
337-
with tf.name_scope(scope, default_name="distorted_bounding_box_crop", values=[image, bbox]):
337+
with tf.name_scope(scope, default_name="distorted_bounding_box_crop",
338+
values=[image, bbox]):
338339
# Each bounding box has shape [1, num_boxes, box coords] and
339340
# the coordinates are ordered [ymin, xmin, ymax, xmax].
340341

tensor2tensor/data_generators/librispeech.py

+65-10
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
"train-other-500"
4040
],
4141
]
42-
_LIBRISPEECH_TEST_DATASETS = [
42+
_LIBRISPEECH_DEV_DATASETS = [
4343
[
4444
"http://www.openslr.org/resources/12/dev-clean.tar.gz",
4545
"dev-clean"
@@ -49,6 +49,16 @@
4949
"dev-other"
5050
],
5151
]
52+
_LIBRISPEECH_TEST_DATASETS = [
53+
[
54+
"http://www.openslr.org/resources/12/test-clean.tar.gz",
55+
"test-clean"
56+
],
57+
[
58+
"http://www.openslr.org/resources/12/test-other.tar.gz",
59+
"test-other"
60+
],
61+
]
5262

5363

5464
def _collect_data(directory, input_ext, transcription_ext):
@@ -72,7 +82,7 @@ def _collect_data(directory, input_ext, transcription_ext):
7282
assert key not in data_files
7383
media_name = "%s.%s"%(media_base, input_ext)
7484
media_path = os.path.join(root, media_name)
75-
data_files[key] = (media_path, label)
85+
data_files[key] = (media_base, media_path, label)
7686
return data_files
7787

7888

@@ -82,7 +92,8 @@ class Librispeech(speech_recognition.SpeechRecognitionProblem):
8292

8393
# Select only the clean data
8494
TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS
85-
DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS
95+
DEV_DATASETS = _LIBRISPEECH_DEV_DATASETS
96+
TEST_DATASETS = _LIBRISPEECH_TEST_DATASETS
8697

8798
@property
8899
def num_shards(self):
@@ -96,6 +107,10 @@ def use_subword_tokenizer(self):
96107
def num_dev_shards(self):
97108
return 1
98109

110+
@property
111+
def num_test_shards(self):
112+
return 1
113+
99114
@property
100115
def use_train_shards_for_dev(self):
101116
"""If true, we only generate training data and hold out shards for dev."""
@@ -127,20 +142,31 @@ def generator(self, data_dir, tmp_dir, datasets,
127142
audio_encoder = encoders["waveforms"]
128143
text_encoder = encoders["targets"]
129144

130-
for media_file, text_data in sorted(data_pairs)[start_from:]:
145+
for utt_id, media_file, text_data in sorted(data_pairs)[start_from:]:
131146
if how_many > 0 and i == how_many:
132147
return
133148
i += 1
149+
wav_data = audio_encoder.encode(media_file)
150+
spk_id, unused_book_id, _ = utt_id.split("-")
134151
yield {
135-
"waveforms": audio_encoder.encode(media_file),
136-
"targets": text_encoder.encode(text_data)
152+
"waveforms": wav_data,
153+
"waveform_lens": [len(wav_data)],
154+
"targets": text_encoder.encode(text_data),
155+
"raw_transcript": [text_data],
156+
"utt_id": [utt_id],
157+
"spk_id": [spk_id],
137158
}
138159

139160
def generate_data(self, data_dir, tmp_dir, task_id=-1):
140161
train_paths = self.training_filepaths(
141162
data_dir, self.num_shards, shuffled=False)
142163
dev_paths = self.dev_filepaths(
143164
data_dir, self.num_dev_shards, shuffled=False)
165+
test_paths = self.test_filepaths(
166+
data_dir, self.num_test_shards, shuffled=True)
167+
168+
generator_utils.generate_files(
169+
self.generator(data_dir, tmp_dir, self.TEST_DATASETS), test_paths)
144170

145171
if self.use_train_shards_for_dev:
146172
all_paths = train_paths + dev_paths
@@ -153,22 +179,51 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
153179
self.generator(data_dir, tmp_dir, self.DEV_DATASETS), dev_paths)
154180

155181

182+
@registry.register_problem()
183+
class LibrispeechTrainFullTestClean(Librispeech):
184+
"""Problem to train on full 960h, but evaluate on clean data only."""
185+
186+
def training_filepaths(self, data_dir, num_shards, shuffled):
187+
return Librispeech.training_filepaths(data_dir, num_shards, shuffled)
188+
189+
def dev_filepaths(self, data_dir, num_shards, shuffled):
190+
return LibrispeechClean.dev_filepaths(data_dir, num_shards, shuffled)
191+
192+
def test_filepaths(self, data_dir, num_shards, shuffled):
193+
return LibrispeechClean.test_filepaths(data_dir, num_shards, shuffled)
194+
195+
def generate_data(self, data_dir, tmp_dir, task_id=-1):
196+
raise Exception("Generate librispeech and librispeech_clean data.")
197+
198+
156199
@registry.register_problem()
157200
class LibrispeechCleanSmall(Librispeech):
158-
"""Problem spec for Librispeech using 100h clean train data."""
201+
"""Problem spec for Librispeech using 100h clean train and clean eval data."""
159202

160203
# Select only the clean data
161204
TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS[:1]
162-
DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS[:1]
205+
DEV_DATASETS = _LIBRISPEECH_DEV_DATASETS[:1]
206+
TEST_DATASETS = _LIBRISPEECH_TEST_DATASETS[:1]
163207

164208

165209
@registry.register_problem()
166210
class LibrispeechClean(Librispeech):
167-
"""Problem spec for Librispeech using 460h clean train data."""
211+
"""Problem spec for Librispeech using 460h clean train and clean eval data."""
168212

169213
# Select only the clean data
170214
TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS[:2]
171-
DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS[:1]
215+
DEV_DATASETS = _LIBRISPEECH_DEV_DATASETS[:1]
216+
TEST_DATASETS = _LIBRISPEECH_TEST_DATASETS[:1]
217+
218+
219+
@registry.register_problem()
220+
class LibrispeechNoisy(Librispeech):
221+
"""Problem spec for Librispeech using 400h noisy train and noisy eval data."""
222+
223+
# Select only the clean data
224+
TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS[2:]
225+
DEV_DATASETS = _LIBRISPEECH_DEV_DATASETS[1:]
226+
TEST_DATASETS = _LIBRISPEECH_TEST_DATASETS[1:]
172227

173228

174229
# TODO(lukaszkaiser): clean up hparams or remove from here.

tensor2tensor/data_generators/mnist.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ def train_shards(self):
162162
def preprocess_example(self, example, mode, unused_hparams):
163163
image = example["inputs"]
164164
image.set_shape([_MNIST_IMAGE_SIZE, _MNIST_IMAGE_SIZE, 1])
165-
image = tf.image.per_image_standardization(image)
165+
if not self._was_reversed:
166+
image = tf.image.per_image_standardization(image)
166167
example["inputs"] = image
167168
return example
168169

tensor2tensor/data_generators/ptb.py

+4
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ def _maybe_download_corpus(tmp_dir, vocab_type):
8282
8383
Args:
8484
tmp_dir: directory containing dataset.
85+
vocab_type: which vocabulary are we using.
86+
87+
Returns:
88+
The list of names of files.
8589
"""
8690
filename = os.path.basename(PTB_URL)
8791
compressed_filepath = generator_utils.maybe_download(

0 commit comments

Comments
 (0)