Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit fd9b315

Browse files
authored
Merge pull request #633 from rsepassi/push
v1.5.4
2 parents 11f1ae4 + 4dd189e commit fd9b315

30 files changed

+417
-184
lines changed

.travis.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ env:
1010
matrix:
1111
- TF_VERSION="1.4.*"
1212
- TF_VERSION="1.5.*"
13-
- TF_VERSION="1.6.0rc1"
13+
- TF_VERSION="1.6.*"
1414
matrix:
1515
exclude:
1616
- python: "3.6"
1717
env: TF_VERSION="1.4.*"
1818
- python: "3.6"
19-
env: TF_VERSION="1.6.0rc1"
19+
env: TF_VERSION="1.6.*"
2020
before_install:
2121
- echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list
2222
- curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -

README.md

+6-5
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO
1212

1313
[Tensor2Tensor](https://github.com/tensorflow/tensor2tensor), or
1414
[T2T](https://github.com/tensorflow/tensor2tensor) for short, is a library
15-
of deep learning models and datasets designed to [accelerate deep learning
16-
research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html) and make it more accessible.
17-
18-
T2T is actively used and maintained by researchers and engineers within the
15+
of deep learning models and datasets designed to make deep learning more
16+
accessible and [accelerate ML
17+
research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html).
18+
is actively used and maintained by researchers and engineers within the
1919
[Google Brain team](https://research.google.com/teams/brain/) and a community
2020
of users. We're eager to collaborate with you too, so feel free to
2121
[open an issue on GitHub](https://github.com/tensorflow/tensor2tensor/issues)
@@ -368,6 +368,7 @@ T2T](https://research.googleblog.com/2017/06/accelerating-deep-learning-research
368368
* [Discrete Autoencoders for Sequence Models](https://arxiv.org/abs/1801.09797)
369369
* [Generating Wikipedia by Summarizing Long
370370
Sequences](https://arxiv.org/abs/1801.10198)
371-
* [Image Transformer](https://openreview.net/forum?id=r16Vyf-0-)
371+
* [Image Transformer](https://arxiv.org/abs/1802.05751)
372+
* [Training Tips for the Transformer Model](http://ufallab.ms.mff.cuni.cz/~popel/training-tips-transformer.pdf)
372373

373374
*Note: This is not an official Google product.*

docs/index.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO
1111

1212
[Tensor2Tensor](https://github.com/tensorflow/tensor2tensor), or
1313
[T2T](https://github.com/tensorflow/tensor2tensor) for short, is a library
14-
of deep learning models and datasets designed to [accelerate deep learning
15-
research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html) and make it more accessible.
14+
of deep learning models and datasets designed to make deep learning more
15+
accessible and [accelerate ML
16+
research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html).
1617

1718

1819
## Basics

docs/new_problem.md

+4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO
99
[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
1010
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)
1111

12+
Another good overview of this part together with training is given in
13+
[The Cloud ML Poetry Blog
14+
Post](https://cloud.google.com/blog/big-data/2018/02/cloud-poetry-training-and-hyperparameter-tuning-custom-text-models-on-cloud-ml-engine)
15+
1216
Let's add a new dataset together and train the
1317
[Transformer](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/models/transformer.py)
1418
model on it. We'll give the model a line of poetry, and it will learn to

docs/walkthrough.md

+6-5
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO
1212

1313
[Tensor2Tensor](https://github.com/tensorflow/tensor2tensor), or
1414
[T2T](https://github.com/tensorflow/tensor2tensor) for short, is a library
15-
of deep learning models and datasets designed to [accelerate deep learning
16-
research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html) and make it more accessible.
17-
18-
T2T is actively used and maintained by researchers and engineers within the
15+
of deep learning models and datasets designed to make deep learning more
16+
accessible and [accelerate ML
17+
research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html).
18+
is actively used and maintained by researchers and engineers within the
1919
[Google Brain team](https://research.google.com/teams/brain/) and a community
2020
of users. We're eager to collaborate with you too, so feel free to
2121
[open an issue on GitHub](https://github.com/tensorflow/tensor2tensor/issues)
@@ -368,6 +368,7 @@ T2T](https://research.googleblog.com/2017/06/accelerating-deep-learning-research
368368
* [Discrete Autoencoders for Sequence Models](https://arxiv.org/abs/1801.09797)
369369
* [Generating Wikipedia by Summarizing Long
370370
Sequences](https://arxiv.org/abs/1801.10198)
371-
* [Image Transformer](https://openreview.net/forum?id=r16Vyf-0-)
371+
* [Image Transformer](https://arxiv.org/abs/1802.05751)
372+
* [Training Tips for the Transformer Model](http://ufallab.ms.mff.cuni.cz/~popel/training-tips-transformer.pdf)
372373

373374
*Note: This is not an official Google product.*

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.5.3',
8+
version='1.5.4',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='[email protected]',

tensor2tensor/data_generators/generator_utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,9 @@ def generate_files(generator, output_filenames, max_cases=None):
147147
if outputs_exist(output_filenames):
148148
tf.logging.info("Skipping generator because outputs files exist")
149149
return
150+
tmp_filenames = [fname + ".incomplete" for fname in output_filenames]
150151
num_shards = len(output_filenames)
151-
writers = [tf.python_io.TFRecordWriter(fname) for fname in output_filenames]
152+
writers = [tf.python_io.TFRecordWriter(fname) for fname in tmp_filenames]
152153
counter, shard = 0, 0
153154
for case in generator:
154155
if case is None:
@@ -165,6 +166,9 @@ def generate_files(generator, output_filenames, max_cases=None):
165166
for writer in writers:
166167
writer.close()
167168

169+
for tmp_name, final_name in zip(tmp_filenames, output_filenames):
170+
tf.gfile.Rename(tmp_name, final_name)
171+
168172
tf.logging.info("Generated %s Examples", counter)
169173

170174

tensor2tensor/data_generators/gym.py

+23-27
Original file line numberDiff line numberDiff line change
@@ -19,39 +19,30 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22-
import os
22+
import functools
2323

2424
# Dependency imports
2525

26-
import numpy as np
27-
import functools
2826
import gym
27+
import numpy as np
2928

30-
from tensor2tensor.rl import rl_trainer_lib
31-
from tensor2tensor.rl.envs import atari_wrappers
32-
from tensor2tensor.models.research import rl
3329
from tensor2tensor.data_generators import generator_utils
3430
from tensor2tensor.data_generators import problem
31+
from tensor2tensor.models.research import rl
32+
from tensor2tensor.rl.envs import atari_wrappers
3533
from tensor2tensor.utils import registry
3634

3735
import tensorflow as tf
3836

3937

38+
39+
4040
flags = tf.flags
4141
FLAGS = flags.FLAGS
4242

4343
flags.DEFINE_string("model_path", "", "File with model for pong")
4444

4545

46-
def gym_lib():
47-
"""Access to gym to allow for import of this file without a gym install."""
48-
try:
49-
import gym # pylint: disable=g-import-not-at-top
50-
except ImportError:
51-
raise ImportError("pip install gym to use gym-based Problems")
52-
return gym
53-
54-
5546
class GymDiscreteProblem(problem.Problem):
5647
"""Gym environment with discrete actions and rewards."""
5748

@@ -67,7 +58,7 @@ def env_name(self):
6758
@property
6859
def env(self):
6960
if self._env is None:
70-
self._env = gym_lib().make(self.env_name)
61+
self._env = gym.make(self.env_name)
7162
return self._env
7263

7364
@property
@@ -157,8 +148,6 @@ def num_steps(self):
157148
return 5000
158149

159150

160-
161-
162151
@registry.register_problem
163152
class GymPongTrajectoriesFromPolicy(GymDiscreteProblem):
164153
"""Pong game, loaded actions."""
@@ -167,28 +156,34 @@ def __init__(self, event_dir, *args, **kwargs):
167156
super(GymPongTrajectoriesFromPolicy, self).__init__(*args, **kwargs)
168157
self._env = None
169158
self._event_dir = event_dir
170-
env_spec = lambda: atari_wrappers.wrap_atari(
171-
gym.make("PongNoFrameskip-v4"), warp=False, frame_skip=4, frame_stack=False)
159+
env_spec = lambda: atari_wrappers.wrap_atari( # pylint: disable=g-long-lambda
160+
gym.make("PongNoFrameskip-v4"),
161+
warp=False,
162+
frame_skip=4,
163+
frame_stack=False)
172164
hparams = rl.atari_base()
173165
with tf.variable_scope("train"):
174166
policy_lambda = hparams.network
175167
policy_factory = tf.make_template(
176-
"network",
177-
functools.partial(policy_lambda, env_spec().action_space, hparams))
178-
self._max_frame_pl = tf.placeholder(tf.float32, self.env.observation_space.shape)
179-
actor_critic = policy_factory(tf.expand_dims(tf.expand_dims(self._max_frame_pl, 0), 0))
168+
"network",
169+
functools.partial(policy_lambda, env_spec().action_space, hparams))
170+
self._max_frame_pl = tf.placeholder(
171+
tf.float32, self.env.observation_space.shape)
172+
actor_critic = policy_factory(tf.expand_dims(tf.expand_dims(
173+
self._max_frame_pl, 0), 0))
180174
policy = actor_critic.policy
181175
self._last_policy_op = policy.mode()
182176
self._last_action = self.env.action_space.sample()
183177
self._skip = 4
184178
self._skip_step = 0
185-
self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape, dtype=np.uint8)
179+
self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape,
180+
dtype=np.uint8)
186181
self._sess = tf.Session()
187182
model_saver = tf.train.Saver(tf.global_variables(".*network_parameters.*"))
188183
model_saver.restore(self._sess, FLAGS.model_path)
189184

190185
# TODO(blazej0): For training of atari agents wrappers are usually used.
191-
# Below we have a hacky solution which is a temporary workaround to be used together
186+
# Below we have a hacky solution which is a workaround to be used together
192187
# with atari_wrappers.MaxAndSkipEnv.
193188
def get_action(self, observation=None):
194189
if self._skip_step == self._skip - 2: self._obs_buffer[0] = observation
@@ -197,7 +192,8 @@ def get_action(self, observation=None):
197192
if self._skip_step == 0:
198193
max_frame = self._obs_buffer.max(axis=0)
199194
self._last_action = int(self._sess.run(
200-
self._last_policy_op, feed_dict={self._max_frame_pl: max_frame})[0, 0])
195+
self._last_policy_op,
196+
feed_dict={self._max_frame_pl: max_frame})[0, 0])
201197
return self._last_action
202198

203199
@property

tensor2tensor/data_generators/imagenet.py

+99
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,84 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import os
2223
# Dependency imports
2324

25+
from tensor2tensor.data_generators import generator_utils
2426
from tensor2tensor.data_generators import image_utils
2527
from tensor2tensor.utils import registry
2628

2729
import tensorflow as tf
2830

31+
# URLs and filenames for IMAGENET 32x32 data from
32+
# https://arxiv.org/abs/1601.06759.
33+
_IMAGENET_SMALL_ROOT_URL = "http://image-net.org/small/"
34+
_IMAGENET_SMALL_URLS = [
35+
"train_32x32.tar", "valid_32x32.tar"]
36+
_IMAGENET_SMALL_TRAIN_PREFIX = "train_32x32"
37+
_IMAGENET_SMALL_EVAL_PREFIX = "valid_32x32"
38+
_IMAGENET_SMALL_IMAGE_SIZE = 32
39+
40+
41+
# URLs and filenames for IMAGENET 64x64 data.
42+
_IMAGENET_MEDIUM_ROOT_URL = "http://image-net.org/small/"
43+
_IMAGENET_MEDIUM_URLS = [
44+
"train_64x64.tar", "valid_64x64.tar"]
45+
_IMAGENET_MEDIUM_TRAIN_PREFIX = "train_64x64"
46+
_IMAGENET_MEDIUM_EVAL_PREFIX = "valid_64x64"
47+
_IMAGENET_MEDIUM_IMAGE_SIZE = 64
48+
2949

3050
# Derived from ImageNet data
3151
MEAN_RGB = [0.485, 0.456, 0.406]
3252
STDDEV_RGB = [0.229, 0.224, 0.225]
3353

3454

55+
def imagenet_pixelrnn_generator(tmp_dir,
56+
training,
57+
size=_IMAGENET_SMALL_IMAGE_SIZE):
58+
"""Image generator for Imagenet 64x64 downsampled images.
59+
60+
It assumes that the data has been downloaded from
61+
http://image-net.org/small/*_32x32.tar or
62+
http://image-net.org/small/*_64x64.tar into tmp_dir.
63+
Args:
64+
tmp_dir: path to temporary storage directory.
65+
training: a Boolean; if true, we use the train set, otherwise the test set.
66+
size: image size (assumes height and width are same)
67+
68+
Yields:
69+
A dictionary representing the images with the following fields:
70+
* image/encoded: the string encoding the image as JPEG,
71+
* image/format: the string "jpeg" representing image format,
72+
* image/height: an integer representing the height,
73+
* image/width: an integer representing the width.
74+
Every field is actually a list of the corresponding type.
75+
"""
76+
if size == _IMAGENET_SMALL_IMAGE_SIZE:
77+
train_prefix = _IMAGENET_SMALL_TRAIN_PREFIX
78+
eval_prefix = _IMAGENET_SMALL_EVAL_PREFIX
79+
else:
80+
train_prefix = _IMAGENET_MEDIUM_TRAIN_PREFIX
81+
eval_prefix = _IMAGENET_MEDIUM_EVAL_PREFIX
82+
prefix = train_prefix if training else eval_prefix
83+
images_filepath = os.path.join(tmp_dir, prefix)
84+
image_files = tf.gfile.Glob(images_filepath + "/*")
85+
height = size
86+
width = size
87+
const_label = 0
88+
for filename in image_files:
89+
with tf.gfile.Open(filename, "r") as f:
90+
encoded_image = f.read()
91+
yield {
92+
"image/encoded": [encoded_image],
93+
"image/format": ["png"],
94+
"image/class/label": [const_label],
95+
"image/height": [height],
96+
"image/width": [width]
97+
}
98+
99+
35100
def imagenet_preprocess_example(example, mode, resize_size=None):
36101
"""Preprocessing used for Imagenet and similar problems."""
37102
resize_size = resize_size or [299, 299]
@@ -123,6 +188,40 @@ def preprocess_example(self, example, mode, _):
123188
return example
124189

125190

191+
@registry.register_problem
192+
class ImageImagenet64Gen(ImageImagenet):
193+
"""Cifar-10 Tune."""
194+
195+
@property
196+
def train_shards(self):
197+
return 1024
198+
199+
@property
200+
def dev_shards(self):
201+
return 10
202+
203+
def generate_data(self, data_dir, tmp_dir, task_id=-1):
204+
generator_utils.generate_dataset_and_shuffle(
205+
self.generator(data_dir, tmp_dir, True),
206+
self.training_filepaths(data_dir, self.train_shards, shuffled=True),
207+
self.generator(data_dir, tmp_dir, False),
208+
self.dev_filepaths(data_dir, self.dev_shards, shuffled=True))
209+
210+
def generator(self, data_dir, tmp_dir, is_training):
211+
if is_training:
212+
return imagenet_pixelrnn_generator(
213+
tmp_dir, int(True), size=_IMAGENET_MEDIUM_IMAGE_SIZE)
214+
else:
215+
return imagenet_pixelrnn_generator(
216+
tmp_dir, int(False), size=_IMAGENET_MEDIUM_IMAGE_SIZE)
217+
218+
def preprocess_example(self, example, mode, unused_hparams):
219+
example["inputs"].set_shape([_IMAGENET_MEDIUM_IMAGE_SIZE,
220+
_IMAGENET_MEDIUM_IMAGE_SIZE, 3])
221+
example["inputs"] = tf.to_int64(example["inputs"])
222+
return example
223+
224+
126225
@registry.register_problem
127226
class ImageImagenet64(ImageImagenet32):
128227
"""Imagenet rescaled to 64x64."""

tensor2tensor/data_generators/inspect.py

+6
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def main(_):
6060
total_sequences = 0
6161
total_input_tokens = 0
6262
total_target_tokens = 0
63+
nonpadding_input_tokens = 0
64+
nonpadding_target_tokens = 0
6365
max_input_length = 0
6466
max_target_length = 0
6567
for record in reader:
@@ -71,6 +73,8 @@ def main(_):
7173
print("INPUTS:\n" + encoder.decode(inputs) if encoder else inputs)
7274
if FLAGS.print_targets:
7375
print("TARGETS:\n" + encoder.decode(targets) if encoder else targets)
76+
nonpadding_input_tokens += len(inputs) - inputs.count(0)
77+
nonpadding_target_tokens += len(targets) - targets.count(0)
7478
total_input_tokens += len(inputs)
7579
total_target_tokens += len(targets)
7680
total_sequences += 1
@@ -83,6 +87,8 @@ def main(_):
8387
print("total_sequences: %d" % total_sequences)
8488
print("total_input_tokens: %d" % total_input_tokens)
8589
print("total_target_tokens: %d" % total_target_tokens)
90+
print("nonpadding_input_tokens: %d" % nonpadding_input_tokens)
91+
print("nonpadding_target_tokens: %d" % nonpadding_target_tokens)
8692
print("max_input_length: %d" % max_input_length)
8793
print("max_target_length: %d" % max_target_length)
8894

0 commit comments

Comments
 (0)