Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit c1cd875

Browse files
authored
Merge pull request #484 from rsepassi/push
v1.4.1
2 parents 758991d + 02da1be commit c1cd875

17 files changed

+973
-252
lines changed

.travis.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ env:
1414
- T2T_DATA_DIR=/tmp/t2t-data
1515
- T2T_TRAIN_DIR=/tmp/t2t-train
1616
script:
17-
- pytest --ignore=tensor2tensor/utils/registry_test.py --ignore=tensor2tensor/problems_test.py --ignore=tensor2tensor/tpu/tpu_trainer_lib_test.py
17+
- pytest --ignore=tensor2tensor/utils/registry_test.py --ignore=tensor2tensor/problems_test.py --ignore=tensor2tensor/tpu/tpu_trainer_lib_test.py --ignore=tensor2tensor/data_generators/algorithmic_math_test.py
1818
- pytest tensor2tensor/utils/registry_test.py
1919
- pytest tensor2tensor/tpu/tpu_trainer_lib_test.py
2020
- t2t-datagen 2>&1 | grep translate && echo passed

setup.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.4.0',
8+
version='1.4.1',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='[email protected]',
@@ -30,6 +30,7 @@
3030
'gym',
3131
'numpy',
3232
'requests',
33+
'scipy',
3334
'sympy',
3435
'six',
3536
],

tensor2tensor/bin/t2t-trainer

+8-6
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ flags.DEFINE_string("t2t_usr_dir", "",
4545
"The imported files should contain registrations, "
4646
"e.g. @registry.register_model calls, that will then be "
4747
"available to the t2t-trainer.")
48+
flags.DEFINE_integer("random_seed", 1234, "Random seed.")
4849
flags.DEFINE_integer("tpu_num_shards", 8, "Number of tpu shards.")
4950
flags.DEFINE_integer("iterations_per_loop", 1000,
5051
"Number of iterations in a TPU training loop.")
@@ -61,7 +62,11 @@ try:
6162
flags.DEFINE_string("output_dir", "", "Base output directory for run.")
6263
flags.DEFINE_string("schedule", "continuous_train_and_eval",
6364
"Method of Experiment to run.")
64-
flags.DEFINE_integer("eval_steps", 200, "Number of steps in evaluation.")
65+
flags.DEFINE_integer("eval_steps", 10000,
66+
"Number of steps in evaluation. By default, eval will "
67+
"stop after eval_steps or when it runs through the eval "
68+
"dataset once in full, whichever comes first, so this "
69+
"can be a very large number.")
6570
except: # pylint: disable=bare-except
6671
pass
6772

@@ -77,9 +82,6 @@ def create_hparams():
7782

7883

7984
def create_experiment_fn():
80-
use_validation_monitor = (FLAGS.schedule in
81-
["train_and_evaluate", "continuous_train_and_eval"]
82-
and FLAGS.local_eval_frequency)
8385
return tpu_trainer_lib.create_experiment_fn(
8486
model_name=FLAGS.model,
8587
problem_name=get_problem_name(),
@@ -92,9 +94,9 @@ def create_experiment_fn():
9294
decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
9395
use_tfdbg=FLAGS.tfdbg,
9496
use_dbgprofile=FLAGS.dbgprofile,
95-
use_validation_monitor=use_validation_monitor,
9697
eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
9798
eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
99+
eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta,
98100
eval_early_stopping_metric_minimize=FLAGS.
99101
eval_early_stopping_metric_minimize,
100102
use_tpu=FLAGS.use_tpu)
@@ -170,7 +172,7 @@ def execute_schedule(exp):
170172

171173
def main(_):
172174
tf.logging.set_verbosity(tf.logging.INFO)
173-
tf.set_random_seed(123)
175+
tpu_trainer_lib.set_random_seed(FLAGS.random_seed)
174176
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
175177
log_registry()
176178

