Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit a836d66

Browse files
authored
Merge pull request #370 from vince62s/translate
Rework the translate problem
2 parents 0ecbef8 + b43f833 commit a836d66

15 files changed

+947
-768
lines changed

Diff for: tensor2tensor/bin/t2t-datagen

100644100755
+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ from tensor2tensor.data_generators import all_problems # pylint: disable=unused
4343
from tensor2tensor.data_generators import audio
4444
from tensor2tensor.data_generators import generator_utils
4545
from tensor2tensor.data_generators import snli
46-
from tensor2tensor.data_generators import wmt
46+
from tensor2tensor.data_generators import translate
4747
from tensor2tensor.data_generators import wsj_parsing
4848
from tensor2tensor.utils import registry
4949
from tensor2tensor.utils import usr_dir

Diff for: tensor2tensor/bin/t2t-decoder

100644100755
File mode changed.

Diff for: tensor2tensor/bin/t2t-make-tf-configs

100644100755
File mode changed.

Diff for: tensor2tensor/bin/t2t-trainer

100644100755
File mode changed.

Diff for: tensor2tensor/data_generators/all_problems.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@
3333
from tensor2tensor.data_generators import ptb
3434
from tensor2tensor.data_generators import snli
3535
from tensor2tensor.data_generators import wiki
36-
from tensor2tensor.data_generators import wmt
36+
from tensor2tensor.data_generators import translate
37+
from tensor2tensor.data_generators import translate_enfr
38+
from tensor2tensor.data_generators import translate_ende
39+
from tensor2tensor.data_generators import translate_encs
40+
from tensor2tensor.data_generators import translate_enzh
41+
from tensor2tensor.data_generators import translate_enmk
3742
from tensor2tensor.data_generators import wsj_parsing
3843

3944

Diff for: tensor2tensor/data_generators/generator_utils.py

+14-45
Original file line numberDiff line numberDiff line change
@@ -263,42 +263,6 @@ def gunzip_file(gz_path, new_path):
263263
for line in gz_file:
264264
new_file.write(line)
265265

