Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 3c5823f

Browse files
authoredOct 16, 2017
Merge pull request #361 from rsepassi/push
v1.2.5
2 parents 3a9c950 + fa9ad63 commit 3c5823f

30 files changed

+1992
-541
lines changed
 

‎.travis.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@ script:
2424
- mkdir $T2T_TRAIN_DIR
2525
- t2t-datagen --problem=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR
2626
- t2t-trainer --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --train_steps=5 --eval_steps=5 --output_dir=$T2T_TRAIN_DIR
27-
- t2t-decoder --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR
27+
- t2t-decoder --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR --decode_hparams='num_samples=10'
2828
git:
2929
depth: 3

‎setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.2.4',
8+
version='1.2.5',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='no-reply@google.com',

‎tensor2tensor/data_generators/cnn_dailymail.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def story_generator(tmp_dir):
7474
for path in paths:
7575
for story_file in tf.gfile.Glob(path + "*"):
7676
story = u""
77-
for line in tf.gfile.Open(story_file, 'rb'):
77+
for line in tf.gfile.Open(story_file, "rb"):
7878
line = unicode(line, "utf-8") if six.PY2 else line.decode("utf-8")
7979
story += line
8080
yield story

‎tensor2tensor/data_generators/generator_utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,8 @@ def generate():
355355
for lang_file in source[1]:
356356
tf.logging.info("Reading file: %s" % lang_file)
357357
filepath = os.path.join(tmp_dir, lang_file)
358+
359+
# Extract from tar if needed.
358360
if not tf.gfile.Exists(filepath):
359361
read_type = "r:gz" if filename.endswith("tgz") else "r"
360362
with tarfile.open(compressed_file, read_type) as corpus_tar:
@@ -411,7 +413,7 @@ def generate():
411413
for line in source_file:
412414
line = line.strip()
413415
if line and "\t" in line:
414-
parts = line.split("\t", maxsplit=1)
416+
parts = line.split("\t", 1)
415417
part = parts[index].strip()
416418
yield part
417419

‎tensor2tensor/data_generators/image.py

+47-12
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@
4242
import tensorflow as tf
4343

4444

45+
def resize_by_area(img, size):
46+
"""image resize function used by quite a few image problems."""
47+
return tf.to_int64(
48+
tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA))
49+
50+
4551
class ImageProblem(problem.Problem):
4652

4753
def example_reading_spec(self, label_key=None):
@@ -93,16 +99,12 @@ class ImageCeleba(ImageProblem):
9399

94100
def preprocess_example(self, example, unused_mode, unused_hparams):
95101

96-
def resize(img, size):
97-
return tf.to_int64(
98-
tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA))
99-
100102
inputs = example["inputs"]
101103
# Remove boundaries in CelebA images. Remove 40 pixels each side
102104
# vertically and 20 pixels each side horizontally.
103105
inputs = tf.image.crop_to_bounding_box(inputs, 40, 20, 218 - 80, 178 - 40)
104-
example["inputs"] = resize(inputs, 8)
105-
example["targets"] = resize(inputs, 32)
106+
example["inputs"] = resize_by_area(inputs, 8)
107+
example["targets"] = resize_by_area(inputs, 32)
106108
return example
107109

108110
def hparams(self, defaults, unused_model_hparams):
@@ -388,14 +390,10 @@ def dataset_filename(self):
388390

389391
def preprocess_example(self, example, unused_mode, unused_hparams):
390392

391-
def resize(img, size):
392-
return tf.to_int64(
393-
tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA))
394-
395393
inputs = example["inputs"]
396394
# For Img2Img resize input and output images as desired.
397-
example["inputs"] = resize(inputs, 8)
398-
example["targets"] = resize(inputs, 32)
395+
example["inputs"] = resize_by_area(inputs, 8)
396+
example["targets"] = resize_by_area(inputs, 32)
399397
return example
400398

401399
def hparams(self, defaults, unused_model_hparams):
@@ -654,6 +652,43 @@ def preprocess_example(self, example, mode, unused_hparams):
654652
return example
655653

656654

