|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2017 The Tensor2Tensor Authors. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Data generators for translation data-sets.""" |
| 17 | + |
| 18 | +from __future__ import absolute_import |
| 19 | +from __future__ import division |
| 20 | +from __future__ import print_function |
| 21 | + |
| 22 | +import os |
| 23 | +import tarfile |
| 24 | + |
| 25 | +# Dependency imports |
| 26 | + |
| 27 | +from tensor2tensor.data_generators import generator_utils |
| 28 | +from tensor2tensor.data_generators import problem |
| 29 | +from tensor2tensor.data_generators import text_encoder |
| 30 | +from tensor2tensor.data_generators import wsj_parsing |
| 31 | +from tensor2tensor.utils import registry |
| 32 | + |
| 33 | +import tensorflow as tf |
| 34 | + |
| 35 | +FLAGS = tf.flags.FLAGS |
| 36 | + |
| 37 | + |
| 38 | +class TranslateProblem(problem.Text2TextProblem): |
| 39 | + """Base class for translation problems.""" |
| 40 | + |
| 41 | + @property |
| 42 | + def is_character_level(self): |
| 43 | + return False |
| 44 | + |
| 45 | + @property |
| 46 | + def num_shards(self): |
| 47 | + return 100 |
| 48 | + |
| 49 | + @property |
| 50 | + def use_subword_tokenizer(self): |
| 51 | + return True |
| 52 | + |
| 53 | + |
| 54 | +# Generic generators used later for multiple problems. |
| 55 | + |
| 56 | + |
| 57 | +def character_generator(source_path, target_path, character_vocab, eos=None): |
| 58 | + """Generator for sequence-to-sequence tasks that just uses characters. |
| 59 | +
|
| 60 | + This generator assumes the files at source_path and target_path have |
| 61 | + the same number of lines and yields dictionaries of "inputs" and "targets" |
| 62 | + where inputs are characters from the source lines converted to integers, |
| 63 | + and targets are characters from the target lines, also converted to integers. |
| 64 | +
|
| 65 | + Args: |
| 66 | + source_path: path to the file with source sentences. |
| 67 | + target_path: path to the file with target sentences. |
| 68 | + character_vocab: a TextEncoder to encode the characters. |
| 69 | + eos: integer to append at the end of each sequence (default: None). |
| 70 | +
|
| 71 | + Yields: |
| 72 | + A dictionary {"inputs": source-line, "targets": target-line} where |
| 73 | + the lines are integer lists converted from characters in the file lines. |
| 74 | + """ |
| 75 | + eos_list = [] if eos is None else [eos] |
| 76 | + with tf.gfile.GFile(source_path, mode="r") as source_file: |
| 77 | + with tf.gfile.GFile(target_path, mode="r") as target_file: |
| 78 | + source, target = source_file.readline(), target_file.readline() |
| 79 | + while source and target: |
| 80 | + source_ints = character_vocab.encode(source.strip()) + eos_list |
| 81 | + target_ints = character_vocab.encode(target.strip()) + eos_list |
| 82 | + yield {"inputs": source_ints, "targets": target_ints} |
| 83 | + source, target = source_file.readline(), target_file.readline() |
| 84 | + |
| 85 | + |
| 86 | +def tabbed_generator(source_path, source_vocab, target_vocab, eos=None): |
| 87 | + r"""Generator for sequence-to-sequence tasks using tabbed files. |
| 88 | +
|
| 89 | + Tokens are derived from text files where each line contains both |
| 90 | + a source and a target string. The two strings are separated by a tab |
| 91 | + character ('\t'). It yields dictionaries of "inputs" and "targets" where |
| 92 | + inputs are characters from the source lines converted to integers, and |
| 93 | + targets are characters from the target lines, also converted to integers. |
| 94 | +
|
| 95 | + Args: |
| 96 | + source_path: path to the file with source and target sentences. |
| 97 | + source_vocab: a SubwordTextEncoder to encode the source string. |
| 98 | + target_vocab: a SubwordTextEncoder to encode the target string. |
| 99 | + eos: integer to append at the end of each sequence (default: None). |
| 100 | +
|
| 101 | + Yields: |
| 102 | + A dictionary {"inputs": source-line, "targets": target-line} where |
| 103 | + the lines are integer lists converted from characters in the file lines. |
| 104 | + """ |
| 105 | + eos_list = [] if eos is None else [eos] |
| 106 | + with tf.gfile.GFile(source_path, mode="r") as source_file: |
| 107 | + for line in source_file: |
| 108 | + if line and "\t" in line: |
| 109 | + parts = line.split("\t", 1) |
| 110 | + source, target = parts[0].strip(), parts[1].strip() |
| 111 | + source_ints = source_vocab.encode(source) + eos_list |
| 112 | + target_ints = target_vocab.encode(target) + eos_list |
| 113 | + yield {"inputs": source_ints, "targets": target_ints} |
| 114 | + |
| 115 | + |
| 116 | +def token_generator(source_path, target_path, token_vocab, eos=None): |
| 117 | + """Generator for sequence-to-sequence tasks that uses tokens. |
| 118 | +
|
| 119 | + This generator assumes the files at source_path and target_path have |
| 120 | + the same number of lines and yields dictionaries of "inputs" and "targets" |
| 121 | + where inputs are token ids from the " "-split source (and target, resp.) lines |
| 122 | + converted to integers using the token_map. |
| 123 | +
|
| 124 | + Args: |
| 125 | + source_path: path to the file with source sentences. |
| 126 | + target_path: path to the file with target sentences. |
| 127 | + token_vocab: text_encoder.TextEncoder object. |
| 128 | + eos: integer to append at the end of each sequence (default: None). |
| 129 | +
|
| 130 | + Yields: |
| 131 | + A dictionary {"inputs": source-line, "targets": target-line} where |
| 132 | + the lines are integer lists converted from tokens in the file lines. |
| 133 | + """ |
| 134 | + eos_list = [] if eos is None else [eos] |
| 135 | + with tf.gfile.GFile(source_path, mode="r") as source_file: |
| 136 | + with tf.gfile.GFile(target_path, mode="r") as target_file: |
| 137 | + source, target = source_file.readline(), target_file.readline() |
| 138 | + while source and target: |
| 139 | + source_ints = token_vocab.encode(source.strip()) + eos_list |
| 140 | + target_ints = token_vocab.encode(target.strip()) + eos_list |
| 141 | + yield {"inputs": source_ints, "targets": target_ints} |
| 142 | + source, target = source_file.readline(), target_file.readline() |
| 143 | + |
| 144 | + |
| 145 | +def bi_vocabs_token_generator(source_path, |
| 146 | + target_path, |
| 147 | + source_token_vocab, |
| 148 | + target_token_vocab, |
| 149 | + eos=None): |
| 150 | + """Generator for sequence-to-sequence tasks that uses tokens. |
| 151 | +
|
| 152 | + This generator assumes the files at source_path and target_path have |
| 153 | + the same number of lines and yields dictionaries of "inputs" and "targets" |
| 154 | + where inputs are token ids from the " "-split source (and target, resp.) lines |
| 155 | + converted to integers using the token_map. |
| 156 | +
|
| 157 | + Args: |
| 158 | + source_path: path to the file with source sentences. |
| 159 | + target_path: path to the file with target sentences. |
| 160 | + source_token_vocab: text_encoder.TextEncoder object. |
| 161 | + target_token_vocab: text_encoder.TextEncoder object. |
| 162 | + eos: integer to append at the end of each sequence (default: None). |
| 163 | +
|
| 164 | + Yields: |
| 165 | + A dictionary {"inputs": source-line, "targets": target-line} where |
| 166 | + the lines are integer lists converted from tokens in the file lines. |
| 167 | + """ |
| 168 | + eos_list = [] if eos is None else [eos] |
| 169 | + with tf.gfile.GFile(source_path, mode="r") as source_file: |
| 170 | + with tf.gfile.GFile(target_path, mode="r") as target_file: |
| 171 | + source, target = source_file.readline(), target_file.readline() |
| 172 | + while source and target: |
| 173 | + source_ints = source_token_vocab.encode(source.strip()) + eos_list |
| 174 | + target_ints = target_token_vocab.encode(target.strip()) + eos_list |
| 175 | + yield {"inputs": source_ints, "targets": target_ints} |
| 176 | + source, target = source_file.readline(), target_file.readline() |
| 177 | + |
| 178 | +def _preprocess_sgm(line, is_sgm): |
| 179 | + """Preprocessing to strip tags in SGM files.""" |
| 180 | + if not is_sgm: |
| 181 | + return line |
| 182 | + # In SGM files, remove <srcset ...>, <p>, <doc ...> lines. |
| 183 | + if line.startswith("<srcset") or line.startswith("</srcset"): |
| 184 | + return "" |
| 185 | + if line.startswith("<doc") or line.startswith("</doc"): |
| 186 | + return "" |
| 187 | + if line.startswith("<p>") or line.startswith("</p>"): |
| 188 | + return "" |
| 189 | + # Strip <seg> tags. |
| 190 | + line = line.strip() |
| 191 | + if line.startswith("<seg") and line.endswith("</seg>"): |
| 192 | + i = line.index(">") |
| 193 | + return line[i + 1:-6] # Strip first <seg ...> and last </seg>. |
| 194 | + |
| 195 | +def _compile_data(tmp_dir, datasets, filename): |
| 196 | + """Concatenate all `datasets` and save to `filename`.""" |
| 197 | + filename = os.path.join(tmp_dir, filename) |
| 198 | + with tf.gfile.GFile(filename + ".lang1", mode="w") as lang1_resfile: |
| 199 | + with tf.gfile.GFile(filename + ".lang2", mode="w") as lang2_resfile: |
| 200 | + for dataset in datasets: |
| 201 | + url = dataset[0] |
| 202 | + compressed_filename = os.path.basename(url) |
| 203 | + compressed_filepath = os.path.join(tmp_dir, compressed_filename) |
| 204 | + |
| 205 | + generator_utils.maybe_download(tmp_dir, compressed_filename, url) |
| 206 | + |
| 207 | + if dataset[1][0] == "tsv": |
| 208 | + _, src_column, trg_column, glob_pattern = dataset[1] |
| 209 | + filenames = tf.gfile.Glob(os.path.join(tmp_dir, glob_pattern)) |
| 210 | + if not filenames: |
| 211 | + # Capture *.tgz and *.tar.gz too. |
| 212 | + mode = "r:gz" if compressed_filepath.endswith("gz") else "r" |
| 213 | + with tarfile.open(compressed_filepath, mode) as corpus_tar: |
| 214 | + corpus_tar.extractall(tmp_dir) |
| 215 | + filenames = tf.gfile.Glob(os.path.join(tmp_dir, glob_pattern)) |
| 216 | + for tsv_filename in filenames: |
| 217 | + if tsv_filename.endswith(".gz"): |
| 218 | + new_filename = tsv_filename.strip(".gz") |
| 219 | + generator_utils.gunzip_file(tsv_filename, new_filename) |
| 220 | + tsv_filename = new_filename |
| 221 | + with tf.gfile.GFile(tsv_filename, mode="r") as tsv_file: |
| 222 | + for line in tsv_file: |
| 223 | + if line and "\t" in line: |
| 224 | + parts = line.split("\t") |
| 225 | + source, target = parts[src_column], parts[trg_column] |
| 226 | + lang1_resfile.write(source.strip() + "\n") |
| 227 | + lang2_resfile.write(target.strip() + "\n") |
| 228 | + else: |
| 229 | + lang1_filename, lang2_filename = dataset[1] |
| 230 | + lang1_filepath = os.path.join(tmp_dir, lang1_filename) |
| 231 | + lang2_filepath = os.path.join(tmp_dir, lang2_filename) |
| 232 | + is_sgm = (lang1_filename.endswith("sgm") and |
| 233 | + lang2_filename.endswith("sgm")) |
| 234 | + |
| 235 | + if not (os.path.exists(lang1_filepath) and |
| 236 | + os.path.exists(lang2_filepath)): |
| 237 | + # For .tar.gz and .tgz files, we read compressed. |
| 238 | + mode = "r:gz" if compressed_filepath.endswith("gz") else "r" |
| 239 | + with tarfile.open(compressed_filepath, mode) as corpus_tar: |
| 240 | + corpus_tar.extractall(tmp_dir) |
| 241 | + if lang1_filepath.endswith(".gz"): |
| 242 | + new_filepath = lang1_filepath.strip(".gz") |
| 243 | + generator_utils.gunzip_file(lang1_filepath, new_filepath) |
| 244 | + lang1_filepath = new_filepath |
| 245 | + if lang2_filepath.endswith(".gz"): |
| 246 | + new_filepath = lang2_filepath.strip(".gz") |
| 247 | + generator_utils.gunzip_file(lang2_filepath, new_filepath) |
| 248 | + lang2_filepath = new_filepath |
| 249 | + with tf.gfile.GFile(lang1_filepath, mode="r") as lang1_file: |
| 250 | + with tf.gfile.GFile(lang2_filepath, mode="r") as lang2_file: |
| 251 | + line1, line2 = lang1_file.readline(), lang2_file.readline() |
| 252 | + while line1 or line2: |
| 253 | + line1res = _preprocess_sgm(line1, is_sgm) |
| 254 | + line2res = _preprocess_sgm(line2, is_sgm) |
| 255 | + if line1res or line2res: |
| 256 | + lang1_resfile.write(line1res.strip() + "\n") |
| 257 | + lang2_resfile.write(line2res.strip() + "\n") |
| 258 | + line1, line2 = lang1_file.readline(), lang2_file.readline() |
| 259 | + |
| 260 | + return filename |
| 261 | + |
| 262 | + |
0 commit comments