@@ -518,3 +518,51 @@ def _refine_roi(self, x: Tuple[Tensor], rois: Tensor,
518
518
for i in range (len (batch_img_metas ))
519
519
]
520
520
return rois , cls_scores , bbox_preds
521
+
522
+ def forward (self , x : Tuple [Tensor ], rpn_results_list : InstanceList ,
523
+ batch_data_samples : SampleList ) -> tuple :
524
+ """Network forward process. Usually includes backbone, neck and head
525
+ forward without any post-processing.
526
+
527
+ Args:
528
+ x (List[Tensor]): Multi-level features that may have different
529
+ resolutions.
530
+ rpn_results_list (list[:obj:`InstanceData`]): List of region
531
+ proposals.
532
+ batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
533
+ the meta information of each image and corresponding
534
+ annotations.
535
+
536
+ Returns
537
+ tuple: A tuple of features from ``bbox_head`` and ``mask_head``
538
+ forward.
539
+ """
540
+ results = ()
541
+ batch_img_metas = [
542
+ data_samples .metainfo for data_samples in batch_data_samples
543
+ ]
544
+ proposals = [rpn_results .bboxes for rpn_results in rpn_results_list ]
545
+ num_proposals_per_img = tuple (len (p ) for p in proposals )
546
+ rois = bbox2roi (proposals )
547
+ # bbox head
548
+ if self .with_bbox :
549
+ rois , cls_scores , bbox_preds = self ._refine_roi (
550
+ x , rois , batch_img_metas , num_proposals_per_img )
551
+ results = results + (cls_scores , bbox_preds )
552
+ # mask head
553
+ if self .with_mask :
554
+ aug_masks = []
555
+ rois = torch .cat (rois )
556
+ for stage in range (self .num_stages ):
557
+ mask_results = self ._mask_forward (stage , x , rois )
558
+ mask_preds = mask_results ['mask_preds' ]
559
+ mask_preds = mask_preds .split (num_proposals_per_img , 0 )
560
+ aug_masks .append ([m .sigmoid ().detach () for m in mask_preds ])
561
+
562
+ merged_masks = []
563
+ for i in range (len (batch_img_metas )):
564
+ aug_mask = [mask [i ] for mask in aug_masks ]
565
+ merged_mask = merge_aug_masks (aug_mask , batch_img_metas [i ])
566
+ merged_masks .append (merged_mask )
567
+ results = results + (merged_masks , )
568
+ return results
0 commit comments