From d7f01a11a3bec56e455e0cb869b8385ff7af3528 Mon Sep 17 00:00:00 2001 From: chenxinfeng Date: Fri, 29 Apr 2022 15:12:38 +0800 Subject: [PATCH 1/2] fix bug when mask number == 0 --- mmdet2trt/apis/inference.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/mmdet2trt/apis/inference.py b/mmdet2trt/apis/inference.py index 6ac84aa..93e3434 100644 --- a/mmdet2trt/apis/inference.py +++ b/mmdet2trt/apis/inference.py @@ -182,19 +182,20 @@ def forward(self, img, img_metas, *args, **kwargs): masks = masks.detach().cpu().numpy() num_classes = len(self.CLASSES) class_agnostic = True - segms_results = [] - for i in range(batch_size): - segms_results = FCNMaskHead.get_seg_masks( - Addict( - num_classes=num_classes, - class_agnostic=class_agnostic), - masks, - old_dets, - labels, - rcnn_test_cfg=Addict(mask_thr_binary=0.5), - ori_shape=img_metas[i]['ori_shape'], - scale_factor=scale_factor, - rescale=rescale) + segms_results = [[] for _ in range(num_classes)] + if num_dets>0: + for i in range(batch_size): + segms_results = FCNMaskHead.get_seg_masks( + Addict( + num_classes=num_classes, + class_agnostic=class_agnostic), + masks, + old_dets, + labels, + rcnn_test_cfg=Addict(mask_thr_binary=0.5), + ori_shape=img_metas[i]['ori_shape'], + scale_factor=scale_factor, + rescale=rescale) results.append((dets_results, segms_results)) else: results.append(dets_results) From b531cb1dcac881a2045f43222fc4823e13d6efe6 Mon Sep 17 00:00:00 2001 From: chenxinfeng Date: Tue, 5 Sep 2023 20:49:09 +0800 Subject: [PATCH 2/2] fix in dataset class --- mmdet2trt/apis/inference.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mmdet2trt/apis/inference.py b/mmdet2trt/apis/inference.py index 93e3434..f3db70b 100644 --- a/mmdet2trt/apis/inference.py +++ b/mmdet2trt/apis/inference.py @@ -77,9 +77,11 @@ def get_classes_from_config(model_cfg): data_cfg = model_cfg.data def get_module_from_train_val(train_val_cfg): - while train_val_cfg.type == 'RepeatDataset' or \ - train_val_cfg.type == 'MultiImageMixDataset': - train_val_cfg = train_val_cfg.dataset + while train_val_cfg.type in ('RepeatDataset', + 'MultiImageMixDataset', + 'ConcatDataset'): + train_val_cfg = train_val_cfg.datasets[0] if hasattr( + train_val_cfg, 'datasets') else train_val_cfg.dataset return module_dict[train_val_cfg.type] data_cfg_type_list = ['train', 'val', 'test']