Skip to content

Commit ceabef4

Browse files
committed
Added AnnotatedObjectsCOCO 🌆
1 parent 9d17ea6 commit ceabef4

14 files changed

+953
-5
lines changed

Diff for: configs/coco_cond_stage.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ model:
3030
codebook_weight: 1.0
3131

3232
data:
33-
target: cutlit.DataModuleFromConfig
33+
target: main.DataModuleFromConfig
3434
params:
3535
batch_size: 12
3636
train:
@@ -41,7 +41,7 @@ data:
4141
onehot_segmentation: true
4242
use_stuffthing: true
4343
validation:
44-
target: taming.data.coco.CocoImagesAndCaptionsTrain
44+
target: taming.data.coco.CocoImagesAndCaptionsValidation
4545
params:
4646
size: 256
4747
crop_size: 256

Diff for: configs/coco_scene_images_transformer.yaml

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
model:
2+
base_learning_rate: 4.5e-06
3+
target: taming.models.cond_transformer.Net2NetTransformer
4+
params:
5+
cond_stage_key: objects_bbox
6+
transformer_config:
7+
target: taming.modules.transformer.mingpt.GPT
8+
params:
9+
vocab_size: 8192
10+
block_size: 348 # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim)
11+
n_layer: 32
12+
n_head: 16
13+
n_embd: 912
14+
first_stage_config:
15+
target: taming.models.vqgan.VQModel
16+
params:
17+
ckpt_path: /path/to/coco_epoch117.ckpt # https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/
18+
embed_dim: 256
19+
n_embed: 8192
20+
ddconfig:
21+
double_z: false
22+
z_channels: 256
23+
resolution: 256
24+
in_channels: 3
25+
out_ch: 3
26+
ch: 128
27+
ch_mult:
28+
- 1
29+
- 1
30+
- 2
31+
- 2
32+
- 4
33+
num_res_blocks: 2
34+
attn_resolutions:
35+
- 16
36+
dropout: 0.0
37+
lossconfig:
38+
target: taming.modules.losses.DummyLoss
39+
cond_stage_config:
40+
target: taming.models.dummy_cond_stage.DummyCondStage
41+
params:
42+
conditional_key: objects_bbox
43+
44+
data:
45+
target: main.DataModuleFromConfig
46+
params:
47+
batch_size: 24
48+
train:
49+
target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
50+
params:
51+
data_path: data/coco
52+
split: train
53+
keys: [image, objects_bbox, file_name]
54+
no_tokens: 8192
55+
target_image_size: 256
56+
min_object_area: 0.00001
57+
min_objects_per_image: 2
58+
max_objects_per_image: 30
59+
crop_method: random-1d
60+
random_flip: true
61+
use_group_parameter: true
62+
encode_crop: true
63+
validation:
64+
target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
65+
params:
66+
data_path: data/coco
67+
split: validation
68+
keys: [image, objects_bbox, file_name]
69+
no_tokens: 8192
70+
target_image_size: 256
71+
min_object_area: 0.00001
72+
min_objects_per_image: 2
73+
max_objects_per_image: 30
74+
crop_method: random-1d
75+
random_flip: true
76+
use_group_parameter: true
77+
encode_crop: true

Diff for: environment.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@ dependencies:
2020
- test-tube>=0.7.5
2121
- streamlit>=0.73.1
2222
- einops==0.3.0
23+
- more-itertools>=8.0.0
2324
- transformers==4.3.1
2425
- -e .

Diff for: main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def log_img(self, pl_module, batch, batch_idx, split="train"):
278278
pl_module.eval()
279279

280280
with torch.no_grad():
281-
images = pl_module.log_images(batch, split=split)
281+
images = pl_module.log_images(batch, split=split, pl_module=pl_module)
282282

283283
for k in images:
284284
N = min(images[k].shape[0], self.max_images)

