Skip to content

Commit 78bab5e

Browse files
BIGWangYuDongZwwWayne
authored andcommitted
[Refactor] Fully refactor yolact
1 parent 5a2ef66 commit 78bab5e

18 files changed

+963
-769
lines changed

configs/solo/decoupled_solo_light_r50_fpn_3x_coco.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,10 @@
3030
file_client_args={{_base_.file_client_args}}),
3131
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
3232
dict(
33-
# TODO: Update after mmcv.RandomChoiceResize finish refactor
3433
type='RandomChoiceResize',
3534
scales=[(852, 512), (852, 480), (852, 448), (852, 416), (852, 384),
3635
(852, 352)],
37-
resize_cfg=dict(type='Resize', keep_ratio=True)),
36+
keep_ratio=True),
3837
dict(type='RandomFlip', prob=0.5),
3938
dict(type='PackDetInputs')
4039
]

configs/solo/solo_r50_fpn_1x_coco.py

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
mean=[123.675, 116.28, 103.53],
1111
std=[58.395, 57.12, 57.375],
1212
bgr_to_rgb=True,
13+
pad_mask=True,
1314
pad_size_divisor=32),
1415
backbone=dict(
1516
type='ResNet',

configs/solo/solo_r50_fpn_3x_coco.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
file_client_args={{_base_.file_client_args}}),
77
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
88
dict(
9-
# TODO: Update after mmcv.RandomChoiceResize finish refactor
109
type='RandomChoiceResize',
1110
scales=[(1333, 800), (1333, 768), (1333, 736), (1333, 704),
1211
(1333, 672), (1333, 640)],
13-
resize_cfg=dict(type='Resize', keep_ratio=True)),
12+
keep_ratio=True),
1413
dict(type='RandomFlip', prob=0.5),
1514
dict(type='PackDetInputs')
1615
]

configs/solov2/solov2_light_r50_fpn_mstrain_3x_coco.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515
file_client_args={{_base_.file_client_args}}),
1616
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
1717
dict(
18-
# TODO: Update after mmcv.RandomChoiceResize finish refactor
1918
type='RandomChoiceResize',
2019
scales=[(768, 512), (768, 480), (768, 448), (768, 416), (768, 384),
2120
(768, 352)],
22-
resize_cfg=dict(type='Resize', keep_ratio=True)),
21+
keep_ratio=True),
2322
dict(type='RandomFlip', prob=0.5),
2423
dict(type='PackDetInputs')
2524
]

configs/solov2/solov2_r50_fpn_1x_coco.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
mean=[123.675, 116.28, 103.53],
1212
std=[58.395, 57.12, 57.375],
1313
bgr_to_rgb=True,
14+
pad_mask=True,
1415
pad_size_divisor=32),
1516
backbone=dict(
1617
type='ResNet',

configs/solov2/solov2_r50_fpn_mstrain_3x_coco.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
file_client_args={{_base_.file_client_args}}),
77
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
88
dict(
9-
# TODO: Update after mmcv.RandomChoiceResize finish refactor
109
type='RandomChoiceResize',
1110
scales=[(1333, 800), (1333, 768), (1333, 736), (1333, 704),
1211
(1333, 672), (1333, 640)],
13-
resize_cfg=dict(type='Resize', keep_ratio=True)),
12+
keep_ratio=True),
1413
dict(type='RandomFlip', prob=0.5),
1514
dict(type='PackDetInputs')
1615
]

configs/yolact/yolact_r50_1x8_coco.py