tensor2tensor/bin/t2t_trainer.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
"The imported files should contain registrations, "
4545
"e.g. @registry.register_model calls, that will then be "
4646
"available to the t2t-trainer.")
47+
flags.DEFINE_integer("random_seed", 1234, "Random seed.")
4748
flags.DEFINE_integer("tpu_num_shards", 8, "Number of tpu shards.")
4849
flags.DEFINE_integer("iterations_per_loop", 1000,
4950
"Number of iterations in a TPU training loop.")
@@ -60,7 +61,11 @@
6061
flags.DEFINE_string("output_dir", "", "Base output directory for run.")
6162
flags.DEFINE_string("schedule", "continuous_train_and_eval",
6263
"Method of Experiment to run.")
63-
flags.DEFINE_integer("eval_steps", 200, "Number of steps in evaluation.")
64+
flags.DEFINE_integer("eval_steps", 10000,
65+
"Number of steps in evaluation. By default, eval will "
66+
"stop after eval_steps or when it runs through the eval "
67+
"dataset once in full, whichever comes first, so this "
68+
"can be a very large number.")
6469
except: # pylint: disable=bare-except
6570
pass
6671

@@ -76,9 +81,6 @@ def create_hparams():
7681

7782

7883
def create_experiment_fn():
79-
use_validation_monitor = (FLAGS.schedule in
80-
["train_and_evaluate", "continuous_train_and_eval"]
81-
and FLAGS.local_eval_frequency)
8284
return tpu_trainer_lib.create_experiment_fn(
8385
model_name=FLAGS.model,
8486
problem_name=get_problem_name(),
@@ -91,9 +93,9 @@ def create_experiment_fn():
9193
decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
9294
use_tfdbg=FLAGS.tfdbg,
9395
use_dbgprofile=FLAGS.dbgprofile,
94-
use_validation_monitor=use_validation_monitor,
9596
eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
9697
eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
98+
eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta,
9799
eval_early_stopping_metric_minimize=FLAGS.
98100
eval_early_stopping_metric_minimize,
99101
use_tpu=FLAGS.use_tpu)
@@ -169,7 +171,7 @@ def execute_schedule(exp):
169171

170172
def main(_):
171173
tf.logging.set_verbosity(tf.logging.INFO)
172-
tf.set_random_seed(123)
174+
tpu_trainer_lib.set_random_seed(FLAGS.random_seed)
173175
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
174176
log_registry()
175177

tensor2tensor/data_generators/algorithmic_math_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
"""Tests for tensor2tensor.data_generators.algorithmic_math."""
17+
# TODO(rsepassi): This test is flaky. Disable, remove, or update.
1718

1819
from __future__ import absolute_import
1920
from __future__ import division

tensor2tensor/data_generators/librispeech.py

+34-169
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,14 @@
1616
"""Librispeech dataset."""
1717

1818
import os
19-
from subprocess import call
2019
import tarfile
21-
import wave
2220

2321
# Dependency imports
2422

25-
import numpy as np
26-
2723
from tensor2tensor.data_generators import generator_utils
28-
from tensor2tensor.data_generators import problem
29-
from tensor2tensor.data_generators import text_encoder
30-
from tensor2tensor.layers import common_layers
31-
from tensor2tensor.utils import modality
24+
from tensor2tensor.data_generators import speech_recognition
3225
from tensor2tensor.utils import registry
3326

34-
import tensorflow as tf
35-
3627