Diff for: taming/data/annotated_objects_coco.py

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import json
2+
from itertools import chain
3+
from pathlib import Path
4+
from typing import Iterable, Dict, List, Callable, Any
5+
from collections import defaultdict
6+
7+
from tqdm import tqdm
8+
9+
from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
10+
from taming.data.helper_types import Annotation, ImageDescription, Category
11+
12+
COCO_PATH_STRUCTURE = {
13+
'train': {
14+
'top_level': '',
15+
'person_annotations': 'annotations/person_keypoints_train2017.json',
16+
'instances_annotations': 'annotations/instances_train2017.json',
17+
'stuff_annotations': 'annotations/stuff_train2017.json',
18+
'files': 'train2017'
19+
},
20+
'validation': {
21+
'top_level': '',
22+
'person_annotations': 'annotations/person_keypoints_val2017.json',
23+
'instances_annotations': 'annotations/instances_val2017.json',
24+
'stuff_annotations': 'annotations/stuff_val2017.json',
25+
'files': 'val2017'
26+
}
27+
}
28+
29+
30+
def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]:
31+
return {
32+
str(img['id']): ImageDescription(
33+
id=img['id'],
34+
license=img.get('license'),
35+
file_name=img['file_name'],
36+
coco_url=img['coco_url'],
37+
original_size=(img['width'], img['height']),
38+
date_captured=img.get('date_captured'),
39+
flickr_url=img.get('flickr_url')
40+
)
41+
for img in description_json
42+
}
43+
44+
45+
def load_categories(category_json: Iterable) -> Dict[str, Category]:
46+
return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name'])
47+
for cat in category_json if cat['name'] != 'other'}
48+
49+
50+
def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription],
51+
category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]:
52+
annotations = defaultdict(list)
53+
total = sum(len(a) for a in annotations_json)
54+
for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total):
55+
image_id = str(ann['image_id'])
56+
if image_id not in image_descriptions:
57+
raise ValueError(f'image_id [{image_id}] has no image description.')
58+
category_id = ann['category_id']
59+
try:
60+
category_no = category_no_for_id(str(category_id))
61+
except KeyError:
62+
continue
63+
64+
width, height = image_descriptions[image_id].original_size
65+
bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height)
66+
67+
annotations[image_id].append(
68+
Annotation(
69+
id=ann['id'],
70+
area=bbox[2]*bbox[3], # use bbox area
71+
is_group_of=ann['iscrowd'],
72+
image_id=ann['image_id'],
73+
bbox=bbox,
74+
category_id=str(category_id),
75+
category_no=category_no
76+
)
77+
)
78+
return dict(annotations)
79+
80+
81+
class AnnotatedObjectsCoco(AnnotatedObjectsDataset):
82+
def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs):
83+
"""
84+
@param data_path: is the path to the following folder structure:
85+
coco/
86+
├── annotations
87+
│ ├── instances_train2017.json
88+
│ ├── instances_val2017.json
89+
│ ├── stuff_train2017.json
90+
│ └── stuff_val2017.json
91+
├── train2017
92+
│ ├── 000000000009.jpg
93+
│ ├── 000000000025.jpg
94+
│ └── ...
95+
├── val2017
96+
│ ├── 000000000139.jpg
97+
│ ├── 000000000285.jpg
98+
│ └── ...
99+
@param: split: one of 'train' or 'validation'
100+
@param: desired image size (give square images)
101+
"""
102+
super().__init__(**kwargs)
103+
self.use_things = use_things
104+
self.use_stuff = use_stuff
105+
106+
with open(self.paths['instances_annotations']) as f:
107+
inst_data_json = json.load(f)
108+
with open(self.paths['stuff_annotations']) as f:
109+
stuff_data_json = json.load(f)
110+
111+
category_jsons = []
112+
annotation_jsons = []
113+
if self.use_things:
114+
category_jsons.append(inst_data_json['categories'])
115+
annotation_jsons.append(inst_data_json['annotations'])
116+
if self.use_stuff:
117+
category_jsons.append(stuff_data_json['categories'])
118+
annotation_jsons.append(stuff_data_json['annotations'])
119+
120+
self.categories = load_categories(chain(*category_jsons))
121+
self.filter_categories()
122+
self.setup_category_id_and_number()
123+
124+
self.image_descriptions = load_image_descriptions(inst_data_json['images'])
125+
annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split)
126+
self.annotations = self.filter_object_number(annotations, self.min_object_area,
127+
self.min_objects_per_image, self.max_objects_per_image)
128+
self.image_ids = list(self.annotations.keys())
129+
self.clean_up_annotations_and_image_descriptions()
130+
131+
def get_path_structure(self) -> Dict[str, str]:
132+
if self.split not in COCO_PATH_STRUCTURE:
133+
raise ValueError(f'Split [{self.split} does not exist for COCO data.]')
134+
return COCO_PATH_STRUCTURE[self.split]
135+
136+
def get_image_path(self, image_id: str) -> Path:
137+
return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name)
138+
139+
def get_image_description(self, image_id: str) -> Dict[str, Any]:
140+
# noinspection PyProtectedMember
141+
return self.image_descriptions[image_id]._asdict()

0 commit comments

Comments
 (0)