655+
@registry.register_problem
656+
class ImageCifar10Plain8(ImageCifar10):
657+
"""CIFAR-10 rescaled to 8x8 for output: Conditional image generation."""
658+
659+
def dataset_filename(self):
660+
return "image_cifar10_plain" # Reuse CIFAR-10 plain data.
661+
662+
def preprocess_example(self, example, mode, unused_hparams):
663+
example["inputs"] = resize_by_area(example["inputs"], 8)
664+
return example
665+
666+
667+
@registry.register_problem
668+
class Img2imgCifar10(ImageCifar10):
669+
"""CIFAR-10 rescaled to 8x8 for input and 32x32 for output."""
670+
671+
def dataset_filename(self):
672+
return "image_cifar10_plain" # Reuse CIFAR-10 plain data.
673+
674+
def preprocess_example(self, example, unused_mode, unused_hparams):
675+
676+
inputs = example["inputs"]
677+
# For Img2Img resize input and output images as desired.
678+
example["inputs"] = resize_by_area(inputs, 8)
679+
example["targets"] = resize_by_area(inputs, 32)
680+
return example
681+
682+
def hparams(self, defaults, unused_model_hparams):
683+
p = defaults
684+
p.input_modality = {"inputs": ("image:identity_no_pad", None)}
685+
p.target_modality = ("image:identity_no_pad", None)
686+
p.batch_size_multiplier = 256
687+
p.max_expected_batch_size_per_shard = 4
688+
p.input_space_id = 1
689+
p.target_space_id = 1
690+
691+
657692
# URLs and filenames for MSCOCO data.
658693
_MSCOCO_ROOT_URL = "http://msvocds.blob.core.windows.net/"
659694
_MSCOCO_URLS = [

‎tensor2tensor/data_generators/wmt.py

+18-20
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22-
import glob
2322
import os
24-
import stat
2523
import tarfile
2624

2725
# Dependency imports
@@ -115,7 +113,7 @@ def tabbed_generator(source_path, source_vocab, target_vocab, eos=None):
115113
with tf.gfile.GFile(source_path, mode="r") as source_file:
116114
for line in source_file:
117115
if line and "\t" in line:
118-
parts = line.split("\t", maxsplit=1)
116+
parts = line.split("\t", 1)
119117
source, target = parts[0].strip(), parts[1].strip()
120118
source_ints = source_vocab.encode(source) + eos_list
121119
target_ints = target_vocab.encode(target) + eos_list
@@ -267,8 +265,9 @@ def bi_vocabs_token_generator(source_path,
267265
# English-Czech datasets
268266
_ENCS_TRAIN_DATASETS = [
269267
[
270-
"https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-1458/data-plaintext-format.tar",
271-
('tsv', 3, 2, 'data.plaintext-format/*train.gz')
268+
("https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/"
269+
"11234/1-1458/data-plaintext-format.tar"),
270+
("tsv", 3, 2, "data.plaintext-format/*train.gz")
272271
],
273272
[
274273
"http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz", # pylint: disable=line-too-long
@@ -375,25 +374,22 @@ def _compile_data(tmp_dir, datasets, filename):
375374
url = dataset[0]
376375
compressed_filename = os.path.basename(url)
377376
compressed_filepath = os.path.join(tmp_dir, compressed_filename)
377+
378378
generator_utils.maybe_download(tmp_dir, compressed_filename, url)
379379

380-
if dataset[1][0] == 'tsv':
380+
if dataset[1][0] == "tsv":
381381
_, src_column, trg_column, glob_pattern = dataset[1]
382-
filenames = glob.glob(os.path.join(tmp_dir, glob_pattern))
382+
filenames = tf.gfile.Glob(os.path.join(tmp_dir, glob_pattern))
383383
if not filenames:
384-
mode = "r:gz" if compressed_filepath.endswith("gz") else "r" # *.tgz *.tar.gz
384+
# Capture *.tgz and *.tar.gz too.
385+
mode = "r:gz" if compressed_filepath.endswith("gz") else "r"
385386
with tarfile.open(compressed_filepath, mode) as corpus_tar:
386387
corpus_tar.extractall(tmp_dir)
387-
filenames = glob.glob(os.path.join(tmp_dir, glob_pattern))
388+
filenames = tf.gfile.Glob(os.path.join(tmp_dir, glob_pattern))
388389
for tsv_filename in filenames:
389390
if tsv_filename.endswith(".gz"):
390391
new_filename = tsv_filename.strip(".gz")
391-
try:
392-
generator_utils.gunzip_file(tsv_filename, new_filename)
393-
except PermissionError:
394-
tsvdir = os.path.dirname(tsv_filename)
395-
os.chmod(tsvdir, os.stat(tsvdir).st_mode | stat.S_IWRITE)
396-
generator_utils.gunzip_file(tsv_filename, new_filename)
392+
generator_utils.gunzip_file(tsv_filename, new_filename)
397393
tsv_filename = new_filename
398394
with tf.gfile.GFile(tsv_filename, mode="r") as tsv_file:
399395
for line in tsv_file:
@@ -663,17 +659,19 @@ def vocab_name(self):
663659
def generator(self, data_dir, tmp_dir, train):
664660
datasets = _ENCS_TRAIN_DATASETS if train else _ENCS_TEST_DATASETS
665661
tag = "train" if train else "dev"
666-
data_path = _compile_data(tmp_dir, datasets, "wmt_encs_tok_%s" % tag)
667662
vocab_datasets = []
663+
data_path = _compile_data(tmp_dir, datasets, "wmt_encs_tok_%s" % tag)
668664
# CzEng contains 100 gz files with tab-separated columns, so let's expect
669-
# it is the first dataset in datasets and use the newly created *.lang{1,2} files instead.
665+
# it is the first dataset in datasets and use the newly created *.lang{1,2}
666+
# files for vocab construction.
670667
if datasets[0][0].endswith("data-plaintext-format.tar"):
671-
vocab_datasets.append([datasets[0][0],
672-
["wmt_encs_tok_%s.lang1" % tag, "wmt_encs_tok_%s.lang2" % tag]])
668+
vocab_datasets.append([datasets[0][0], ["wmt_encs_tok_%s.lang1" % tag,
669+
"wmt_encs_tok_%s.lang2" % tag]])
673670
datasets = datasets[1:]
674671
vocab_datasets += [[item[0], [item[1][0], item[1][1]]] for item in datasets]
675672
symbolizer_vocab = generator_utils.get_or_generate_vocab(
676-
data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size, vocab_datasets)
673+
data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size,
674+
vocab_datasets)
677675
return token_generator(data_path + ".lang1", data_path + ".lang2",
678676
symbolizer_vocab, EOS)
679677

0 commit comments

Comments
 (0)
This repository has been archived.