+64-58
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
1-
_base_ = '../_base_/default_runtime.py'
2-
1+
_base_ = [
2+
'../_base_/datasets/coco_instance.py', '../_base_/default_runtime.py'
3+
]
4+
img_norm_cfg = dict(
5+
mean=[123.68, 116.78, 103.94], std=[58.40, 57.12, 57.38], to_rgb=True)
36
# model settings
4-
img_size = 550
7+
input_size = 550
58
model = dict(
69
type='YOLACT',
10+
data_preprocessor=dict(
11+
type='DetDataPreprocessor',
12+
mean=img_norm_cfg['mean'],
13+
std=img_norm_cfg['std'],
14+
bgr_to_rgb=img_norm_cfg['to_rgb'],
15+
pad_mask=True),
716
backbone=dict(
817
type='ResNet',
918
depth=50,
@@ -56,11 +65,8 @@
5665
num_protos=32,
5766
num_classes=80,
5867
max_masks_to_train=100,
59-
loss_mask_weight=6.125),
60-
segm_head=dict(
61-
type='YOLACTSegmHead',
62-
num_classes=80,
63-
in_channels=256,
68+
loss_mask_weight=6.125,
69+
with_seg_branch=True,
6470
loss_segm=dict(
6571
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
6672
# training and testing settings
@@ -72,6 +78,7 @@
7278
min_pos_iou=0.,
7379
ignore_iof_thr=-1,
7480
gt_max_assign_all=False),
81+
sampler=dict(type='PseudoSampler'), # YOLACT should use PseudoSampler
7582
# smoothl1_beta=1.,
7683
allowed_border=-1,
7784
pos_weight=-1,
@@ -81,16 +88,16 @@
8188
nms_pre=1000,
8289
min_bbox_size=0,
8390
score_thr=0.05,
91+
mask_thr=0.5,
8492
iou_thr=0.5,
8593
top_k=200,
86-
max_per_img=100))
94+
max_per_img=100,
95+
mask_thr_binary=0.5))
8796
# dataset settings
88-
dataset_type = 'CocoDataset'
89-
data_root = 'data/coco/'
90-
img_norm_cfg = dict(
91-
mean=[123.68, 116.78, 103.94], std=[58.40, 57.12, 57.38], to_rgb=True)
9297
train_pipeline = [
93-
dict(type='LoadImageFromFile'),
98+
dict(
99+
type='LoadImageFromFile',
100+
file_client_args={{_base_.file_client_args}}),
94101
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
95102
dict(type='FilterAnnotations', min_gt_bbox_wh=(4.0, 4.0)),
96103
dict(
@@ -102,62 +109,61 @@
102109
type='MinIoURandomCrop',
103110
min_ious=(0.1, 0.3, 0.5, 0.7, 0.9),
104111
min_crop_size=0.3),
105-
dict(type='Resize', img_scale=(img_size, img_size), keep_ratio=False),
106-
dict(type='RandomFlip', flip_ratio=0.5),
112+
dict(type='Resize', scale=(input_size, input_size), keep_ratio=False),
113+
dict(type='RandomFlip', prob=0.5),
107114
dict(
108115
type='PhotoMetricDistortion',
109116
brightness_delta=32,
110117
contrast_range=(0.5, 1.5),
111118
saturation_range=(0.5, 1.5),
112119
hue_delta=18),
113-
dict(type='Normalize', **img_norm_cfg),
114-
dict(type='DefaultFormatBundle'),
115-
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
120+
dict(type='PackDetInputs')
116121
]
117122
test_pipeline = [
118123
dict(type='LoadImageFromFile'),
124+
dict(type='Resize', scale=(input_size, input_size), keep_ratio=False),
125+
dict(
126+
type='PackDetInputs',
127+
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
128+
'scale_factor'))
129+
]
130+
train_dataloader = dict(
131+
batch_size=8,
132+
num_workers=4,
133+
batch_sampler=None,
134+
dataset=dict(pipeline=train_pipeline))
135+
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
136+
test_dataloader = val_dataloader
137+
138+
max_epochs = 55
139+
# training schedule for 55e
140+
train_cfg = dict(
141+
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)
142+
val_cfg = dict(type='ValLoop')
143+
test_cfg = dict(type='TestLoop')
144+
145+
# learning rate
146+
param_scheduler = [
147+
dict(type='LinearLR', start_factor=0.1, by_epoch=False, begin=0, end=500),
119148
dict(
120-
type='MultiScaleFlipAug',
121-
img_scale=(img_size, img_size),
122-
flip=False,
123-
transforms=[
124-
dict(type='Resize', keep_ratio=False),
125-
dict(type='Normalize', **img_norm_cfg),
126-
dict(type='ImageToTensor', keys=['img']),
127-
dict(type='Collect', keys=['img']),
128-
])
149+
type='MultiStepLR',
150+
begin=0,
151+
end=max_epochs,
152+
by_epoch=True,
153+
milestones=[20, 42, 49, 52],
154+
gamma=0.1)
129155
]
130-
data = dict(
131-
samples_per_gpu=8,
132-
workers_per_gpu=4,
133-
train=dict(
134-
type=dataset_type,
135-
ann_file=data_root + 'annotations/instances_train2017.json',
136-
img_prefix=data_root + 'train2017/',
137-
pipeline=train_pipeline),
138-
val=dict(
139-
type=dataset_type,
140-
ann_file=data_root + 'annotations/instances_val2017.json',
141-
img_prefix=data_root + 'val2017/',
142-
pipeline=test_pipeline),
143-
test=dict(
144-
type=dataset_type,
145-
ann_file=data_root + 'annotations/instances_val2017.json',
146-
img_prefix=data_root + 'val2017/',
147-
pipeline=test_pipeline))
156+
148157
# optimizer
149-
optimizer = dict(type='SGD', lr=1e-3, momentum=0.9, weight_decay=5e-4)
150-
optimizer_config = dict()
151-
# learning policy
152-
lr_config = dict(
153-
policy='step',
154-
warmup='linear',
155-
warmup_iters=500,
156-
warmup_ratio=0.1,
157-
step=[20, 42, 49, 52])
158-
runner = dict(type='EpochBasedRunner', max_epochs=55)
159-
cudnn_benchmark = True
160-
evaluation = dict(metric=['bbox', 'segm'])
158+
optim_wrapper = dict(
159+
type='OptimWrapper',
160+
optimizer=dict(type='SGD', lr=1e-3, momentum=0.9, weight_decay=5e-4))
161+
162+
custom_hooks = [
163+
dict(type='CheckInvalidLossHook', interval=50, priority='VERY_LOW')
164+
]
165+
166+
env_cfg = dict(cudnn_benchmark=True)
161167

