|
19 | 19 | from __future__ import division
|
20 | 20 | from __future__ import print_function
|
21 | 21 |
|
22 |
| -import glob |
23 | 22 | import os
|
24 |
| -import stat |
25 | 23 | import tarfile
|
26 | 24 |
|
27 | 25 | # Dependency imports
|
@@ -115,7 +113,7 @@ def tabbed_generator(source_path, source_vocab, target_vocab, eos=None):
|
115 | 113 | with tf.gfile.GFile(source_path, mode="r") as source_file:
|
116 | 114 | for line in source_file:
|
117 | 115 | if line and "\t" in line:
|
118 |
| - parts = line.split("\t", maxsplit=1) |
| 116 | + parts = line.split("\t", 1) |
119 | 117 | source, target = parts[0].strip(), parts[1].strip()
|
120 | 118 | source_ints = source_vocab.encode(source) + eos_list
|
121 | 119 | target_ints = target_vocab.encode(target) + eos_list
|
@@ -267,8 +265,9 @@ def bi_vocabs_token_generator(source_path,
|
267 | 265 | # English-Czech datasets
|
268 | 266 | _ENCS_TRAIN_DATASETS = [
|
269 | 267 | [
|
270 |
| - "https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-1458/data-plaintext-format.tar", |
271 |
| - ('tsv', 3, 2, 'data.plaintext-format/*train.gz') |
| 268 | + ("https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/" |
| 269 | + "11234/1-1458/data-plaintext-format.tar"), |
| 270 | + ("tsv", 3, 2, "data.plaintext-format/*train.gz") |
272 | 271 | ],
|
273 | 272 | [
|
274 | 273 | "http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz", # pylint: disable=line-too-long
|
@@ -375,25 +374,22 @@ def _compile_data(tmp_dir, datasets, filename):
|
375 | 374 | url = dataset[0]
|
376 | 375 | compressed_filename = os.path.basename(url)
|
377 | 376 | compressed_filepath = os.path.join(tmp_dir, compressed_filename)
|
| 377 | + |
378 | 378 | generator_utils.maybe_download(tmp_dir, compressed_filename, url)
|
379 | 379 |
|
380 |
| - if dataset[1][0] == 'tsv': |
| 380 | + if dataset[1][0] == "tsv": |
381 | 381 | _, src_column, trg_column, glob_pattern = dataset[1]
|
382 |
| - filenames = glob.glob(os.path.join(tmp_dir, glob_pattern)) |
| 382 | + filenames = tf.gfile.Glob(os.path.join(tmp_dir, glob_pattern)) |
383 | 383 | if not filenames:
|
384 |
| - mode = "r:gz" if compressed_filepath.endswith("gz") else "r" # *.tgz *.tar.gz |
| 384 | + # Capture *.tgz and *.tar.gz too. |
| 385 | + mode = "r:gz" if compressed_filepath.endswith("gz") else "r" |
385 | 386 | with tarfile.open(compressed_filepath, mode) as corpus_tar:
|
386 | 387 | corpus_tar.extractall(tmp_dir)
|
387 |
| - filenames = glob.glob(os.path.join(tmp_dir, glob_pattern)) |
| 388 | + filenames = tf.gfile.Glob(os.path.join(tmp_dir, glob_pattern)) |
388 | 389 | for tsv_filename in filenames:
|
389 | 390 | if tsv_filename.endswith(".gz"):
|
390 | 391 | new_filename = tsv_filename.strip(".gz")
|
391 |
| - try: |
392 |
| - generator_utils.gunzip_file(tsv_filename, new_filename) |
393 |
| - except PermissionError: |
394 |
| - tsvdir = os.path.dirname(tsv_filename) |
395 |
| - os.chmod(tsvdir, os.stat(tsvdir).st_mode | stat.S_IWRITE) |
396 |
| - generator_utils.gunzip_file(tsv_filename, new_filename) |
| 392 | + generator_utils.gunzip_file(tsv_filename, new_filename) |
397 | 393 | tsv_filename = new_filename
|
398 | 394 | with tf.gfile.GFile(tsv_filename, mode="r") as tsv_file:
|
399 | 395 | for line in tsv_file:
|
@@ -663,17 +659,19 @@ def vocab_name(self):
|
663 | 659 | def generator(self, data_dir, tmp_dir, train):
|
664 | 660 | datasets = _ENCS_TRAIN_DATASETS if train else _ENCS_TEST_DATASETS
|
665 | 661 | tag = "train" if train else "dev"
|
666 |
| - data_path = _compile_data(tmp_dir, datasets, "wmt_encs_tok_%s" % tag) |
667 | 662 | vocab_datasets = []
|
| 663 | + data_path = _compile_data(tmp_dir, datasets, "wmt_encs_tok_%s" % tag) |
668 | 664 | # CzEng contains 100 gz files with tab-separated columns, so let's expect
|
669 |
| - # it is the first dataset in datasets and use the newly created *.lang{1,2} files instead. |
| 665 | + # it is the first dataset in datasets and use the newly created *.lang{1,2} |
| 666 | + # files for vocab construction. |
670 | 667 | if datasets[0][0].endswith("data-plaintext-format.tar"):
|
671 |
| - vocab_datasets.append([datasets[0][0], |
672 |
| - ["wmt_encs_tok_%s.lang1" % tag, "wmt_encs_tok_%s.lang2" % tag]]) |
| 668 | + vocab_datasets.append([datasets[0][0], ["wmt_encs_tok_%s.lang1" % tag, |
| 669 | + "wmt_encs_tok_%s.lang2" % tag]]) |
673 | 670 | datasets = datasets[1:]
|
674 | 671 | vocab_datasets += [[item[0], [item[1][0], item[1][1]]] for item in datasets]
|
675 | 672 | symbolizer_vocab = generator_utils.get_or_generate_vocab(
|
676 |
| - data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size, vocab_datasets) |
| 673 | + data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size, |
| 674 | + vocab_datasets) |
677 | 675 | return token_generator(data_path + ".lang1", data_path + ".lang2",
|
678 | 676 | symbolizer_vocab, EOS)
|
679 | 677 |
|
|
0 commit comments