Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 204b359

Browse files
authored
Merge pull request #29 from rsepassi/push
Push 1.0.6
2 parents 9d04261 + 2f4d5b7 commit 204b359

File tree

6 files changed

+39
-11
lines changed

6 files changed

+39
-11
lines changed

Diff for: README.md

+14
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,20 @@ issues](https://github.com/tensorflow/tensor2tensor/issues).
2626
And chat with us and other users on
2727
[Gitter](https://gitter.im/tensor2tensor/Lobby).
2828

29+
### Contents
30+
31+
* [Walkthrough](#walkthrough)
32+
* [Installation](#installation)
33+
* [Features](#features)
34+
* [T2T Overview](#t2t-overview)
35+
* [Datasets](#datasets)
36+
* [Problems and Modalities](#problems-and-modalities)
37+
* [Models](#models)
38+
* [Hyperparameter Sets](#hyperparameter-sets)
39+
* [Trainer](#trainer)
40+
* [Adding your own components](#adding-your-own-components)
41+
* [Adding a dataset](#adding-a-dataset)
42+
2943
---
3044

3145
## Walkthrough

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.0.5',
8+
version='1.0.6',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='[email protected]',

Diff for: tensor2tensor/data_generators/generator_utils.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,15 @@ def generate_files(generator,
127127

128128

129129
def download_report_hook(count, block_size, total_size):
130-
"""Report hook for download progress
130+
"""Report hook for download progress.
131131
132132
Args:
133133
count: current block number
134134
block_size: block size
135135
total_size: total size
136136
"""
137-
percent = int(count*block_size*100/total_size)
138-
print("\r%d%%" % percent + ' completed', end='\r')
137+
percent = int(count * block_size * 100 / total_size)
138+
print("\r%d%%" % percent + " completed", end="\r")
139139

140140

141141
def maybe_download(directory, filename, url):
@@ -155,11 +155,12 @@ def maybe_download(directory, filename, url):
155155
filepath = os.path.join(directory, filename)
156156
if not tf.gfile.Exists(filepath):
157157
tf.logging.info("Downloading %s to %s" % (url, filepath))
158-
filepath, _ = urllib.urlretrieve(url, filepath,
159-
reporthook=download_report_hook)
160-
158+
inprogress_filepath = filepath + ".incomplete"
159+
inprogress_filepath, _ = urllib.urlretrieve(url, inprogress_filepath,
160+
reporthook=download_report_hook)
161161
# Print newline to clear the carriage return from the download progress
162162
print()
163+
tf.gfile.Rename(inprogress_filepath, filepath)
163164
statinfo = os.stat(filepath)
164165
tf.logging.info("Succesfully downloaded %s, %s bytes." % (filename,
165166
statinfo.st_size))

Diff for: tensor2tensor/models/common_layers.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,7 @@ def conv_hidden_relu(inputs,
10791079
hidden_size,
10801080
output_size,
10811081
kernel_size=(1, 1),
1082+
second_kernel_size=(1, 1),
10821083
summaries=True,
10831084
dropout=0.0,
10841085
**kwargs):
@@ -1090,7 +1091,8 @@ def conv_hidden_relu(inputs,
10901091
inputs = tf.expand_dims(inputs, 2)
10911092
else:
10921093
is_3d = False
1093-
h = conv(
1094+
conv_f1 = conv if kernel_size == (1, 1) else separable_conv
1095+
h = conv_f1(
10941096
inputs,
10951097
hidden_size,
10961098
kernel_size,
@@ -1103,7 +1105,8 @@ def conv_hidden_relu(inputs,
11031105
tf.summary.histogram("hidden_density_logit",
11041106
relu_density_logit(
11051107
h, list(range(inputs.shape.ndims - 1))))
1106-
ret = conv(h, output_size, (1, 1), name="conv2", **kwargs)
1108+
conv_f2 = conv if second_kernel_size == (1, 1) else separable_conv
1109+
ret = conv_f2(h, output_size, second_kernel_size, name="conv2", **kwargs)
11071110
if is_3d:
11081111
ret = tf.squeeze(ret, 2)
11091112
return ret

Diff for: tensor2tensor/models/transformer.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,15 @@ def transformer_ffn_layer(x, hparams):
255255
hparams.filter_size,
256256
hparams.num_heads,
257257
hparams.attention_dropout)
258+
elif hparams.ffn_layer == "conv_hidden_relu_with_sepconv":
259+
return common_layers.conv_hidden_relu(
260+
x,
261+
hparams.filter_size,
262+
hparams.hidden_size,
263+
kernel_size=(3, 1),
264+
second_kernel_size=(31, 1),
265+
padding="LEFT",
266+
dropout=hparams.relu_dropout)
258267
else:
259268
assert hparams.ffn_layer == "none"
260269
return x
@@ -342,7 +351,6 @@ def transformer_parsing_base():
342351
hparams.learning_rate_warmup_steps = 16000
343352
hparams.hidden_size = 1024
344353
hparams.learning_rate = 0.05
345-
hparams.residual_dropout = 0.1
346354
hparams.shared_embedding_and_softmax_weights = int(False)
347355
return hparams
348356

Diff for: tensor2tensor/utils/data_reader.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,16 @@ def examples_queue(data_sources,
100100
with tf.name_scope("examples_queue"):
101101
# Read serialized examples using slim parallel_reader.
102102
num_epochs = None if training else 1
103+
data_files = tf.contrib.slim.parallel_reader.get_data_files(data_sources)
104+
num_readers = min(4 if training else 1, len(data_files))
103105
_, example_serialized = tf.contrib.slim.parallel_reader.parallel_read(
104106
data_sources,
105107
tf.TFRecordReader,
106108
num_epochs=num_epochs,
107109
shuffle=training,
108110
capacity=2 * capacity,
109111
min_after_dequeue=capacity,
110-
num_readers=4 if training else 1)
112+
num_readers=num_readers)
111113

112114
if data_items_to_decoders is None:
113115
data_items_to_decoders = {

0 commit comments

Comments
 (0)