|
19 | 19 | from __future__ import division
|
20 | 20 | from __future__ import print_function
|
21 | 21 |
|
| 22 | +import os |
22 | 23 | # Dependency imports
|
23 | 24 |
|
| 25 | +from tensor2tensor.data_generators import generator_utils |
24 | 26 | from tensor2tensor.data_generators import image_utils
|
25 | 27 | from tensor2tensor.utils import registry
|
26 | 28 |
|
27 | 29 | import tensorflow as tf
|
28 | 30 |
|
| 31 | +# URLs and filenames for IMAGENET 32x32 data from |
| 32 | +# https://arxiv.org/abs/1601.06759. |
| 33 | +_IMAGENET_SMALL_ROOT_URL = "http://image-net.org/small/" |
| 34 | +_IMAGENET_SMALL_URLS = [ |
| 35 | + "train_32x32.tar", "valid_32x32.tar"] |
| 36 | +_IMAGENET_SMALL_TRAIN_PREFIX = "train_32x32" |
| 37 | +_IMAGENET_SMALL_EVAL_PREFIX = "valid_32x32" |
| 38 | +_IMAGENET_SMALL_IMAGE_SIZE = 32 |
| 39 | + |
| 40 | + |
| 41 | +# URLs and filenames for IMAGENET 64x64 data. |
| 42 | +_IMAGENET_MEDIUM_ROOT_URL = "http://image-net.org/small/" |
| 43 | +_IMAGENET_MEDIUM_URLS = [ |
| 44 | + "train_64x64.tar", "valid_64x64.tar"] |
| 45 | +_IMAGENET_MEDIUM_TRAIN_PREFIX = "train_64x64" |
| 46 | +_IMAGENET_MEDIUM_EVAL_PREFIX = "valid_64x64" |
| 47 | +_IMAGENET_MEDIUM_IMAGE_SIZE = 64 |
| 48 | + |
29 | 49 |
|
30 | 50 | # Derived from ImageNet data
|
31 | 51 | MEAN_RGB = [0.485, 0.456, 0.406]
|
32 | 52 | STDDEV_RGB = [0.229, 0.224, 0.225]
|
33 | 53 |
|
34 | 54 |
|
| 55 | +def imagenet_pixelrnn_generator(tmp_dir, |
| 56 | + training, |
| 57 | + size=_IMAGENET_SMALL_IMAGE_SIZE): |
| 58 | + """Image generator for Imagenet 64x64 downsampled images. |
| 59 | +
|
| 60 | + It assumes that the data has been downloaded from |
| 61 | + http://image-net.org/small/*_32x32.tar or |
| 62 | + http://image-net.org/small/*_64x64.tar into tmp_dir. |
| 63 | + Args: |
| 64 | + tmp_dir: path to temporary storage directory. |
| 65 | + training: a Boolean; if true, we use the train set, otherwise the test set. |
| 66 | + size: image size (assumes height and width are same) |
| 67 | +
|
| 68 | + Yields: |
| 69 | + A dictionary representing the images with the following fields: |
| 70 | + * image/encoded: the string encoding the image as JPEG, |
| 71 | + * image/format: the string "jpeg" representing image format, |
| 72 | + * image/height: an integer representing the height, |
| 73 | + * image/width: an integer representing the width. |
| 74 | + Every field is actually a list of the corresponding type. |
| 75 | + """ |
| 76 | + if size == _IMAGENET_SMALL_IMAGE_SIZE: |
| 77 | + train_prefix = _IMAGENET_SMALL_TRAIN_PREFIX |
| 78 | + eval_prefix = _IMAGENET_SMALL_EVAL_PREFIX |
| 79 | + else: |
| 80 | + train_prefix = _IMAGENET_MEDIUM_TRAIN_PREFIX |
| 81 | + eval_prefix = _IMAGENET_MEDIUM_EVAL_PREFIX |
| 82 | + prefix = train_prefix if training else eval_prefix |
| 83 | + images_filepath = os.path.join(tmp_dir, prefix) |
| 84 | + image_files = tf.gfile.Glob(images_filepath + "/*") |
| 85 | + height = size |
| 86 | + width = size |
| 87 | + const_label = 0 |
| 88 | + for filename in image_files: |
| 89 | + with tf.gfile.Open(filename, "r") as f: |
| 90 | + encoded_image = f.read() |
| 91 | + yield { |
| 92 | + "image/encoded": [encoded_image], |
| 93 | + "image/format": ["png"], |
| 94 | + "image/class/label": [const_label], |
| 95 | + "image/height": [height], |
| 96 | + "image/width": [width] |
| 97 | + } |
| 98 | + |
| 99 | + |
35 | 100 | def imagenet_preprocess_example(example, mode, resize_size=None):
|
36 | 101 | """Preprocessing used for Imagenet and similar problems."""
|
37 | 102 | resize_size = resize_size or [299, 299]
|
@@ -123,6 +188,40 @@ def preprocess_example(self, example, mode, _):
|
123 | 188 | return example
|
124 | 189 |
|
125 | 190 |
|
| 191 | +@registry.register_problem |
| 192 | +class ImageImagenet64Gen(ImageImagenet): |
| 193 | + """Cifar-10 Tune.""" |
| 194 | + |
| 195 | + @property |
| 196 | + def train_shards(self): |
| 197 | + return 1024 |
| 198 | + |
| 199 | + @property |
| 200 | + def dev_shards(self): |
| 201 | + return 10 |
| 202 | + |
| 203 | + def generate_data(self, data_dir, tmp_dir, task_id=-1): |
| 204 | + generator_utils.generate_dataset_and_shuffle( |
| 205 | + self.generator(data_dir, tmp_dir, True), |
| 206 | + self.training_filepaths(data_dir, self.train_shards, shuffled=True), |
| 207 | + self.generator(data_dir, tmp_dir, False), |
| 208 | + self.dev_filepaths(data_dir, self.dev_shards, shuffled=True)) |
| 209 | + |
| 210 | + def generator(self, data_dir, tmp_dir, is_training): |
| 211 | + if is_training: |
| 212 | + return imagenet_pixelrnn_generator( |
| 213 | + tmp_dir, int(True), size=_IMAGENET_MEDIUM_IMAGE_SIZE) |
| 214 | + else: |
| 215 | + return imagenet_pixelrnn_generator( |
| 216 | + tmp_dir, int(False), size=_IMAGENET_MEDIUM_IMAGE_SIZE) |
| 217 | + |
| 218 | + def preprocess_example(self, example, mode, unused_hparams): |
| 219 | + example["inputs"].set_shape([_IMAGENET_MEDIUM_IMAGE_SIZE, |
| 220 | + _IMAGENET_MEDIUM_IMAGE_SIZE, 3]) |
| 221 | + example["inputs"] = tf.to_int64(example["inputs"]) |
| 222 | + return example |
| 223 | + |
| 224 | + |
126 | 225 | @registry.register_problem
|
127 | 226 | class ImageImagenet64(ImageImagenet32):
|
128 | 227 | """Imagenet rescaled to 64x64."""
|
|
0 commit comments