3728
_LIBRISPEECH_TRAIN_DATASETS = [
3829
[
@@ -86,130 +77,13 @@ def _collect_data(directory, input_ext, transcription_ext):
8677
return data_files
8778

8879

89-
def _get_audio_data(filepath):
90-
# Construct a true .wav file.
91-
out_filepath = filepath.strip(".flac") + ".wav"
92-
# Assumes sox is installed on system. Sox converts from FLAC to WAV.
93-
call(["sox", filepath, out_filepath])
94-
wav_file = wave.open(open(out_filepath))
95-
frame_count = wav_file.getnframes()
96-
byte_array = wav_file.readframes(frame_count)
97-
98-
data = np.fromstring(byte_array, np.uint8).tolist()
99-
return data, frame_count, wav_file.getsampwidth(), wav_file.getnchannels()
100-
101-
102-
class LibrispeechTextEncoder(text_encoder.TextEncoder):
103-
104-
def encode(self, s):
105-
return [self._num_reserved_ids + ord(c) for c in s]
106-
107-
def decode(self, ids):
108-
"""Transform a sequence of int ids into a human-readable string.
109-
110-
EOS is not expected in ids.
111-
112-
Args:
113-
ids: list of integers to be converted.
114-
Returns:
115-
s: human-readable string.
116-
"""
117-
decoded_ids = []
118-
for id_ in ids:
119-
if 0 <= id_ < self._num_reserved_ids:
120-
decoded_ids.append(text_encoder.RESERVED_TOKENS[int(id_)])
121-
else:
122-
decoded_ids.append(id_ - self._num_reserved_ids)
123-
return "".join([chr(d) for d in decoded_ids])
124-
125-
126-
@registry.register_audio_modality
127-
class LibrispeechModality(modality.Modality):
128-
"""Performs strided conv compressions for audio spectral data."""
129-
130-
def bottom(self, inputs):
131-
"""Transform input from data space to model space.
132-
133-
Args:
134-
inputs: A Tensor with shape [batch, ...]
135-
Returns:
136-
body_input: A Tensor with shape [batch, ?, ?, body_input_depth].
137-
"""
138-
with tf.variable_scope(self.name):
139-
# TODO(aidangomez): Will need to sort out a better audio pipeline
140-
def xnet_resblock(x, filters, res_relu, name):
141-
with tf.variable_scope(name):
142-
# We only stride along the length dimension to preserve the spectral
143-
# bins (which are tiny in dimensionality relative to length)
144-
y = common_layers.separable_conv_block(
145-
x,
146-
filters, [((1, 1), (3, 3)), ((1, 1), (3, 3))],
147-
first_relu=True,
148-
padding="SAME",
149-
force2d=True,
150-
name="sep_conv_block")
151-
y = common_layers.pool(y, (3, 3), "MAX", "SAME", strides=(2, 1))
152-
return y + common_layers.conv_block(
153-
x,
154-
filters, [((1, 1), (1, 1))],
155-
padding="SAME",
156-
strides=(2, 1),
157-
first_relu=res_relu,
158-
force2d=True,
159-
name="res_conv0")
160-
161-
# Rescale from UINT8 to floats in [-1,-1]
162-
signals = (tf.to_float(inputs)-127)/128.
163-
signals = tf.squeeze(signals, [2, 3])
164-
165-
# `stfts` is a complex64 Tensor representing the short-time Fourier
166-
# Transform of each signal in `signals`. Its shape is
167-
# [batch_size, ?, fft_unique_bins]
168-
# where fft_unique_bins = fft_length // 2 + 1 = 513.
169-
stfts = tf.contrib.signal.stft(signals, frame_length=1024, frame_step=512,
170-
fft_length=1024)
171-
172-
# An energy spectrogram is the magnitude of the complex-valued STFT.
173-
# A float32 Tensor of shape [batch_size, ?, 513].
174-
magnitude_spectrograms = tf.abs(stfts)
175-
176-
# Warp the linear-scale, magnitude spectrograms into the mel-scale.
177-
num_spectrogram_bins = magnitude_spectrograms.shape[-1].value
178-
lower_edge_hertz, upper_edge_hertz, num_mel_bins = 80.0, 7600.0, 64
179-
sample_rate = 16000
180-
linear_to_mel_weight_matrix = (
181-
tf.contrib.signal.linear_to_mel_weight_matrix(
182-
num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz,
183-
upper_edge_hertz))
184-
mel_spectrograms = tf.tensordot(
185-
magnitude_spectrograms, linear_to_mel_weight_matrix, 1)
186-
# Note: Shape inference for tensordot does not currently handle this case.
187-
mel_spectrograms.set_shape(magnitude_spectrograms.shape[:-1].concatenate(
188-
linear_to_mel_weight_matrix.shape[-1:]))
189-
190-
x = tf.expand_dims(mel_spectrograms, 2)
191-
x.set_shape([None, None, None, num_mel_bins])
192-
for i in xrange(self._model_hparams.audio_compression):
193-
x = xnet_resblock(x, 2**(i + 1), True, "compress_block_%d" % i)
194-
return xnet_resblock(x, self._body_input_depth, False,
195-
"compress_block_final")
196-
197-
19880
@registry.register_problem()
199-
class Librispeech(problem.Problem):
200-
"""Problem spec for English word to dictionary definition."""
81+
class Librispeech(speech_recognition.SpeechRecognitionProblem):
82+
"""Problem spec for Librispeech using clean and noisy data."""
20183

202-
@property
203-
def is_character_level(self):
204-
return True
205-
206-
@property
207-
def input_space_id(self):
208-
return problem.SpaceID.AUDIO_SPECTRAL
209-
210-
@property
211-
def target_space_id(self):
212-
return problem.SpaceID.EN_CHR
84+
# Select only the clean data
85+
TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS
86+
DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS
21387

21488
@property
21589
def num_shards(self):
@@ -228,26 +102,8 @@ def use_train_shards_for_dev(self):
228102
"""If true, we only generate training data and hold out shards for dev."""
229103
return False
230104

231-
def feature_encoders(self, _):
232-
return {
233-
"inputs": text_encoder.TextEncoder(),
234-
"targets": LibrispeechTextEncoder(),
235-
}
236-
237-
def example_reading_spec(self):
238-
data_fields = {
239-
"inputs": tf.VarLenFeature(tf.int64),
240-
"targets": tf.VarLenFeature(tf.int64),
241-
}
242-
data_items_to_decoders = None
243-
return (data_fields, data_items_to_decoders)
244-
245-
def generator(self, data_dir, tmp_dir, training,
105+
def generator(self, data_dir, tmp_dir, datasets,
246106
eos_list=None, start_from=0, how_many=0):
247-
eos_list = [1] if eos_list is None else eos_list
248-
datasets = (_LIBRISPEECH_TRAIN_DATASETS if training
249-
else _LIBRISPEECH_TEST_DATASETS)
250-
num_reserved_ids = self.feature_encoders(None)["targets"].num_reserved_ids
251107
i = 0
252108
for url, subdir in datasets:
253109
filename = os.path.basename(url)
@@ -267,44 +123,53 @@ def generator(self, data_dir, tmp_dir, training,
267123
data_dir = os.path.join(tmp_dir, "LibriSpeech", subdir)
268124
data_files = _collect_data(data_dir, "flac", "txt")
269125
data_pairs = data_files.values()
126+
127+
encoders = self.feature_encoders(None)
128+
audio_encoder = encoders["waveforms"]
129+
text_encoder = encoders["targets"]
130+
270131
for media_file, text_data in sorted(data_pairs)[start_from:]:
271132
if how_many > 0 and i == how_many:
272133
return
273134
i += 1
274-
audio_data, sample_count, sample_width, num_channels = _get_audio_data(
275-
media_file)
276-
label = [num_reserved_ids + ord(c) for c in text_data] + eos_list
277135
yield {
278-
"inputs": audio_data,
279-
"audio/channel_count": [num_channels],
280-
"audio/sample_count": [sample_count],
281-
"audio/sample_width": [sample_width],
282-
"targets": label
136+
"waveforms": audio_encoder.encode(media_file),
137+
"targets": text_encoder.encode(text_data)
283138
}
284139

285140
def generate_data(self, data_dir, tmp_dir, task_id=-1):
286141
train_paths = self.training_filepaths(
287142
data_dir, self.num_shards, shuffled=False)
288143
dev_paths = self.dev_filepaths(
289144
data_dir, self.num_dev_shards, shuffled=False)
145+
290146
if self.use_train_shards_for_dev:
291147
all_paths = train_paths + dev_paths
292148
generator_utils.generate_files(
293-
self.generator(data_dir, tmp_dir, True), all_paths)
149+
self.generator(data_dir, tmp_dir, self.TRAIN_DATASETS), all_paths)
294150
generator_utils.shuffle_dataset(all_paths)
295151
else:
296152
generator_utils.generate_dataset_and_shuffle(
297-
self.generator(data_dir, tmp_dir, True), train_paths,
298-
self.generator(data_dir, tmp_dir, False), dev_paths)
153+
self.generator(data_dir, tmp_dir, self.TRAIN_DATASETS), train_paths,
154+
self.generator(data_dir, tmp_dir, self.DEV_DATASETS), dev_paths)
299155

300-
def hparams(self, defaults, unused_model_hparams):
301-
p = defaults
302-
p.stop_at_eos = int(False)
303-
p.input_modality = {"inputs": ("audio:librispeech_modality", None)}
304-
p.target_modality = (registry.Modalities.SYMBOL, 256)
305156

306-
def preprocess_example(self, example, mode, hparams):
307-
return example
157+
@registry.register_problem()
158+
class LibrispeechCleanSmall(Librispeech):
159+
"""Problem spec for Librispeech using 100h clean train data."""
160+
161+
# Select only the clean data
162+
TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS[:1]
163+
DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS[:1]
164+
165+
166+
@registry.register_problem()
167+
class LibrispeechClean(Librispeech):
168+
"""Problem spec for Librispeech using 460h clean train data."""
169+
170+
# Select only the clean data
171+
TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS[:2]
172+
DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS[:1]
308173

309174

310175
# TODO(lukaszkaiser): clean up hparams or remove from here.

0 commit comments

Comments
 (0)