Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 120315c

Browse files
authored
Merge pull request #708 from rsepassi/push
v1.5.7
2 parents c4ca5a4 + 95aeb11 commit 120315c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+2401
-718
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.5.6',
8+
version='1.5.7',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='[email protected]',

tensor2tensor/bin/t2t_datagen.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@
3838

3939
import numpy as np
4040

41+
from tensor2tensor import problems as problems_lib # pylint: disable=unused-import
4142
from tensor2tensor.data_generators import algorithmic_math
42-
from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import
4343
from tensor2tensor.data_generators import audio
4444
from tensor2tensor.data_generators import generator_utils
4545
from tensor2tensor.data_generators import snli

tensor2tensor/bin/t2t_decoder.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def create_decode_hparams():
8282

8383
def decode(estimator, hparams, decode_hp):
8484
if FLAGS.decode_interactive:
85-
decoding.decode_interactively(estimator, hparams, decode_hp, checkpoint_path=FLAGS.checkpoint_path)
85+
decoding.decode_interactively(estimator, hparams, decode_hp,
86+
checkpoint_path=FLAGS.checkpoint_path)
8687
elif FLAGS.decode_from_file:
8788
decoding.decode_from_file(estimator, FLAGS.decode_from_file, hparams,
8889
decode_hp, FLAGS.decode_to_file,

tensor2tensor/bin/t2t_distill.py

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# coding=utf-8
2+
# Copyright 2018 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
r"""Perform distillation for a teacher to student.
17+
18+
This script is intended to be used with --model=distillation. See the model for
19+
example hyperparameters and usage.
20+
"""
21+
from __future__ import absolute_import
22+
from __future__ import division
23+
from __future__ import print_function
24+
25+
import os
26+
27+
# Dependency imports
28+
29+
from tensor2tensor import models # pylint: disable=unused-import
30+
from tensor2tensor import problems as problems_lib # pylint: disable=unused-import
31+
from tensor2tensor.bin import t2t_trainer
32+
from tensor2tensor.utils import cloud_mlengine
33+
from tensor2tensor.utils import flags as t2t_flags # pylint: disable=unused-import
34+
from tensor2tensor.utils import trainer_lib
35+
from tensor2tensor.utils import usr_dir
36+
37+
import tensorflow as tf
38+
39+
flags = tf.flags
40+
FLAGS = flags.FLAGS
41+
42+
43+
def main(argv):
44+
tf.logging.set_verbosity(tf.logging.INFO)
45+
trainer_lib.set_random_seed(FLAGS.random_seed)
46+
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
47+
t2t_trainer.log_registry()
48+
49+
if FLAGS.cloud_mlengine:
50+
return cloud_mlengine.launch()
51+
52+
if FLAGS.generate_data:
53+
t2t_trainer.generate_data()
54+
55+
if cloud_mlengine.job_dir():
56+
FLAGS.output_dir = cloud_mlengine.job_dir()
57+
58+
if argv:
59+
t2t_trainer.set_hparams_from_args(argv[1:])
60+
61+
with t2t_trainer.maybe_cloud_tpu():
62+
root_output_dir = FLAGS.output_dir
63+
64+
# Train Teacher ============
65+
hparams = t2t_trainer.create_hparams()
66+
hparams.distill_phase = "train"
67+
teacher_dir = os.path.join(root_output_dir, "teacher")
68+
FLAGS.output_dir = teacher_dir
69+
70+
exp_fn = t2t_trainer.create_experiment_fn()
71+
run_config = t2t_trainer.create_run_config(hparams)
72+
exp = exp_fn(run_config, hparams)
73+
if t2t_trainer.is_chief():
74+
t2t_trainer.save_metadata(hparams)
75+
t2t_trainer.execute_schedule(exp)
76+
# ==========================
77+
# Train Student ============
78+
hparams = t2t_trainer.create_hparams()
79+
hparams.add_hparam("teacher_dir", teacher_dir)
80+
hparams.distill_phase = "distill"
81+
student_dir = os.path.join(root_output_dir, "student")
82+
FLAGS.output_dir = student_dir
83+
84+
exp_fn = t2t_trainer.create_experiment_fn()
85+
run_config = t2t_trainer.create_run_config(hparams)
86+
exp = exp_fn(run_config, hparams)
87+
88+
if t2t_trainer.is_chief():
89+
t2t_trainer.save_metadata(hparams)
90+
t2t_trainer.execute_schedule(exp)
91+
# ==========================
92+
93+
94+
if __name__ == "__main__":
95+
tf.app.run()

tensor2tensor/data_generators/algorithmic_math.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,8 @@ def generate_calculus_integrate_sample(vlist, ops, min_depth, max_depth,
355355
# functions: Dict of special function names. Maps human readable string names to
356356
# single char names used in flist.
357357
# ops: Dict mapping op symbols (chars) to ExprOp instances.
358-
# solve_ops: Encodes rules for how to algebraically cancel out each operation. See
359-
# doc-string for `algebra_inverse_solve`.
358+
# solve_ops: Encodes rules for how to algebraically cancel out each operation.
359+
# See doc-string for `algebra_inverse_solve`.
360360
# int_encoder: Function that maps a string to a list of tokens. Use this to
361361
# encode an expression to feed into a model.
362362
# int_decoder: Function that maps a list of tokens to a string. Use this to

tensor2tensor/data_generators/all_problems.py

+43-42
Original file line numberDiff line numberDiff line change
@@ -18,47 +18,48 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21-
# pylint: disable=unused-import
22-
from tensor2tensor.data_generators import algorithmic
23-
from tensor2tensor.data_generators import algorithmic_math
24-
from tensor2tensor.data_generators import audio
25-
from tensor2tensor.data_generators import celeba
26-
from tensor2tensor.data_generators import cifar
27-
from tensor2tensor.data_generators import cipher
28-
from tensor2tensor.data_generators import cnn_dailymail
29-
from tensor2tensor.data_generators import desc2code
30-
from tensor2tensor.data_generators import fsns
31-
from tensor2tensor.data_generators import gym
32-
from tensor2tensor.data_generators import ice_parsing
33-
from tensor2tensor.data_generators import imagenet
34-
from tensor2tensor.data_generators import imdb
35-
from tensor2tensor.data_generators import librispeech
36-
from tensor2tensor.data_generators import lm1b
37-
from tensor2tensor.data_generators import mnist
38-
from tensor2tensor.data_generators import mscoco
39-
from tensor2tensor.data_generators import multinli
40-
from tensor2tensor.data_generators import ocr
41-
from tensor2tensor.data_generators import problem_hparams
42-
from tensor2tensor.data_generators import ptb
43-
from tensor2tensor.data_generators import snli
44-
from tensor2tensor.data_generators import squad
45-
from tensor2tensor.data_generators import translate_encs
46-
from tensor2tensor.data_generators import translate_ende
47-
from tensor2tensor.data_generators import translate_enfr
48-
from tensor2tensor.data_generators import translate_enmk
49-
from tensor2tensor.data_generators import translate_envi
50-
from tensor2tensor.data_generators import translate_enzh
51-
from tensor2tensor.data_generators import twentybn
52-
from tensor2tensor.data_generators import wiki
53-
from tensor2tensor.data_generators import wsj_parsing
21+
import importlib
5422

5523

56-
# Problem modules that require optional dependencies
57-
# pylint: disable=g-import-not-at-top
58-
try:
59-
# Requires h5py
60-
from tensor2tensor.data_generators import gene_expression
61-
except ImportError:
62-
pass
63-
# pylint: enable=g-import-not-at-top
64-
# pylint: enable=unused-import
24+
modules = [
25+
"tensor2tensor.data_generators.algorithmic",
26+
"tensor2tensor.data_generators.algorithmic_math",
27+
"tensor2tensor.data_generators.audio",
28+
"tensor2tensor.data_generators.celeba",
29+
"tensor2tensor.data_generators.cifar",
30+
"tensor2tensor.data_generators.cipher",
31+
"tensor2tensor.data_generators.cnn_dailymail",
32+
"tensor2tensor.data_generators.desc2code",
33+
"tensor2tensor.data_generators.fsns",
34+
"tensor2tensor.data_generators.gene_expression",
35+
"tensor2tensor.data_generators.gym",
36+
"tensor2tensor.data_generators.ice_parsing",
37+
"tensor2tensor.data_generators.imagenet",
38+
"tensor2tensor.data_generators.imdb",
39+
"tensor2tensor.data_generators.librispeech",
40+
"tensor2tensor.data_generators.lm1b",
41+
"tensor2tensor.data_generators.mnist",
42+
"tensor2tensor.data_generators.mscoco",
43+
"tensor2tensor.data_generators.multinli",
44+
"tensor2tensor.data_generators.ocr",
45+
"tensor2tensor.data_generators.problem_hparams",
46+
"tensor2tensor.data_generators.ptb",
47+
"tensor2tensor.data_generators.snli",
48+
"tensor2tensor.data_generators.squad",
49+
"tensor2tensor.data_generators.translate_encs",
50+
"tensor2tensor.data_generators.translate_ende",
51+
"tensor2tensor.data_generators.translate_enfr",
52+
"tensor2tensor.data_generators.translate_enmk",
53+
"tensor2tensor.data_generators.translate_envi",
54+
"tensor2tensor.data_generators.translate_enzh",
55+
"tensor2tensor.data_generators.twentybn",
56+
"tensor2tensor.data_generators.wiki",
57+
"tensor2tensor.data_generators.wsj_parsing",
58+
]
59+
60+
61+
for module in modules:
62+
try:
63+
importlib.import_module(module)
64+
except ImportError as error:
65+
print("Did not import module: %s; Cause: %s" % (module, str(error)))

tensor2tensor/data_generators/cifar.py

+112-2
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,17 @@ def cifar_generator(cifar_version, tmp_dir, training, how_many, start_from=0):
7777
test_files = _CIFAR10_TEST_FILES
7878
prefix = _CIFAR10_PREFIX
7979
image_size = _CIFAR10_IMAGE_SIZE
80-
elif cifar_version == "cifar100":
80+
label_key = "labels"
81+
elif cifar_version == "cifar100" or cifar_version == "cifar20":
8182
url = _CIFAR100_URL
8283
train_files = _CIFAR100_TRAIN_FILES
8384
test_files = _CIFAR100_TEST_FILES
8485
prefix = _CIFAR100_PREFIX
8586
image_size = _CIFAR100_IMAGE_SIZE
87+
if cifar_version == "cifar100":
88+
label_key = "fine_labels"
89+
else:
90+
label_key = "coarse_labels"
8691

8792
_get_cifar(tmp_dir, url)
8893
data_files = train_files if training else test_files
@@ -97,7 +102,7 @@ def cifar_generator(cifar_version, tmp_dir, training, how_many, start_from=0):
97102
all_images.extend([
98103
np.squeeze(images[j]).transpose((1, 2, 0)) for j in xrange(num_images)
99104
])
100-
labels = data["labels" if cifar_version == "cifar10" else "fine_labels"]
105+
labels = data[label_key]
101106
all_labels.extend([labels[j] for j in xrange(num_images)])
102107
return image_utils.image_generator(
103108
all_images[start_from:start_from + how_many],
@@ -417,3 +422,108 @@ def hparams(self, defaults, unused_model_hparams):
417422
p.max_expected_batch_size_per_shard = 4
418423
p.input_space_id = 1
419424
p.target_space_id = 1
425+
426+
427+
@registry.register_problem
428+
class ImageCifar20Tune(mnist.ImageMnistTune):
429+
"""Cifar-20 Tune."""
430+
431+
@property
432+
def num_classes(self):
433+
return 20
434+
435+
@property
436+
def num_channels(self):
437+
return 3
438+
439+
@property
440+
def class_labels(self):
441+
return [
442+
"aquatic mammals",
443+
"fish",
444+
"flowers",
445+
"food containers",
446+
"fruit and vegetables",
447+
"household electrical devices",
448+
"household furniture",
449+
"insects",
450+
"large carnivores",
451+
"large man-made outdoor things",
452+
"large natural outdoor scenes",
453+
"large omnivores and herbivores",
454+
"medium-sized mammals",
455+
"non-insect invertebrates",
456+
"people",
457+
"reptiles",
458+
"small mammals",
459+
"trees",
460+
"vehicles 1",
461+
"vehicles 2",
462+
]
463+
464+
def preprocess_example(self, example, mode, unused_hparams):
465+
image = example["inputs"]
466+
image.set_shape([_CIFAR100_IMAGE_SIZE, _CIFAR100_IMAGE_SIZE, 3])
467+
if mode == tf.estimator.ModeKeys.TRAIN:
468+
image = image_utils.cifar_image_augmentation(image)
469+
if not self._was_reversed:
470+
image = tf.image.per_image_standardization(image)
471+
example["inputs"] = image
472+
return example
473+
474+
def generator(self, data_dir, tmp_dir, is_training):
475+
if is_training:
476+
return cifar_generator("cifar20", tmp_dir, True, 48000)
477+
else:
478+
return cifar_generator("cifar20", tmp_dir, True, 2000, 48000)
479+
480+
481+
@registry.register_problem
482+
class ImageCifar20(ImageCifar20Tune):
483+
484+
def generator(self, data_dir, tmp_dir, is_training):
485+
if is_training:
486+
return cifar_generator("cifar20", tmp_dir, True, 50000)
487+
else:
488+
return cifar_generator("cifar20", tmp_dir, False, 10000)
489+
490+
491+
@registry.register_problem
492+
class ImageCifar20Plain(ImageCifar20):
493+
494+
def preprocess_example(self, example, mode, unused_hparams):
495+
image = example["inputs"]
496+
image.set_shape([_CIFAR100_IMAGE_SIZE, _CIFAR100_IMAGE_SIZE, 3])
497+
if not self._was_reversed:
498+
image = tf.image.per_image_standardization(image)
499+
example["inputs"] = image
500+
return example
501+
502+
503+
@registry.register_problem
504+
class ImageCifar20PlainGen(ImageCifar20Plain):
505+
"""CIFAR-20 32x32 for image generation without standardization preprep."""
506+
507+
def dataset_filename(self):
508+
return "image_cifar20_plain" # Reuse CIFAR-20 plain data.
509+
510+
def preprocess_example(self, example, mode, unused_hparams):
511+
example["inputs"].set_shape([_CIFAR100_IMAGE_SIZE, _CIFAR100_IMAGE_SIZE, 3])
512+
example["inputs"] = tf.to_int64(example["inputs"])
513+
return example
514+
515+
516+
@registry.register_problem
517+
class ImageCifar20Plain8(ImageCifar20):
518+
"""CIFAR-20 rescaled to 8x8 for output: Conditional image generation."""
519+
520+
def dataset_filename(self):
521+
return "image_cifar20_plain" # Reuse CIFAR-20 plain data.
522+
523+
def preprocess_example(self, example, mode, unused_hparams):
524+
image = example["inputs"]
525+
image = image_utils.resize_by_area(image, 8)
526+
if not self._was_reversed:
527+
image = tf.image.per_image_standardization(image)
528+
example["inputs"] = image
529+
return example

0 commit comments

Comments
 (0)