Skip to content

Commit 90f7aaf

Browse files
authoredMay 12, 2023
[Feature] Support ViTDet in projects (#9812)
1 parent c78202f commit 90f7aaf

9 files changed

+1005
-2
lines changed
 

‎.readthedocs.yml

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
version: 2
22

3-
formats: all
3+
build:
4+
os: ubuntu-22.04
5+
tools:
6+
python: "3.8"
7+
8+
formats:
9+
- epub
410

511
python:
6-
version: 3.7
712
install:
813
- requirements: requirements/docs.txt
914
- requirements: requirements/readthedocs.txt

‎projects/ViTDet/README.md

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# ViTDet
2+
3+
## Description
4+
5+
This is an implementation of [ViTDet](https://github.com/facebookresearch/detectron2/tree/main/projects/ViTDet) based on [MMDetection](https://github.com/facebookresearch/detectron2/tree/main/projects/ViTDet), [MMCV](https://github.com/open-mmlab/mmcv), and [MMEngine](https://github.com/open-mmlab/mmengine).
6+
7+
## Usage
8+
9+
### Training commands
10+
11+
Follow original [setting](https://github.com/facebookresearch/detectron2/tree/main/projects/ViTDet), this project is trained with total batch size of 64 (16 GPU with 4 images per GPU).
12+
13+
In MMDetection's root directory, run the following command to train the model:
14+
15+
```bash
16+
GPUS=${GPUS} ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${WORK_DIR}
17+
```
18+
19+
Below is an example of using 16 GPUs to train VitDet on a Slurm partition named _dev_, and set the work-dir to some shared file systems.
20+
21+
```shell
22+
GPUS=16 ./tools/slurm_train.sh dev vitdet_mask_b projects/ViTDet/configs/vitdet_mask-rcnn_vit-b-mae_lsj-100e.py /nfs/xxxx/vitdet_mask-rcnn_vit-b-mae_lsj-100e
23+
```
24+
25+
### Testing commands
26+
27+
In MMDetection's root directory, run the following command to test the model:
28+
29+
```bash
30+
python tools/test.py projects/ViTDet/configs/vitdet_mask-rcnn_vit-b-mae_lsj-100e.py ${CHECKPOINT_PATH}
31+
```
32+
33+
## Results
34+
35+
Based on mmdetection, this project almost aligns the test and train accuracy of the [ViTDet](https://github.com/facebookresearch/detectron2/tree/main/projects/ViTDet).
36+
37+
| Method | Backbone | Pretrained Model | Training set | Test set | Epoch | Val Box AP | Val Mask AP | Download |
38+
| :--------------------------------------------------------: | :------: | :--------------: | :------------: | :----------: | :---: | :--------: | :----------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
39+
| [ViTDet](./configs/vitdet_mask-rcnn_vit-b-mae_lsj-100e.py) | ViT-B | MAE | COCO2017 Train | COCO2017 Val | 100 | 51.6 | 45.7 | [model](https://download.openmmlab.com/mmdetection/v3.0/vitdet/vitdet_mask-rcnn_vit-b-mae_lsj-100e/vitdet_mask-rcnn_vit-b-mae_lsj-100e_20230328_153519-e15fe294.pth) / [log](https://download.openmmlab.com/mmdetection/v3.0/vitdet/vitdet_mask-rcnn_vit-b-mae_lsj-100e/vitdet_mask-rcnn_vit-b-mae_lsj-100e_20230328_153519.log.json) |
40+
41+
**Note**:
42+
43+
1. The mask AP is lower than official [repo](https://github.com/facebookresearch/detectron2/tree/main/projects/ViTDet) slightly
44+
2. other model vision will release code and weights in the future
45+
46+
## Citation
47+
48+
```latex
49+
@article{li2022exploring,
50+
title={Exploring plain vision transformer backbones for object detection},
51+
author={Li, Yanghao and Mao, Hanzi and Girshick, Ross and He, Kaiming},
52+
journal={arXiv preprint arXiv:2203.16527},
53+
year={2022}
54+
}
55+
```
56+
57+
## Checklist
58+
59+
<!-- Here is a checklist illustrating a usual development workflow of a successful project, and also serves as an overview of this project's progress. The PIC (person in charge) or contributors of this project should check all the items that they believe have been finished, which will further be verified by codebase maintainers via a PR.
60+
OpenMMLab's maintainer will review the code to ensure the project's quality. Reaching the first milestone means that this project suffices the minimum requirement of being merged into 'projects/'. But this project is only eligible to become a part of the core package upon attaining the last milestone.
61+
Note that keeping this section up-to-date is crucial not only for this project's developers but the entire community, since there might be some other contributors joining this project and deciding their starting point from this list. It also helps maintainers accurately estimate time and effort on further code polishing, if needed.
62+
A project does not necessarily have to be finished in a single PR, but it's essential for the project to at least reach the first milestone in its very first PR. -->
63+
64+
- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
65+
66+
- [x] Finish the code
67+
68+
<!-- The code's design shall follow existing interfaces and convention. For example, each model component should be registered into `mmdet.registry.MODELS` and configurable via a config file. -->
69+
70+
- [x] Basic docstrings & proper citation
71+
72+
<!-- Each major object should contain a docstring, describing its functionality and arguments. If you have adapted the code from other open-source projects, don't forget to cite the source project in docstring and make sure your behavior is not against its license. Typically, we do not accept any code snippet under GPL license. [A Short Guide to Open Source Licenses](https://medium.com/nationwide-technology/a-short-guide-to-open-source-licenses-cf5b1c329edd) -->
73+
74+
- [x] Test-time correctness
75+
76+
<!-- If you are reproducing the result from a paper, make sure your model's inference-time performance matches that in the original paper. The weights usually could be obtained by simply renaming the keys in the official pre-trained weights. This test could be skipped though, if you are able to prove the training-time correctness and check the second milestone. -->
77+
78+
- [x] A full README
79+
80+
<!-- As this template does. -->
81+
82+
- [x] Milestone 2: Indicates a successful model implementation.
83+
84+
- [x] Training-time correctness
85+
86+
<!-- If you are reproducing the result from a paper, checking this item means that you should have trained your model from scratch based on the original paper's specification and verified that the final result matches the report within a minor error range. -->
87+
88+
- [ ] Milestone 3: Good to be a part of our core package!
89+
90+
- [ ] Type hints and docstrings
91+
92+
<!-- Ideally *all* the methods should have [type hints](https://www.pythontutorial.net/python-basics/python-type-hints/) and [docstrings](https://google.github.io/styleguide/pyguide.html#381-docstrings). [Example](https://github.com/open-mmlab/mmdetection/blob/5b0d5b40d5c6cfda906db7464ca22cbd4396728a/mmdet/datasets/transforms/transforms.py#L41-L169) -->
93+
94+
- [ ] Unit tests
95+
96+
<!-- Unit tests for each module are required. [Example](https://github.com/open-mmlab/mmdetection/blob/5b0d5b40d5c6cfda906db7464ca22cbd4396728a/tests/test_datasets/test_transforms/test_transforms.py#L35-L88) -->
97+
98+
- [ ] Code polishing
99+
100+
<!-- Refactor your code according to reviewer's comment. -->
101+
102+
- [ ] Metafile.yml
103+
104+
<!-- It will be parsed by MIM and Inferencer. [Example](https://github.com/open-mmlab/mmdetection/blob/3.x/configs/faster_rcnn/metafile.yml) -->
105+
106+
- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
107+
108+
<!-- In particular, you may have to refactor this README into a standard one. [Example](https://github.com/open-mmlab/mmdetection/blob/3.x/configs/faster_rcnn/README.md) -->
109+
110+
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
_base_ = [
2+
'../../../configs/_base_/default_runtime.py',
3+
]
4+
5+
# dataset settings
6+
dataset_type = 'CocoDataset'
7+
data_root = 'data/coco/'
8+
image_size = (1024, 1024)
9+
10+
backend_args = None
11+
12+
train_pipeline = [
13+
dict(type='LoadImageFromFile', backend_args=backend_args),
14+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
15+
dict(type='RandomFlip', prob=0.5),
16+
dict(
17+
type='RandomResize',
18+
scale=image_size,
19+
ratio_range=(0.1, 2.0),
20+
keep_ratio=True),
21+
dict(
22+
type='RandomCrop',
23+
crop_type='absolute_range',
24+
crop_size=image_size,
25+
recompute_bbox=True,
26+
allow_negative_crop=True),
27+
dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)),
28+
dict(type='Pad', size=image_size, pad_val=dict(img=(114, 114, 114))),
29+
dict(type='PackDetInputs')
30+
]
31+
32+
test_pipeline = [
33+
dict(type='LoadImageFromFile', backend_args=backend_args),
34+
dict(type='Resize', scale=image_size, keep_ratio=True),
35+
dict(type='Pad', size=image_size, pad_val=dict(img=(114, 114, 114))),
36+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
37+
dict(
38+
type='PackDetInputs',
39+
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
40+
'scale_factor'))
41+
]
42+
43+
train_dataloader = dict(
44+
batch_size=4,
45+
num_workers=8,
46+
persistent_workers=True,
47+
sampler=dict(type='DefaultSampler', shuffle=True),
48+
dataset=dict(
49+
type=dataset_type,
50+
data_root=data_root,
51+
ann_file='annotations/instances_train2017.json',
52+
data_prefix=dict(img='train2017/'),
53+
filter_cfg=dict(filter_empty_gt=True, min_size=32),
54+
pipeline=train_pipeline))
55+
56+
val_dataloader = dict(
57+
batch_size=1,
58+
num_workers=2,
59+
persistent_workers=True,
60+
drop_last=False,
61+
sampler=dict(type='DefaultSampler', shuffle=False),
62+
dataset=dict(
63+
type=dataset_type,
64+
data_root=data_root,
65+
ann_file='annotations/instances_val2017.json',
66+
data_prefix=dict(img='val2017/'),
67+
test_mode=True,
68+
pipeline=test_pipeline))
69+
test_dataloader = val_dataloader
70+
71+
val_evaluator = dict(
72+
type='CocoMetric',
73+
ann_file=data_root + 'annotations/instances_val2017.json',
74+
metric=['bbox', 'segm'],
75+
format_only=False)
76+
test_evaluator = val_evaluator
77+
78+
optim_wrapper = dict(
79+
type='AmpOptimWrapper',
80+
constructor='LayerDecayOptimizerConstructor',
81+
paramwise_cfg={
82+
'decay_rate': 0.7,
83+
'decay_type': 'layer_wise',
84+
'num_layers': 12,
85+
},
86+
optimizer=dict(
87+
type='AdamW',
88+
lr=0.0001,
89+
betas=(0.9, 0.999),
90+
weight_decay=0.1,
91+
))
92+
93+
# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep
94+
max_iters = 184375
95+
interval = 5000
96+
dynamic_intervals = [(max_iters // interval * interval + 1, max_iters)]
97+
param_scheduler = [
98+
dict(
99+
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=250),
100+
dict(
101+
type='MultiStepLR',
102+
begin=0,
103+
end=max_iters,
104+
by_epoch=False,
105+
# 88 ep = [163889 iters * 64 images/iter / 118000 images/ep
106+
# 96 ep = [177546 iters * 64 images/iter / 118000 images/ep
107+
milestones=[163889, 177546],
108+
gamma=0.1)
109+
]
110+
111+
train_cfg = dict(
112+
type='IterBasedTrainLoop',
113+
max_iters=max_iters,
114+
val_interval=interval,
115+
dynamic_intervals=dynamic_intervals)
116+
val_cfg = dict(type='ValLoop')
117+
test_cfg = dict(type='TestLoop')
118+
119+
default_hooks = dict(
120+
logger=dict(type='LoggerHook', interval=50),
121+
checkpoint=dict(
122+
type='CheckpointHook',
123+
by_epoch=False,
124+
save_last=True,
125+
interval=interval,
126+
max_keep_ckpts=5))
127+
vis_backends = [
128+
dict(type='LocalVisBackend'),
129+
dict(type='TensorboardVisBackend')
130+
]
131+
visualizer = dict(
132+
type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
133+
log_processor = dict(type='LogProcessor', window_size=50, by_epoch=False)
134+
135+
auto_scale_lr = dict(base_batch_size=64)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
_base_ = [
2+
'../../../configs/_base_/models/mask-rcnn_r50_fpn.py',
3+
'./lsj-100e_coco-instance.py',
4+
]
5+
6+
custom_imports = dict(imports=['projects.ViTDet.vitdet'])
7+
8+
backbone_norm_cfg = dict(type='LN', requires_grad=True)
9+
norm_cfg = dict(type='LN2d', requires_grad=True)
10+
image_size = (1024, 1024)
11+
batch_augments = [
12+
dict(type='BatchFixedSizePad', size=image_size, pad_mask=True)
13+
]
14+
15+
# model settings
16+
model = dict(
17+
data_preprocessor=dict(pad_size_divisor=32, batch_augments=batch_augments),
18+
backbone=dict(
19+
_delete_=True,
20+
type='ViT',
21+
img_size=1024,
22+
patch_size=16,
23+
embed_dim=768,
24+
depth=12,
25+
num_heads=12,
26+
drop_path_rate=0.1,
27+
window_size=14,
28+
mlp_ratio=4,
29+
qkv_bias=True,
30+
norm_cfg=backbone_norm_cfg,
31+
window_block_indexes=[
32+
0,
33+
1,
34+
3,
35+
4,
36+
6,
37+
7,
38+
9,
39+
10,
40+
],
41+
use_rel_pos=True,
42+
init_cfg=dict(
43+
type='Pretrained', checkpoint='mae_pretrain_vit_base.pth')),
44+
neck=dict(
45+
_delete_=True,
46+
type='SimpleFPN',
47+
backbone_channel=768,
48+
in_channels=[192, 384, 768, 768],
49+
out_channels=256,
50+
num_outs=5,
51+
norm_cfg=norm_cfg),
52+
rpn_head=dict(num_convs=2),
53+
roi_head=dict(
54+
bbox_head=dict(
55+
type='Shared4Conv1FCBBoxHead',
56+
conv_out_channels=256,
57+
norm_cfg=norm_cfg),
58+
mask_head=dict(norm_cfg=norm_cfg)))
59+
60+
custom_hooks = [dict(type='Fp16CompresssionHook')]

‎projects/ViTDet/vitdet/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .fp16_compression_hook import Fp16CompresssionHook
2+
from .layer_decay_optimizer_constructor import LayerDecayOptimizerConstructor
3+
from .simple_fpn import SimpleFPN
4+
from .vit import LN2d, ViT
5+
6+
__all__ = [
7+
'LayerDecayOptimizerConstructor', 'ViT', 'SimpleFPN', 'LN2d',
8+
'Fp16CompresssionHook'
9+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from mmengine.hooks import Hook
3+
4+
from mmdet.registry import HOOKS
5+
6+
7+
@HOOKS.register_module()
8+
class Fp16CompresssionHook(Hook):
9+
"""Support fp16 compression in DDP mode.
10+
11+
In detectron2, vitdet use Fp16CompresssionHook in training process
12+
Fp16CompresssionHook can reduce training time and improve bbox mAP when you
13+
use Fp16CompresssionHook, training time reduce form 3 days to 2 days and
14+
box mAP from 51.4 to 51.6
15+
"""
16+
17+
def before_train(self, runner):
18+
19+
if runner.distributed:
20+
if runner.cfg.get('model_wrapper_cfg') is None:
21+
from torch.distributed.algorithms.ddp_comm_hooks import \
22+
default as comm_hooks
23+
runner.model.register_comm_hook(
24+
state=None, hook=comm_hooks.fp16_compress_hook)
25+
runner.logger.info('use fp16 compression in DDP mode')
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import json
3+
from typing import List
4+
5+
import torch.nn as nn
6+
from mmengine.dist import get_dist_info
7+
from mmengine.logging import MMLogger
8+
from mmengine.optim import DefaultOptimWrapperConstructor
9+
10+
from mmdet.registry import OPTIM_WRAPPER_CONSTRUCTORS
11+
12+
13+
def get_layer_id_for_vit(var_name, max_layer_id):
14+
"""Get the layer id to set the different learning rates in ``layer_wise``
15+
decay_type.
16+
17+
Args:
18+
var_name (str): The key of the model.
19+
max_layer_id (int): Maximum layer id.
20+
Returns:
21+
int: The id number corresponding to different learning rate in
22+
``LayerDecayOptimizerConstructor``.
23+
"""
24+
if var_name.startswith('backbone'):
25+
if 'patch_embed' in var_name or 'pos_embed' in var_name:
26+
return 0
27+
elif '.blocks.' in var_name:
28+
layer_id = int(var_name.split('.')[2]) + 1
29+
return layer_id
30+
else:
31+
return max_layer_id + 1
32+
else:
33+
return max_layer_id + 1
34+
35+
36+
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
37+
class LayerDecayOptimizerConstructor(DefaultOptimWrapperConstructor):
38+
# Different learning rates are set for different layers of backbone.
39+
# Note: Currently, this optimizer constructor is built for ViT.
40+
41+
def add_params(self, params: List[dict], module: nn.Module,
42+
**kwargs) -> None:
43+
"""Add all parameters of module to the params list.
44+
45+
The parameters of the given module will be added to the list of param
46+
groups, with specific rules defined by paramwise_cfg.
47+
Args:
48+
params (list[dict]): A list of param groups, it will be modified
49+
in place.
50+
module (nn.Module): The module to be added.
51+
"""
52+
logger = MMLogger.get_current_instance()
53+
54+
parameter_groups = {}
55+
logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}')
56+
num_layers = self.paramwise_cfg.get('num_layers') + 2
57+
decay_rate = self.paramwise_cfg.get('decay_rate')
58+
decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise')
59+
logger.info('Build LayerDecayOptimizerConstructor '
60+
f'{decay_type} {decay_rate} - {num_layers}')
61+
weight_decay = self.base_wd
62+
63+
for name, param in module.named_parameters():
64+
if not param.requires_grad:
65+
continue # frozen weights
66+
if name.startswith('backbone.blocks') and 'norm' in name:
67+
group_name = 'no_decay'
68+
this_weight_decay = 0.
69+
elif 'pos_embed' in name:
70+
group_name = 'no_decay_pos_embed'
71+
this_weight_decay = 0
72+
else:
73+
group_name = 'decay'
74+
this_weight_decay = weight_decay
75+
76+
layer_id = get_layer_id_for_vit(
77+
name, self.paramwise_cfg.get('num_layers'))
78+
logger.info(f'set param {name} as id {layer_id}')
79+
80+
group_name = f'layer_{layer_id}_{group_name}'
81+
this_lr_multi = 1.
82+
83+
if group_name not in parameter_groups:
84+
scale = decay_rate**(num_layers - 1 - layer_id)
85+
86+
parameter_groups[group_name] = {
87+
'weight_decay': this_weight_decay,
88+
'params': [],
89+
'param_names': [],
90+
'lr_scale': scale,
91+
'group_name': group_name,
92+
'lr': scale * self.base_lr * this_lr_multi,
93+
}
94+
95+
parameter_groups[group_name]['params'].append(param)
96+
parameter_groups[group_name]['param_names'].append(name)
97+
98+
rank, _ = get_dist_info()
99+
if rank == 0:
100+
to_display = {}
101+
for key in parameter_groups:
102+
to_display[key] = {
103+
'param_names': parameter_groups[key]['param_names'],
104+
'lr_scale': parameter_groups[key]['lr_scale'],
105+
'lr': parameter_groups[key]['lr'],
106+
'weight_decay': parameter_groups[key]['weight_decay'],
107+
}
108+
logger.info(f'Param groups = {json.dumps(to_display, indent=2)}')
109+
params.extend(parameter_groups.values())

‎projects/ViTDet/vitdet/simple_fpn.py

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import List
3+
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
from mmcv.cnn import ConvModule, build_norm_layer
7+
from mmengine.model import BaseModule
8+
from torch import Tensor
9+
10+
from mmdet.registry import MODELS
11+
from mmdet.utils import MultiConfig, OptConfigType
12+
13+
14+
@MODELS.register_module()
15+
class SimpleFPN(BaseModule):
16+
"""Simple Feature Pyramid Network for ViTDet."""
17+
18+
def __init__(self,
19+
backbone_channel: int,
20+
in_channels: List[int],
21+
out_channels: int,
22+
num_outs: int,
23+
conv_cfg: OptConfigType = None,
24+
norm_cfg: OptConfigType = None,
25+
act_cfg: OptConfigType = None,
26+
init_cfg: MultiConfig = None) -> None:
27+
super().__init__(init_cfg=init_cfg)
28+
assert isinstance(in_channels, list)
29+
self.backbone_channel = backbone_channel
30+
self.in_channels = in_channels
31+
self.out_channels = out_channels
32+
self.num_ins = len(in_channels)
33+
self.num_outs = num_outs
34+
35+
self.fpn1 = nn.Sequential(
36+
nn.ConvTranspose2d(self.backbone_channel,
37+
self.backbone_channel // 2, 2, 2),
38+
build_norm_layer(norm_cfg, self.backbone_channel // 2)[1],
39+
nn.GELU(),
40+
nn.ConvTranspose2d(self.backbone_channel // 2,
41+
self.backbone_channel // 4, 2, 2))
42+
self.fpn2 = nn.Sequential(
43+
nn.ConvTranspose2d(self.backbone_channel,
44+
self.backbone_channel // 2, 2, 2))
45+
self.fpn3 = nn.Sequential(nn.Identity())
46+
self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2))
47+
48+
self.lateral_convs = nn.ModuleList()
49+
self.fpn_convs = nn.ModuleList()
50+
51+
for i in range(self.num_ins):
52+
l_conv = ConvModule(
53+
in_channels[i],
54+
out_channels,
55+
1,
56+
conv_cfg=conv_cfg,
57+
norm_cfg=norm_cfg,
58+
act_cfg=act_cfg,
59+
inplace=False)
60+
fpn_conv = ConvModule(
61+
out_channels,
62+
out_channels,
63+
3,
64+
padding=1,
65+
conv_cfg=conv_cfg,
66+
norm_cfg=norm_cfg,
67+
act_cfg=act_cfg,
68+
inplace=False)
69+
70+
self.lateral_convs.append(l_conv)
71+
self.fpn_convs.append(fpn_conv)
72+
73+
def forward(self, input: Tensor) -> tuple:
74+
"""Forward function.
75+
76+
Args:
77+
inputs (Tensor): Features from the upstream network, 4D-tensor
78+
Returns:
79+
tuple: Feature maps, each is a 4D-tensor.
80+
"""
81+
# build FPN
82+
inputs = []
83+
inputs.append(self.fpn1(input))
84+
inputs.append(self.fpn2(input))
85+
inputs.append(self.fpn3(input))
86+
inputs.append(self.fpn4(input))
87+
88+
# build laterals
89+
laterals = [
90+
lateral_conv(inputs[i])
91+
for i, lateral_conv in enumerate(self.lateral_convs)
92+
]
93+
94+
# build outputs
95+
# part 1: from original levels
96+
outs = [self.fpn_convs[i](laterals[i]) for i in range(self.num_ins)]
97+
98+
# part 2: add extra levels
99+
if self.num_outs > len(outs):
100+
for i in range(self.num_outs - self.num_ins):
101+
outs.append(F.max_pool2d(outs[-1], 1, stride=2))
102+
return tuple(outs)

‎projects/ViTDet/vitdet/vit.py

+448
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)
Please sign in to comment.