Skip to content

Commit e32ed8d

Browse files
authored
[Feature]: support Feature Pyramid Grids (open-mmlab#4645)
* add cfg * add cfg * fix typo * fix pad * fix test pad * add readme * modify by comments * fix typo
1 parent db85ba2 commit e32ed8d

11 files changed

+744
-1
lines changed

configs/fpg/README.md

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Feature Pyramid Grids
2+
3+
## Introduction
4+
5+
```latex
6+
@article{chen2020feature,
7+
title={Feature pyramid grids},
8+
author={Chen, Kai and Cao, Yuhang and Loy, Chen Change and Lin, Dahua and Feichtenhofer, Christoph},
9+
journal={arXiv preprint arXiv:2004.03580},
10+
year={2020}
11+
}
12+
```
13+
14+
## Results and Models
15+
16+
We benchmark the new training schedule (crop training, large batch, unfrozen BN, 50 epochs) introduced in NAS-FPN.
17+
All backbones are Resnet-50 in pytorch style.
18+
19+
| Method | Neck | Lr schd | Mem (GB) | Inf time (fps) | box AP | mask AP | Config | Download |
20+
|:------------:|:-----------:|:-------:|:--------:|:--------------:|:------:|:-------:|:-------:|:--------:|
21+
| Faster R-CNN | FPG | 50e | 20.0 | - | 42.2 | - |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fpg/faster_rcnn_r50_fpg_crop640_50e_coco.py) |
22+
| Faster R-CNN | FPG-chn128 | 50e | 11.9 | - | 41.2 | - |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fpg/faster_rcnn_r50_fpg-chn128_crop640_50e_coco.py) |
23+
| Mask R-CNN | FPG | 50e | 23.2 | - | 42.7 | 37.8 |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fpg/mask_rcnn_r50_fpg_crop640_50e_coco.py) |
24+
| Mask R-CNN | FPG-chn128 | 50e | 15.3 | - | 41.7 | 36.9 |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fpg/mask_rcnn_r50_fpg-chn128_crop640_50e_coco.py) |
25+
| RetinaNet | FPG | 50e | 20.8 | - | 40.5 | - |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fpg/retinanet_r50_fpg_crop640_50e_coco.py) |
26+
| RetinaNet | FPG-chn128 | 50e | | - | | - |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fpg/retinanet_r50_fpg-chn128_crop640_50e_coco.py) |
27+
28+
**Note**: Chn128 means to decrease the number of channels of features and convs from 256 (default) to 128 in
29+
Neck and BBox Head, which can greatly decrease memory consumption without sacrificing much precision.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
_base_ = 'faster_rcnn_r50_fpg_crop640_50e_coco.py'
2+
3+
norm_cfg = dict(type='BN', requires_grad=True)
4+
model = dict(
5+
neck=dict(out_channels=128, inter_channels=128),
6+
rpn_head=dict(in_channels=128),
7+
roi_head=dict(
8+
bbox_roi_extractor=dict(out_channels=128),
9+
bbox_head=dict(in_channels=128)))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
_base_ = 'faster_rcnn_r50_fpn_crop640_50e_coco.py'
2+
3+
norm_cfg = dict(type='BN', requires_grad=True)
4+
model = dict(
5+
neck=dict(
6+
type='FPG',
7+
in_channels=[256, 512, 1024, 2048],
8+
out_channels=256,
9+
inter_channels=256,
10+
num_outs=5,
11+
stack_times=9,
12+
paths=['bu'] * 9,
13+
same_down_trans=None,
14+
same_up_trans=dict(
15+
type='conv',
16+
kernel_size=3,
17+
stride=2,
18+
padding=1,
19+
norm_cfg=norm_cfg,
20+
inplace=False,
21+
order=('act', 'conv', 'norm')),
22+
across_lateral_trans=dict(
23+
type='conv',
24+
kernel_size=1,
25+
norm_cfg=norm_cfg,
26+
inplace=False,
27+
order=('act', 'conv', 'norm')),
28+
across_down_trans=dict(
29+
type='interpolation_conv',
30+
mode='nearest',
31+
kernel_size=3,
32+
norm_cfg=norm_cfg,
33+
order=('act', 'conv', 'norm'),
34+
inplace=False),
35+
across_up_trans=None,
36+
across_skip_trans=dict(
37+
type='conv',
38+
kernel_size=1,
39+
norm_cfg=norm_cfg,
40+
inplace=False,
41+
order=('act', 'conv', 'norm')),
42+
output_trans=dict(
43+
type='last_conv',
44+
kernel_size=3,
45+
order=('act', 'conv', 'norm'),
46+
inplace=False),
47+
norm_cfg=norm_cfg,
48+
skip_inds=[(0, 1, 2, 3), (0, 1, 2), (0, 1), (0, ), ()]))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
_base_ = [
2+
'../_base_/models/faster_rcnn_r50_fpn.py',
3+
'../_base_/datasets/coco_detection.py',
4+
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
5+
]
6+
norm_cfg = dict(type='BN', requires_grad=True)
7+
model = dict(
8+
backbone=dict(norm_cfg=norm_cfg, norm_eval=False),
9+
neck=dict(norm_cfg=norm_cfg),
10+
roi_head=dict(bbox_head=dict(norm_cfg=norm_cfg)))
11+
dataset_type = 'CocoDataset'
12+
data_root = 'data/coco/'
13+
img_norm_cfg = dict(
14+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
15+
train_pipeline = [
16+
dict(type='LoadImageFromFile'),
17+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
18+
dict(
19+
type='Resize',
20+
img_scale=(640, 640),
21+
ratio_range=(0.8, 1.2),
22+
keep_ratio=True),
23+
dict(type='RandomCrop', crop_size=(640, 640)),
24+
dict(type='RandomFlip', flip_ratio=0.5),
25+
dict(type='Normalize', **img_norm_cfg),
26+
dict(type='Pad', size=(640, 640)),
27+
dict(type='DefaultFormatBundle'),
28+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
29+
]
30+
test_pipeline = [
31+
dict(type='LoadImageFromFile'),
32+
dict(
33+
type='MultiScaleFlipAug',
34+
img_scale=(640, 640),
35+
flip=False,
36+
transforms=[
37+
dict(type='Resize', keep_ratio=True),
38+
dict(type='RandomFlip'),
39+
dict(type='Normalize', **img_norm_cfg),
40+
dict(type='Pad', size_divisor=64),
41+
dict(type='ImageToTensor', keys=['img']),
42+
dict(type='Collect', keys=['img']),
43+
])
44+
]
45+
data = dict(
46+
samples_per_gpu=8,
47+
workers_per_gpu=4,
48+
train=dict(pipeline=train_pipeline),
49+
val=dict(pipeline=test_pipeline),
50+
test=dict(pipeline=test_pipeline))
51+
# learning policy
52+
optimizer = dict(
53+
type='SGD',
54+
lr=0.08,
55+
momentum=0.9,
56+
weight_decay=0.0001,
57+
paramwise_cfg=dict(norm_decay_mult=0, bypass_duplicate=True))
58+
optimizer_config = dict(grad_clip=None)
59+
# learning policy
60+
lr_config = dict(
61+
policy='step',
62+
warmup='linear',
63+
warmup_iters=1000,
64+
warmup_ratio=0.1,
65+
step=[30, 40])
66+
# runtime settings
67+
runner = dict(max_epochs=50)
68+
evaluation = dict(interval=2)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
_base_ = 'mask_rcnn_r50_fpg_crop640_50e_coco.py'
2+
3+
model = dict(
4+
neck=dict(out_channels=128, inter_channels=128),
5+
rpn_head=dict(in_channels=128),
6+
roi_head=dict(
7+
bbox_roi_extractor=dict(out_channels=128),
8+
bbox_head=dict(in_channels=128),
9+
mask_roi_extractor=dict(out_channels=128),
10+
mask_head=dict(in_channels=128)))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
_base_ = 'mask_rcnn_r50_fpn_crop640_50e_coco.py'
2+
3+
norm_cfg = dict(type='BN', requires_grad=True)
4+
model = dict(
5+
neck=dict(
6+
type='FPG',
7+
in_channels=[256, 512, 1024, 2048],
8+
out_channels=256,
9+
inter_channels=256,
10+
num_outs=5,
11+
stack_times=9,
12+
paths=['bu'] * 9,
13+
same_down_trans=None,
14+
same_up_trans=dict(
15+
type='conv',
16+
kernel_size=3,
17+
stride=2,
18+
padding=1,
19+
norm_cfg=norm_cfg,
20+
inplace=False,
21+
order=('act', 'conv', 'norm')),
22+
across_lateral_trans=dict(
23+
type='conv',
24+
kernel_size=1,
25+
norm_cfg=norm_cfg,
26+
inplace=False,
27+
order=('act', 'conv', 'norm')),
28+
across_down_trans=dict(
29+
type='interpolation_conv',
30+
mode='nearest',
31+
kernel_size=3,
32+
norm_cfg=norm_cfg,
33+
order=('act', 'conv', 'norm'),
34+
inplace=False),
35+
across_up_trans=None,
36+
across_skip_trans=dict(
37+
type='conv',
38+
kernel_size=1,
39+
norm_cfg=norm_cfg,
40+
inplace=False,
41+
order=('act', 'conv', 'norm')),
42+
output_trans=dict(
43+
type='last_conv',
44+
kernel_size=3,
45+
order=('act', 'conv', 'norm'),
46+
inplace=False),
47+
norm_cfg=norm_cfg,
48+
skip_inds=[(0, 1, 2, 3), (0, 1, 2), (0, 1), (0, ), ()]))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
_base_ = [
2+
'../_base_/models/mask_rcnn_r50_fpn.py',
3+
'../_base_/datasets/coco_instance.py',
4+
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
5+
]
6+
norm_cfg = dict(type='BN', requires_grad=True)
7+
model = dict(
8+
backbone=dict(norm_cfg=norm_cfg, norm_eval=False),
9+
neck=dict(
10+
type='FPN',
11+
in_channels=[256, 512, 1024, 2048],
12+
out_channels=256,
13+
norm_cfg=norm_cfg,
14+
num_outs=5),
15+
roi_head=dict(
16+
bbox_head=dict(norm_cfg=norm_cfg), mask_head=dict(norm_cfg=norm_cfg)))
17+
dataset_type = 'CocoDataset'
18+
data_root = 'data/coco/'
19+
img_norm_cfg = dict(
20+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
21+
train_pipeline = [
22+
dict(type='LoadImageFromFile'),
23+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
24+
dict(
25+
type='Resize',
26+
img_scale=(640, 640),
27+
ratio_range=(0.8, 1.2),
28+
keep_ratio=True),
29+
dict(type='RandomCrop', crop_size=(640, 640)),
30+
dict(type='RandomFlip', flip_ratio=0.5),
31+
dict(type='Normalize', **img_norm_cfg),
32+
dict(type='Pad', size=(640, 640)),
33+
dict(type='DefaultFormatBundle'),
34+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
35+
]
36+
test_pipeline = [
37+
dict(type='LoadImageFromFile'),
38+
dict(
39+
type='MultiScaleFlipAug',
40+
img_scale=(640, 640),
41+
flip=False,
42+
transforms=[
43+
dict(type='Resize', keep_ratio=True),
44+
dict(type='RandomFlip'),
45+
dict(type='Normalize', **img_norm_cfg),
46+
dict(type='Pad', size_divisor=64),
47+
dict(type='ImageToTensor', keys=['img']),
48+
dict(type='Collect', keys=['img']),
49+
])
50+
]
51+
data = dict(
52+
samples_per_gpu=8,
53+
workers_per_gpu=4,
54+
train=dict(pipeline=train_pipeline),
55+
val=dict(pipeline=test_pipeline),
56+
test=dict(pipeline=test_pipeline))
57+
# learning policy
58+
optimizer = dict(
59+
type='SGD',
60+
lr=0.08,
61+
momentum=0.9,
62+
weight_decay=0.0001,
63+
paramwise_cfg=dict(norm_decay_mult=0, bypass_duplicate=True))
64+
optimizer_config = dict(grad_clip=None)
65+
# learning policy
66+
lr_config = dict(
67+
policy='step',
68+
warmup='linear',
69+
warmup_iters=1000,
70+
warmup_ratio=0.1,
71+
step=[30, 40])
72+
# runtime settings
73+
runner = dict(max_epochs=50)
74+
evaluation = dict(interval=2)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
_base_ = 'retinanet_r50_fpg_crop640_50e_coco.py'
2+
3+
model = dict(
4+
neck=dict(out_channels=128, inter_channels=128),
5+
bbox_head=dict(in_channels=128))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
_base_ = '../nas_fpn/retinanet_r50_nasfpn_crop640_50e_coco.py'
2+
3+
norm_cfg = dict(type='BN', requires_grad=True)
4+
model = dict(
5+
neck=dict(
6+
_delete_=True,
7+
type='FPG',
8+
in_channels=[256, 512, 1024, 2048],
9+
out_channels=256,
10+
inter_channels=256,
11+
num_outs=5,
12+
add_extra_convs=True,
13+
start_level=1,
14+
stack_times=9,
15+
paths=['bu'] * 9,
16+
same_down_trans=None,
17+
same_up_trans=dict(
18+
type='conv',
19+
kernel_size=3,
20+
stride=2,
21+
padding=1,
22+
norm_cfg=norm_cfg,
23+
inplace=False,
24+
order=('act', 'conv', 'norm')),
25+
across_lateral_trans=dict(
26+
type='conv',
27+
kernel_size=1,
28+
norm_cfg=norm_cfg,
29+
inplace=False,
30+
order=('act', 'conv', 'norm')),
31+
across_down_trans=dict(
32+
type='interpolation_conv',
33+
mode='nearest',
34+
kernel_size=3,
35+
norm_cfg=norm_cfg,
36+
order=('act', 'conv', 'norm'),
37+
inplace=False),
38+
across_up_trans=None,
39+
across_skip_trans=dict(
40+
type='conv',
41+
kernel_size=1,
42+
norm_cfg=norm_cfg,
43+
inplace=False,
44+
order=('act', 'conv', 'norm')),
45+
output_trans=dict(
46+
type='last_conv',
47+
kernel_size=3,
48+
order=('act', 'conv', 'norm'),
49+
inplace=False),
50+
norm_cfg=norm_cfg,
51+
skip_inds=[(0, 1, 2, 3), (0, 1, 2), (0, 1), (0, ), ()]))
52+
53+
evaluation = dict(interval=2)

mmdet/models/necks/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .bfp import BFP
22
from .channel_mapper import ChannelMapper
3+
from .fpg import FPG
34
from .fpn import FPN
45
from .fpn_carafe import FPN_CARAFE
56
from .hrfpn import HRFPN
@@ -11,5 +12,5 @@
1112

1213
__all__ = [
1314
'FPN', 'BFP', 'ChannelMapper', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN',
14-
'NASFCOS_FPN', 'RFP', 'YOLOV3Neck'
15+
'NASFCOS_FPN', 'RFP', 'YOLOV3Neck', 'FPG'
1516
]

0 commit comments

Comments
 (0)