266-
267-
# TODO(aidangomez): en-fr tasks are significantly over-represented below
268-
_DATA_FILE_URLS = [
269-
# German-English
270-
[
271-
"http://data.statmt.org/wmt16/translation-task/training-parallel-nc-v11.tgz", # pylint: disable=line-too-long
272-
[
273-
"training-parallel-nc-v11/news-commentary-v11.de-en.en",
274-
"training-parallel-nc-v11/news-commentary-v11.de-en.de"
275-
]
276-
],
277-
# German-English & French-English
278-
[
279-
"http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz", [
280-
"commoncrawl.de-en.en", "commoncrawl.de-en.de",
281-
"commoncrawl.fr-en.en", "commoncrawl.fr-en.fr"
282-
]
283-
],
284-
[
285-
"http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz", [
286-
"training/europarl-v7.de-en.en", "training/europarl-v7.de-en.de",
287-
"training/europarl-v7.fr-en.en", "training/europarl-v7.fr-en.fr"
288-
]
289-
],
290-
# French-English
291-
[
292-
"http://www.statmt.org/wmt10/training-giga-fren.tar",
293-
["giga-fren.release2.fixed.en.gz", "giga-fren.release2.fixed.fr.gz"]
294-
],
295-
[
296-
"http://www.statmt.org/wmt13/training-parallel-un.tgz",
297-
["un/undoc.2000.fr-en.en", "un/undoc.2000.fr-en.fr"]
298-
],
299-
]
300-
301-
302266
def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,
303267
generator):
304268
"""Inner implementation for vocab generators.
@@ -341,9 +305,8 @@ def get_or_generate_vocab(data_dir,
341305
tmp_dir,
342306
vocab_filename,
343307
vocab_size,
344-
sources=None):
345-
"""Generate a vocabulary from the datasets in sources (_DATA_FILE_URLS)."""
346-
sources = sources or _DATA_FILE_URLS
308+
sources):
309+
"""Generate a vocabulary from the datasets in sources."""
347310

348311
def generate():
349312
tf.logging.info("Generating vocab from: %s", str(sources))
@@ -375,13 +338,19 @@ def generate():
375338

376339
# Use Tokenizer to count the word occurrences.
377340
with tf.gfile.GFile(filepath, mode="r") as source_file:
378-
file_byte_budget = 3.5e5 if filepath.endswith("en") else 7e5
341+
file_byte_budget = 1e6
342+
counter = 0
343+
countermax = int(source_file.size() / file_byte_budget / 2)
379344
for line in source_file:
380-
if file_byte_budget <= 0:
381-
break
382-
line = line.strip()
383-
file_byte_budget -= len(line)
384-
yield line
345+
if counter < countermax:
346+
counter += 1
347+
else:
348+
if file_byte_budget <= 0:
349+
break
350+
line = line.strip()
351+
file_byte_budget -= len(line)
352+
counter = 0
353+
yield line
385354

386355
return get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,
387356
generate())

Diff for: tensor2tensor/data_generators/ice_parsing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from tensor2tensor.data_generators import generator_utils
3333
from tensor2tensor.data_generators import problem
3434
from tensor2tensor.data_generators import text_encoder
35-
from tensor2tensor.data_generators.wmt import tabbed_generator
35+
from tensor2tensor.data_generators.translate import tabbed_generator
3636
from tensor2tensor.utils import registry
3737

3838

Diff for: tensor2tensor/data_generators/translate.py

+262
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
# coding=utf-8
2+
# Copyright 2017 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Data generators for translation data-sets."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import os
23+
import tarfile
24+
25+
# Dependency imports
26+
27+
from tensor2tensor.data_generators import generator_utils
28+
from tensor2tensor.data_generators import problem
29+
from tensor2tensor.data_generators import text_encoder
30+
from tensor2tensor.data_generators import wsj_parsing
31+
from tensor2tensor.utils import registry
32+
33+
import tensorflow as tf
34+
35+
FLAGS = tf.flags.FLAGS
36+
37+
38+
class TranslateProblem(problem.Text2TextProblem):
39+
"""Base class for translation problems."""
40+
41+
@property
42+
def is_character_level(self):
43+
return False
44+
45+
@property
46+
def num_shards(self):
47+
return 100
48+
49+
@property
50+
def use_subword_tokenizer(self):
51+
return True
52+
53+
54+
# Generic generators used later for multiple problems.
55+
56+
57+
def character_generator(source_path, target_path, character_vocab, eos=None):
58+
"""Generator for sequence-to-sequence tasks that just uses characters.
59+
60+
This generator assumes the files at source_path and target_path have
61+
the same number of lines and yields dictionaries of "inputs" and "targets"
62+
where inputs are characters from the source lines converted to integers,
63+
and targets are characters from the target lines, also converted to integers.
64+
65+
Args:
66+
source_path: path to the file with source sentences.
67+
target_path: path to the file with target sentences.
68+
character_vocab: a TextEncoder to encode the characters.
69+
eos: integer to append at the end of each sequence (default: None).
70+
71+
Yields:
72+
A dictionary {"inputs": source-line, "targets": target-line} where
73+
the lines are integer lists converted from characters in the file lines.
74+
"""
75+
eos_list = [] if eos is None else [eos]
76+
with tf.gfile.GFile(source_path, mode="r") as source_file:
77+
with tf.gfile.GFile(target_path, mode="r") as target_file:
78+
source, target = source_file.readline(), target_file.readline()
79+
while source and target:
80+
source_ints = character_vocab.encode(source.strip()) + eos_list
81+
target_ints = character_vocab.encode(target.strip()) + eos_list
82+
yield {"inputs": source_ints, "targets": target_ints}
83+
source, target = source_file.readline(), target_file.readline()
84+
85+
86+
def tabbed_generator(source_path, source_vocab, target_vocab, eos=None):
87+
r"""Generator for sequence-to-sequence tasks using tabbed files.
88+
89+
Tokens are derived from text files where each line contains both
90+
a source and a target string. The two strings are separated by a tab
91+
character ('\t'). It yields dictionaries of "inputs" and "targets" where
92+
inputs are characters from the source lines converted to integers, and
93+
targets are characters from the target lines, also converted to integers.
94+
95+
Args:
96+
source_path: path to the file with source and target sentences.
97+
source_vocab: a SubwordTextEncoder to encode the source string.
98+
target_vocab: a SubwordTextEncoder to encode the target string.
99+
eos: integer to append at the end of each sequence (default: None).
100+
101+
Yields:
102+
A dictionary {"inputs": source-line, "targets": target-line} where
103+
the lines are integer lists converted from characters in the file lines.
104+
"""
105+
eos_list = [] if eos is None else [eos]
106+
with tf.gfile.GFile(source_path, mode="r") as source_file:
107+
for line in source_file:
108+
if line and "\t" in line:
109+
parts = line.split("\t", 1)
110+
source, target = parts[0].strip(), parts[1].strip()
111+
source_ints = source_vocab.encode(source) + eos_list
112+
target_ints = target_vocab.encode(target) + eos_list
113+
yield {"inputs": source_ints, "targets": target_ints}
114+
115+
116+
def token_generator(source_path, target_path, token_vocab, eos=None):
117+
"""Generator for sequence-to-sequence tasks that uses tokens.
118+
119+
This generator assumes the files at source_path and target_path have
120+
the same number of lines and yields dictionaries of "inputs" and "targets"
121+
where inputs are token ids from the " "-split source (and target, resp.) lines
122+
converted to integers using the token_map.
123+
124+
Args:
125+
source_path: path to the file with source sentences.
126+
target_path: path to the file with target sentences.
127+
token_vocab: text_encoder.TextEncoder object.
128+
eos: integer to append at the end of each sequence (default: None).
129+
130+
Yields:
131+
A dictionary {"inputs": source-line, "targets": target-line} where
132+
the lines are integer lists converted from tokens in the file lines.
133+
"""
134+
eos_list = [] if eos is None else [eos]
135+
with tf.gfile.GFile(source_path, mode="r") as source_file:
136+
with tf.gfile.GFile(target_path, mode="r") as target_file:
137+
source, target = source_file.readline(), target_file.readline()
138+
while source and target:
139+
source_ints = token_vocab.encode(source.strip()) + eos_list
140+
target_ints = token_vocab.encode(target.strip()) + eos_list
141+
yield {"inputs": source_ints, "targets": target_ints}
142+
source, target = source_file.readline(), target_file.readline()
143+
144+
145+
def bi_vocabs_token_generator(source_path,
146+
target_path,
147+
source_token_vocab,
148+
target_token_vocab,
149+
eos=None):
150+
"""Generator for sequence-to-sequence tasks that uses tokens.
151+
152+
This generator assumes the files at source_path and target_path have
153+
the same number of lines and yields dictionaries of "inputs" and "targets"
154+
where inputs are token ids from the " "-split source (and target, resp.) lines
155+
converted to integers using the token_map.
156+
157+
Args:
158+
source_path: path to the file with source sentences.
159+
target_path: path to the file with target sentences.
160+
source_token_vocab: text_encoder.TextEncoder object.
161+
target_token_vocab: text_encoder.TextEncoder object.
162+
eos: integer to append at the end of each sequence (default: None).
163+
164+
Yields:
165+
A dictionary {"inputs": source-line, "targets": target-line} where
166+
the lines are integer lists converted from tokens in the file lines.
167+
"""
168+
eos_list = [] if eos is None else [eos]
169+
with tf.gfile.GFile(source_path, mode="r") as source_file:
170+
with tf.gfile.GFile(target_path, mode="r") as target_file:
171+
source, target = source_file.readline(), target_file.readline()
172+
while source and target:
173+
source_ints = source_token_vocab.encode(source.strip()) + eos_list
174+
target_ints = target_token_vocab.encode(target.strip()) + eos_list
175+
yield {"inputs": source_ints, "targets": target_ints}
176+
source, target = source_file.readline(), target_file.readline()
177+
178+
def _preprocess_sgm(line, is_sgm):
179+
"""Preprocessing to strip tags in SGM files."""
180+
if not is_sgm:
181+
return line
182+
# In SGM files, remove <srcset ...>, <p>, <doc ...> lines.
183+
if line.startswith("<srcset") or line.startswith("</srcset"):
184+
return ""
185+
if line.startswith("<doc") or line.startswith("</doc"):
186+
return ""
187+
if line.startswith("<p>") or line.startswith("</p>"):
188+
return ""
189+
# Strip <seg> tags.
190+
line = line.strip()
191+
if line.startswith("<seg") and line.endswith("</seg>"):
192+
i = line.index(">")
193+
return line[i + 1:-6] # Strip first <seg ...> and last </seg>.
194+
195+
def _compile_data(tmp_dir, datasets, filename):
196+
"""Concatenate all `datasets` and save to `filename`."""
197+
filename = os.path.join(tmp_dir, filename)
198+
with tf.gfile.GFile(filename + ".lang1", mode="w") as lang1_resfile:
199+
with tf.gfile.GFile(filename + ".lang2", mode="w") as lang2_resfile:
200+
for dataset in datasets:
201+
url = dataset[0]
202+
compressed_filename = os.path.basename(url)
203+
compressed_filepath = os.path.join(tmp_dir, compressed_filename)
204+
205+
generator_utils.maybe_download(tmp_dir, compressed_filename, url)
206+
207+
if dataset[1][0] == "tsv":
208+
_, src_column, trg_column, glob_pattern = dataset[1]
209+
filenames = tf.gfile.Glob(os.path.join(tmp_dir, glob_pattern))
210+
if not filenames:
211+
# Capture *.tgz and *.tar.gz too.
212+
mode = "r:gz" if compressed_filepath.endswith("gz") else "r"
213+
with tarfile.open(compressed_filepath, mode) as corpus_tar:
214+
corpus_tar.extractall(tmp_dir)
215+
filenames = tf.gfile.Glob(os.path.join(tmp_dir, glob_pattern))
216+
for tsv_filename in filenames:
217+
if tsv_filename.endswith(".gz"):
218+
new_filename = tsv_filename.strip(".gz")
219+
generator_utils.gunzip_file(tsv_filename, new_filename)
220+
tsv_filename = new_filename
221+
with tf.gfile.GFile(tsv_filename, mode="r") as tsv_file:
222+
for line in tsv_file:
223+
if line and "\t" in line:
224+
parts = line.split("\t")
225+
source, target = parts[src_column], parts[trg_column]
226+
lang1_resfile.write(source.strip() + "\n")
227+
lang2_resfile.write(target.strip() + "\n")
228+
else:
229+
lang1_filename, lang2_filename = dataset[1]
230+
lang1_filepath = os.path.join(tmp_dir, lang1_filename)
231+
lang2_filepath = os.path.join(tmp_dir, lang2_filename)
232+
is_sgm = (lang1_filename.endswith("sgm") and
233+
lang2_filename.endswith("sgm"))
234+
235+
if not (os.path.exists(lang1_filepath) and
236+
os.path.exists(lang2_filepath)):
237+
# For .tar.gz and .tgz files, we read compressed.
238+
mode = "r:gz" if compressed_filepath.endswith("gz") else "r"
239+
with tarfile.open(compressed_filepath, mode) as corpus_tar:
240+
corpus_tar.extractall(tmp_dir)
241+
if lang1_filepath.endswith(".gz"):
242+
new_filepath = lang1_filepath.strip(".gz")
243+
generator_utils.gunzip_file(lang1_filepath, new_filepath)
244+
lang1_filepath = new_filepath
245+
if lang2_filepath.endswith(".gz"):
246+
new_filepath = lang2_filepath.strip(".gz")
247+
generator_utils.gunzip_file(lang2_filepath, new_filepath)
248+
lang2_filepath = new_filepath
249+
with tf.gfile.GFile(lang1_filepath, mode="r") as lang1_file:
250+
with tf.gfile.GFile(lang2_filepath, mode="r") as lang2_file:
251+
line1, line2 = lang1_file.readline(), lang2_file.readline()
252+
while line1 or line2:
253+
line1res = _preprocess_sgm(line1, is_sgm)
254+
line2res = _preprocess_sgm(line2, is_sgm)
255+
if line1res or line2res:
256+
lang1_resfile.write(line1res.strip() + "\n")
257+
lang2_resfile.write(line2res.strip() + "\n")
258+
line1, line2 = lang1_file.readline(), lang2_file.readline()
259+
260+
return filename
261+
262+

0 commit comments

Comments
 (0)