162168
# NOTE: `auto_scale_lr` is for automatically scaling LR,
163169
# USER SHOULD NOT CHANGE ITS VALUES.

configs/yolact/yolact_r50_8x8_coco.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
_base_ = 'yolact_r50_1x8_coco.py'
22

3-
optimizer = dict(type='SGD', lr=8e-3, momentum=0.9, weight_decay=5e-4)
4-
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
5-
# learning policy
6-
lr_config = dict(
7-
policy='step',
8-
warmup='linear',
9-
warmup_iters=1000,
10-
warmup_ratio=0.1,
11-
step=[20, 42, 49, 52])
12-
3+
# optimizer
4+
optim_wrapper = dict(
5+
type='OptimWrapper',
6+
optimizer=dict(lr=8e-3),
7+
clip_grad=dict(max_norm=35, norm_type=2))
8+
# learning rate
9+
max_epochs = 55
10+
param_scheduler = [
11+
dict(type='LinearLR', start_factor=0.1, by_epoch=False, begin=0, end=1000),
12+
dict(
13+
type='MultiStepLR',
14+
begin=0,
15+
end=max_epochs,
16+
by_epoch=True,
17+
milestones=[20, 42, 49, 52],
18+
gamma=0.1)
19+
]
1320
# NOTE: `auto_scale_lr` is for automatically scaling LR,
1421
# USER SHOULD NOT CHANGE ITS VALUES.
1522
# base_batch_size = (8 GPUs) x (8 samples per GPU)

