Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit d029c45

Browse files
authored
Merge pull request #34 from rsepassi/push
v1.0.7
2 parents 3410bea + d578f52 commit d029c45

13 files changed

+285
-12
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,7 @@
33

44
# Python egg metadata, regenerated from source files by setuptools.
55
/*.egg-info
6+
7+
# PyPI distribution artificats
8+
build/
9+
dist/

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.0.6',
8+
version='1.0.7',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='[email protected]',

tensor2tensor/data_generators/generator_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import io
2323
import os
2424
import tarfile
25+
import urllib
2526

2627
# Dependency imports
2728

tensor2tensor/data_generators/image.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,9 @@
2929
# Dependency imports
3030

3131
import numpy as np
32+
from six.moves import cPickle
3233
from six.moves import xrange # pylint: disable=redefined-builtin
3334
from six.moves import zip # pylint: disable=redefined-builtin
34-
from six.moves import cPickle
35-
3635
from tensor2tensor.data_generators import generator_utils
3736

3837
import tensorflow as tf
@@ -201,10 +200,6 @@ def cifar10_generator(tmp_dir, training, how_many, start_from=0):
201200
])
202201
labels = data["labels"]
203202
all_labels.extend([labels[j] for j in xrange(num_images)])
204-
# Shuffle the data to make sure classes are well distributed.
205-
data = zip(all_images, all_labels)
206-
random.shuffle(data)
207-
all_images, all_labels = zip(*data)
208203
return image_generator(all_images[start_from:start_from + how_many],
209204
all_labels[start_from:start_from + how_many])
210205

tensor2tensor/data_generators/text_encoder.py

100755100644
+5-3
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@
2323
from __future__ import division
2424
from __future__ import print_function
2525

26+
from collections import defaultdict
27+
2628
# Dependency imports
2729

2830
import six
2931
from six.moves import xrange # pylint: disable=redefined-builtin
30-
from collections import defaultdict
3132
from tensor2tensor.data_generators import tokenizer
3233

3334
import tensorflow as tf
@@ -41,6 +42,7 @@
4142
else:
4243
RESERVED_TOKENS_BYTES = [bytes(PAD, 'ascii'), bytes(EOS, 'ascii')]
4344

45+
4446
class TextEncoder(object):
4547
"""Base class for converting from ints to/from human readable strings."""
4648

@@ -95,7 +97,7 @@ def encode(self, s):
9597
if six.PY2:
9698
return [ord(c) + numres for c in s]
9799
# Python3: explicitly convert to UTF-8
98-
return [c + numres for c in s.encode("utf-8")]
100+
return [c + numres for c in s.encode('utf-8')]
99101

100102
def decode(self, ids):
101103
numres = self._num_reserved_ids
@@ -109,7 +111,7 @@ def decode(self, ids):
109111
if six.PY2:
110112
return ''.join(decoded_ids)
111113
# Python3: join byte arrays and then decode string
112-
return b''.join(decoded_ids).decode("utf-8")
114+
return b''.join(decoded_ids).decode('utf-8')
113115

114116
@property
115117
def vocab_size(self):

tensor2tensor/data_generators/tokenizer.py

100755100644
+2-1
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,13 @@
4545
from __future__ import division
4646
from __future__ import print_function
4747

48+
from collections import defaultdict
4849
import string
4950

5051
# Dependency imports
5152

5253
from six.moves import xrange # pylint: disable=redefined-builtin
53-
from collections import defaultdict
54+
5455

5556
class Tokenizer(object):
5657
"""Vocab for breaking words into wordpieces.

tensor2tensor/models/bluenet.py

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright 2017 Google Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""BlueNet: and out of the blue network to experiment with shake-shake."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
# Dependency imports
22+
23+
from six.moves import xrange # pylint: disable=redefined-builtin
24+
25+
from tensor2tensor.models import common_hparams
26+
from tensor2tensor.models import common_layers
27+
from tensor2tensor.utils import registry
28+
from tensor2tensor.utils import t2t_model
29+
30+
import tensorflow as tf
31+
32+
33+
def residual_module(x, hparams, train, n, sep):
34+
"""A stack of convolution blocks with residual connection."""
35+
k = (hparams.kernel_height, hparams.kernel_width)
36+
dilations_and_kernels = [((1, 1), k) for _ in xrange(n)]
37+
with tf.variable_scope("residual_module%d_sep%d" % (n, sep)):
38+
y = common_layers.subseparable_conv_block(
39+
x,
40+
hparams.hidden_size,
41+
dilations_and_kernels,
42+
padding="SAME",
43+
separability=sep,
44+
name="block")
45+
x = common_layers.layer_norm(x + y, hparams.hidden_size, name="lnorm")
46+
return tf.nn.dropout(x, 1.0 - hparams.dropout * tf.to_float(train))
47+
48+
49+
def residual_module1(x, hparams, train):
50+
return residual_module(x, hparams, train, 1, 1)
51+
52+
53+
def residual_module1_sep(x, hparams, train):
54+
return residual_module(x, hparams, train, 1, 0)
55+
56+
57+
def residual_module2(x, hparams, train):
58+
return residual_module(x, hparams, train, 2, 1)
59+
60+
61+
def residual_module2_sep(x, hparams, train):
62+
return residual_module(x, hparams, train, 2, 0)
63+
64+
65+
def residual_module3(x, hparams, train):
66+
return residual_module(x, hparams, train, 3, 1)
67+
68+
69+
def residual_module3_sep(x, hparams, train):
70+
return residual_module(x, hparams, train, 3, 0)
71+
72+
73+
def norm_module(x, hparams, train):
74+
del train # Unused.
75+
return common_layers.layer_norm(x, hparams.hidden_size, name="norm_module")
76+
77+
78+
def identity_module(x, hparams, train):
79+
del hparams, train # Unused.
80+
return x
81+
82+
83+
def run_modules(blocks, cur, hparams, train, dp):
84+
"""Run blocks in parallel using dp as data_parallelism."""
85+
assert len(blocks) % dp.n == 0
86+
res = []
87+
for i in xrange(len(blocks) // dp.n):
88+
res.extend(dp(blocks[i * dp.n:(i + 1) * dp.n], cur, hparams, train))
89+
return res
90+
91+
92+
@registry.register_model
93+
class BlueNet(t2t_model.T2TModel):
94+
95+
def model_fn_body_sharded(self, sharded_features, train):
96+
dp = self._data_parallelism
97+
dp._reuse = False # pylint:disable=protected-access
98+
hparams = self._hparams
99+
blocks = [identity_module, norm_module,
100+
residual_module1, residual_module1_sep,
101+
residual_module2, residual_module2_sep,
102+
residual_module3, residual_module3_sep]
103+
inputs = sharded_features["inputs"]
104+
105+
cur = tf.concat(inputs, axis=0)
106+
cur_shape = cur.get_shape()
107+
for i in xrange(hparams.num_hidden_layers):
108+
with tf.variable_scope("layer_%d" % i):
109+
processed = run_modules(blocks, cur, hparams, train, dp)
110+
cur = common_layers.shakeshake(processed)
111+
cur.set_shape(cur_shape)
112+
113+
return list(tf.split(cur, len(inputs), axis=0)), 0.0
114+
115+
116+
@registry.register_hparams
117+
def bluenet_base():
118+
"""Set of hyperparameters."""
119+
hparams = common_hparams.basic_params1()
120+
hparams.batch_size = 4096
121+
hparams.hidden_size = 768
122+
hparams.dropout = 0.2
123+
hparams.symbol_dropout = 0.2
124+
hparams.label_smoothing = 0.1
125+
hparams.clip_grad_norm = 2.0
126+
hparams.num_hidden_layers = 8
127+
hparams.kernel_height = 3
128+
hparams.kernel_width = 3
129+
hparams.learning_rate_decay_scheme = "exp50k"
130+
hparams.learning_rate = 0.05
131+
hparams.learning_rate_warmup_steps = 3000
132+
hparams.initializer_gain = 1.0
133+
hparams.weight_decay = 3.0
134+
hparams.num_sampled_classes = 0
135+
hparams.sampling_method = "argmax"
136+
hparams.optimizer_adam_epsilon = 1e-6
137+
hparams.optimizer_adam_beta1 = 0.85
138+
hparams.optimizer_adam_beta2 = 0.997
139+
hparams.add_hparam("imagenet_use_2d", True)
140+
return hparams
141+
142+
143+
@registry.register_hparams
144+
def bluenet_tiny():
145+
hparams = bluenet_base()
146+
hparams.batch_size = 1024
147+
hparams.hidden_size = 128
148+
hparams.num_hidden_layers = 4
149+
hparams.learning_rate_decay_scheme = "none"
150+
return hparams

tensor2tensor/models/bluenet_test.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2017 Google Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""BlueNet tests."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
# Dependency imports
22+
23+
import numpy as np
24+
25+
from tensor2tensor.data_generators import problem_hparams
26+
from tensor2tensor.models import bluenet
27+
28+
import tensorflow as tf
29+
30+
31+
class BlueNetTest(tf.test.TestCase):
32+
33+
def testBlueNet(self):
34+
vocab_size = 9
35+
x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1))
36+
y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 1, 1, 1))
37+
hparams = bluenet.bluenet_tiny()
38+
p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size,
39+
vocab_size)
40+
with self.test_session() as session:
41+
features = {
42+
"inputs": tf.constant(x, dtype=tf.int32),
43+
"targets": tf.constant(y, dtype=tf.int32),
44+
}
45+
model = bluenet.BlueNet(hparams, p_hparams)
46+
sharded_logits, _, _ = model.model_fn(features, True)
47+
logits = tf.concat(sharded_logits, 0)
48+
session.run(tf.global_variables_initializer())
49+
res = session.run(logits)
50+
self.assertEqual(res.shape, (3, 5, 1, 1, vocab_size))
51+
52+
53+
if __name__ == "__main__":
54+
tf.test.main()

tensor2tensor/models/common_layers.py

+46
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,52 @@ def inverse_exp_decay(max_step, min_value=0.01):
5858
return inv_base**tf.maximum(float(max_step) - step, 0.0)
5959

6060

61+
def shakeshake2_py(x, y, equal=False):
62+
"""The shake-shake sum of 2 tensors, python version."""
63+
alpha = 0.5 if equal else tf.random_uniform([])
64+
return alpha * x + (1.0 - alpha) * y
65+
66+
67+
@function.Defun()
68+
def shakeshake2_grad(x1, x2, dy):
69+
"""Overriding gradient for shake-shake of 2 tensors."""
70+
y = shakeshake2_py(x1, x2)
71+
dx = tf.gradients(ys=[y], xs=[x1, x2], grad_ys=[dy])
72+
return dx
73+
74+
75+
@function.Defun()
76+
def shakeshake2_equal_grad(x1, x2, dy):
77+
"""Overriding gradient for shake-shake of 2 tensors."""
78+
y = shakeshake2_py(x1, x2, equal=True)
79+
dx = tf.gradients(ys=[y], xs=[x1, x2], grad_ys=[dy])
80+
return dx
81+
82+
83+
@function.Defun(grad_func=shakeshake2_grad)
84+
def shakeshake2(x1, x2):
85+
"""The shake-shake function with a different alpha for forward/backward."""
86+
return shakeshake2_py(x1, x2)
87+
88+
89+
@function.Defun(grad_func=shakeshake2_equal_grad)
90+
def shakeshake2_eqgrad(x1, x2):
91+
"""The shake-shake function with a different alpha for forward/backward."""
92+
return shakeshake2_py(x1, x2)
93+
94+
95+
def shakeshake(xs, equal_grad=False):
96+
"""Multi-argument shake-shake, currently approximated by sums of 2."""
97+
if len(xs) == 1:
98+
return xs[0]
99+
div = (len(xs) + 1) // 2
100+
arg1 = shakeshake(xs[:div], equal_grad=equal_grad)
101+
arg2 = shakeshake(xs[div:], equal_grad=equal_grad)
102+
if equal_grad:
103+
return shakeshake2_eqgrad(arg1, arg2)
104+
return shakeshake2(arg1, arg2)
105+
106+
61107
def standardize_images(x):
62108
"""Image standardization on batches (tf.image.per_image_standardization)."""
63109
with tf.name_scope("standardize_images", [x]):

tensor2tensor/models/common_layers_test.py

+9
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,15 @@ def testEmbedding(self):
6565
res = session.run(y)
6666
self.assertEqual(res.shape, (3, 5, 16))
6767

68+
def testShakeShake(self):
69+
x = np.random.rand(5, 7)
70+
with self.test_session() as session:
71+
x = tf.constant(x, dtype=tf.float32)
72+
y = common_layers.shakeshake([x, x, x, x, x])
73+
session.run(tf.global_variables_initializer())
74+
inp, res = session.run([x, y])
75+
self.assertAllClose(res, inp)
76+
6877
def testConv(self):
6978
x = np.random.rand(5, 7, 1, 11)
7079
with self.test_session() as session:

tensor2tensor/models/models.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from tensor2tensor.models import attention_lm
2626
from tensor2tensor.models import attention_lm_moe
27+
from tensor2tensor.models import bluenet
2728
from tensor2tensor.models import bytenet
2829
from tensor2tensor.models import lstm
2930
from tensor2tensor.models import modalities

tensor2tensor/models/xception.py

+10
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,13 @@ def xception_base():
8787
hparams.optimizer_adam_beta2 = 0.997
8888
hparams.add_hparam("imagenet_use_2d", True)
8989
return hparams
90+
91+
92+
@registry.register_hparams
93+
def xception_tiny():
94+
hparams = xception_base()
95+
hparams.batch_size = 1024
96+
hparams.hidden_size = 128
97+
hparams.num_hidden_layers = 4
98+
hparams.learning_rate_decay_scheme = "none"
99+
return hparams

0 commit comments

Comments
 (0)