|
42 | 42 | from models.common import DetectMultiBackend
|
43 | 43 | from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams, LoadScreenshots
|
44 | 44 | from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
|
45 |
| - increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh) |
| 45 | + increment_path, non_max_suppression,yolov8_non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh) |
46 | 46 | from utils.plots import Annotator, colors, save_one_box
|
47 | 47 | from utils.torch_utils import select_device, smart_inference_mode, TracedModel
|
48 | 48 |
|
@@ -76,6 +76,7 @@ def run(
|
76 | 76 | half=False, # use FP16 half-precision inference
|
77 | 77 | dnn=False, # use OpenCV DNN for ONNX inference
|
78 | 78 | trace=False,
|
| 79 | + v8_det=False, |
79 | 80 | vid_stride=1, # video frame-rate stride
|
80 | 81 | ):
|
81 | 82 | source = str(source)
|
@@ -131,8 +132,8 @@ def run(
|
131 | 132 |
|
132 | 133 | # NMS
|
133 | 134 | with dt[2]:
|
134 |
| - pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) |
135 |
| - |
| 135 | + pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) if not v8_det else yolov8_non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) |
| 136 | + |
136 | 137 | # Second-stage classifier (optional)
|
137 | 138 | # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
|
138 | 139 |
|
@@ -247,6 +248,7 @@ def parse_opt():
|
247 | 248 | parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
|
248 | 249 | parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
|
249 | 250 | parser.add_argument('--trace', action='store_true', help='trace model')
|
| 251 | + parser.add_argument('--v8_det', action='store_true', help='trace model') |
250 | 252 | opt = parser.parse_args()
|
251 | 253 | opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
|
252 | 254 | print_args(vars(opt))
|
|
0 commit comments