mmdet/datasets/transforms/loading.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -659,9 +659,25 @@ def transform(self, results: dict) -> Union[dict, None]:
659659
Returns:
660660
dict: Updated result dict.
661661
"""
662-
assert 'gt_bboxes' in results
663-
gt_bboxes = results['gt_bboxes']
664-
if gt_bboxes.shape[0] == 0:
662+
# gt_masks may not match with gt_bboxes, because gt_masks
663+
# will not add into instances if ignore is True
664+
if 'gt_ignore_flags' in results and 'gt_masks' in results:
665+
vaild_idx = np.where(results['gt_ignore_flags'] == 0)[0]
666+
keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_ignore_flags')
667+
for key in keys:
668+
if key in results:
669+
results[key] = results[key][vaild_idx]
670+
671+
if self.by_box:
672+
assert 'gt_bboxes' in results
673+
gt_bboxes = results['gt_bboxes']
674+
instance_num = gt_bboxes.shape[0]
675+
if self.by_mask:
676+
assert 'gt_masks' in results
677+
gt_masks = results['gt_masks']
678+
instance_num = len(gt_masks)
679+
680+
if instance_num == 0:
665681
return results
666682

667683
tests = []

mmdet/models/dense_heads/__init__.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from .ssd_head import SSDHead
3939
from .tood_head import TOODHead
4040
from .vfnet_head import VFNetHead
41-
from .yolact_head import YOLACTHead, YOLACTProtonet, YOLACTSegmHead
41+
from .yolact_head import YOLACTHead, YOLACTProtonet
4242
from .yolo_head import YOLOV3Head
4343
from .yolof_head import YOLOFHead
4444
from .yolox_head import YOLOXHead
@@ -49,11 +49,11 @@
4949
'SSDHead', 'FCOSHead', 'RepPointsHead', 'FoveaHead',
5050
'FreeAnchorRetinaHead', 'ATSSHead', 'FSAFHead', 'NASFCOSHead',
5151
'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'YOLACTHead',
52-
'YOLACTSegmHead', 'YOLACTProtonet', 'YOLOV3Head', 'PAAHead',
53-
'SABLRetinaHead', 'CentripetalHead', 'VFNetHead', 'StageCascadeRPNHead',
54-
'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead',
55-
'AutoAssignHead', 'DETRHead', 'YOLOFHead', 'DeformableDETRHead',
56-
'SOLOHead', 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead',
52+
'YOLACTProtonet', 'YOLOV3Head', 'PAAHead', 'SABLRetinaHead',
53+
'CentripetalHead', 'VFNetHead', 'StageCascadeRPNHead', 'CascadeRPNHead',
54+
'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead', 'AutoAssignHead',
55+
'DETRHead', 'YOLOFHead', 'DeformableDETRHead', 'SOLOHead',
56+
'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead',
5757
'DecoupledSOLOLightHead', 'LADHead', 'TOODHead', 'MaskFormerHead',
5858
'Mask2FormerHead', 'SOLOV2Head', 'DDODHead', 'CenterNetUpdateHead'
5959
]

mmdet/models/dense_heads/anchor_head.py

+3
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,9 @@ def get_targets(self,
383383
# `avg_factor` is usually equal to the number of positive priors.
384384
avg_factor = sum(
385385
[results.avg_factor for results in sampling_results_list])
386+
# update `_raw_positive_infos`, which will be used when calling
387+
# `get_positive_infos`.
388+
self._raw_positive_infos.update(sampling_results=sampling_results_list)
386389
# split targets to a list w.r.t. multiple levels
387390
labels_list = images_to_levels(all_labels, num_level_anchors)
388391
label_weights_list = images_to_levels(all_label_weights,

mmdet/models/dense_heads/base_dense_head.py

+29
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ class BaseDenseHead(BaseModule, metaclass=ABCMeta):
5858

5959
def __init__(self, init_cfg: OptMultiConfig = None) -> None:
6060
super().__init__(init_cfg=init_cfg)
61+
# `_raw_positive_infos` will be used in `get_positive_infos`, which
62+
# can get positive information.
63+
self._raw_positive_infos = dict()
6164

6265
def init_weights(self) -> None:
6366
"""Initialize the weights."""
@@ -68,6 +71,32 @@ def init_weights(self) -> None:
6871
if hasattr(m, 'conv_offset'):
6972
constant_init(m.conv_offset, 0)
7073

74+
def get_positive_infos(self) -> InstanceList:
75+
"""Get positive information from sampling results.
76+
77+
Returns:
78+
list[:obj:`InstanceData`]: Positive information of each image,
79+
usually including positive bboxes, positive labels, positive
80+
priors, etc.
81+
"""
82+
if len(self._raw_positive_infos) == 0:
83+
return None
84+
85+
sampling_results = self._raw_positive_infos.get(
86+
'sampling_results', None)
87+
assert sampling_results is not None
88+
positive_infos = []
89+
for sampling_result in enumerate(sampling_results):
90+
pos_info = InstanceData()
91+
pos_info.bboxes = sampling_result.pos_gt_bboxes
92+
pos_info.labels = sampling_result.pos_gt_labels
93+
pos_info.priors = sampling_result.pos_priors
94+
pos_info.pos_assigned_gt_inds = \
95+
sampling_result.pos_assigned_gt_inds
96+
pos_info.pos_inds = sampling_result.pos_inds
97+
positive_infos.append(pos_info)
98+
return positive_infos
99+
71100
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict:
72101
"""Perform forward propagation and loss calculation of the detection
73102
head on the features of the upstream network.

0 commit comments

Comments
 (0)