Skip to content

Commit ade8c0c

Browse files
committed
add yolov8 core,waiting update later
1 parent f96a5a7 commit ade8c0c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

95 files changed

+18413
-744
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
## <div align="left">🚀 yolov5_reserach PLUS— High-level</div>
1919

2020

21-
更新升级中.....
21+
更新升级中.....add V8 core
2222

2323
### <div align="left">⭐新闻板块【实时更新/计划】</div>
24+
- 2023/3/1 add v8 core:春节期间看了下V8,最近V8又更新了,由于近半年项目比较多也是耽误了好久(原版本是将V8的所有功能全部融合到了V5的代码中,但是训练的时候发生了问题,排查发现问题发生在V5的数据读取处理,所以暂时使用V8的训练结构代码,也便于区分),然后抓紧时间不停更新;
2425
- 2022/11/23 修复已知BUG,V7.0版本更新兼容,年底比较忙后续忙完业务会大更新~
2526
- 2022/10/20 修复适配V7结构和额外任务引起的一些代码问题,实时更新V5的代码优化部分,添加了工具grad_cam在tools目录。
2627
- 2022/9/19 修复已知BUG,更新了实时的V5BUG修复和代码优化融合验证,核心检测、分类、分割的部分CI验证,关键点检测实测训练正常,基本功能整理完毕。

classify/train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
download, increment_path, init_seeds, print_args, yaml_save)
4545
from utils.loggers import GenericLogger
4646
from utils.plots import imshow_cls
47-
from utils.torch_utils import (ModelEMA, model_info, reshape_classifier_output, select_device, smart_DDP,
47+
from utils.torch_utils import (ModelEMA, de_parallel,model_info, reshape_classifier_output, select_device, smart_DDP,
4848
smart_optimizer, smartCrossEntropyLoss, torch_distributed_zero_first)
4949

5050
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
@@ -259,7 +259,7 @@ def train(opt, device):
259259
# Plot examples
260260
images, labels = (x[:25] for x in next(iter(testloader))) # first 25 images and labels
261261
pred = torch.max(ema.ema(images.to(device)), 1)[1]
262-
file = imshow_cls(images, labels, pred, model.names, verbose=False, f=save_dir / 'test_images.jpg')
262+
file = imshow_cls(images, labels, pred, de_parallel(model).names, verbose=False, f=save_dir / 'test_images.jpg')
263263

264264
# Log results
265265
meta = {"epochs": epochs, "top1_acc": best_fitness, "date": datetime.now().isoformat()}

detect.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from models.common import DetectMultiBackend
4343
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams, LoadScreenshots
4444
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
45-
increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
45+
increment_path, non_max_suppression,yolov8_non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
4646
from utils.plots import Annotator, colors, save_one_box
4747
from utils.torch_utils import select_device, smart_inference_mode, TracedModel
4848

@@ -76,6 +76,7 @@ def run(
7676
half=False, # use FP16 half-precision inference
7777
dnn=False, # use OpenCV DNN for ONNX inference
7878
trace=False,
79+
v8_det=False,
7980
vid_stride=1, # video frame-rate stride
8081
):
8182
source = str(source)
@@ -131,8 +132,8 @@ def run(
131132

132133
# NMS
133134
with dt[2]:
134-
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
135-
135+
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) if not v8_det else yolov8_non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
136+
136137
# Second-stage classifier (optional)
137138
# pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
138139

@@ -247,6 +248,7 @@ def parse_opt():
247248
parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
248249
parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
249250
parser.add_argument('--trace', action='store_true', help='trace model')
251+
parser.add_argument('--v8_det', action='store_true', help='trace model')
250252
opt = parser.parse_args()
251253
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
252254
print_args(vars(opt))

0 commit comments

Comments
 (0)