Skip to content

Commit 69be8e7

Browse files
authored
YOLOv5 v4.0 Release (#1837)
* Update C3 module * Update C3 module * Update C3 module * Update C3 module * update * update * update * update * update * update * update * update * update * updates * updates * updates * updates * updates * updates * updates * updates * updates * updates * update * update * update * update * updates * updates * updates * updates * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update datasets * update * update * update * update attempt_downlaod() * merge * merge * update * update * update * update * update * update * update * update * update * update * parameterize eps * comments * gs-multiple * update * max_nms implemented * Create one_cycle() function * update * update * update * update * update * update * update * update study.png * update study.png * Update datasets.py
1 parent 0e341c5 commit 69be8e7

23 files changed

+489
-125
lines changed

README.md

+20-15
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,40 @@
44

55
![CI CPU testing](https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg)
66

7-
This repository represents Ultralytics open-source research into future object detection methods, and incorporates our lessons learned and best practices evolved over training thousands of models on custom client datasets with our previous YOLO repository https://github.com/ultralytics/yolov3. **All code and models are under active development, and are subject to modification or deletion without notice.** Use at your own risk.
7+
This repository represents Ultralytics open-source research into future object detection methods, and incorporates lessons learned and best practices evolved over thousands of hours of training and evolution on anonymized client datasets. **All code and models are under active development, and are subject to modification or deletion without notice.** Use at your own risk.
88

9-
<img src="https://user-images.githubusercontent.com/26833433/90187293-6773ba00-dd6e-11ea-8f90-cd94afc0427f.png" width="1000">** GPU Speed measures end-to-end time per image averaged over 5000 COCO val2017 images using a V100 GPU with batch size 32, and includes image preprocessing, PyTorch FP16 inference, postprocessing and NMS. EfficientDet data from [google/automl](https://github.com/google/automl) at batch size 8.
9+
<img src="https://user-images.githubusercontent.com/26833433/103594689-455e0e00-4eae-11eb-9cdf-7d753e2ceeeb.png" width="1000">** GPU Speed measures end-to-end time per image averaged over 5000 COCO val2017 images using a V100 GPU with batch size 32, and includes image preprocessing, PyTorch FP16 inference, postprocessing and NMS. EfficientDet data from [google/automl](https://github.com/google/automl) at batch size 8.
1010

11+
- **January 5, 2021**: [v4.0 release](https://github.com/ultralytics/yolov5/releases/tag/v4.0): nn.SiLU() activations, [Weights & Biases](https://wandb.ai/) logging, [PyTorch Hub](https://pytorch.org/hub/ultralytics_yolov5/) integration.
1112
- **August 13, 2020**: [v3.0 release](https://github.com/ultralytics/yolov5/releases/tag/v3.0): nn.Hardswish() activations, data autodownload, native AMP.
1213
- **July 23, 2020**: [v2.0 release](https://github.com/ultralytics/yolov5/releases/tag/v2.0): improved model definition, training and mAP.
1314
- **June 22, 2020**: [PANet](https://arxiv.org/abs/1803.01534) updates: new heads, reduced parameters, improved speed and mAP [364fcfd](https://github.com/ultralytics/yolov5/commit/364fcfd7dba53f46edd4f04c037a039c0a287972).
1415
- **June 19, 2020**: [FP16](https://pytorch.org/docs/stable/nn.html#torch.nn.Module.half) as new default for smaller checkpoints and faster inference [d4c6674](https://github.com/ultralytics/yolov5/commit/d4c6674c98e19df4c40e33a777610a18d1961145).
15-
- **June 9, 2020**: [CSP](https://github.com/WongKinYiu/CrossStagePartialNetworks) updates: improved speed, size, and accuracy (credit to @WongKinYiu for CSP).
16-
- **May 27, 2020**: Public release. YOLOv5 models are SOTA among all known YOLO implementations.
1716

1817

1918
## Pretrained Checkpoints
2019

21-
| Model | AP<sup>val</sup> | AP<sup>test</sup> | AP<sub>50</sub> | Speed<sub>GPU</sub> | FPS<sub>GPU</sub> || params | GFLOPS |
22-
|---------- |------ |------ |------ | -------- | ------| ------ |------ | :------: |
23-
| [YOLOv5s](https://github.com/ultralytics/yolov5/releases) | 37.0 | 37.0 | 56.2 | **2.4ms** | **416** || 7.5M | 17.5
24-
| [YOLOv5m](https://github.com/ultralytics/yolov5/releases) | 44.3 | 44.3 | 63.2 | 3.4ms | 294 || 21.8M | 52.3
25-
| [YOLOv5l](https://github.com/ultralytics/yolov5/releases) | 47.7 | 47.7 | 66.5 | 4.4ms | 227 || 47.8M | 117.2
26-
| [YOLOv5x](https://github.com/ultralytics/yolov5/releases) | **49.2** | **49.2** | **67.7** | 6.9ms | 145 || 89.0M | 221.5
27-
| | | | | | || |
28-
| [YOLOv5x](https://github.com/ultralytics/yolov5/releases) + TTA|**50.8**| **50.8** | **68.9** | 25.5ms | 39 || 89.0M | 801.0
20+
| Model | size | AP<sup>val</sup> | AP<sup>test</sup> | AP<sub>50</sub> | Speed<sub>V100</sub> | FPS<sub>V100</sub> || params | GFLOPS |
21+
|---------- |------ |------ |------ |------ | -------- | ------| ------ |------ | :------: |
22+
| [YOLOv5s](https://github.com/ultralytics/yolov5/releases) |640 |36.8 |36.8 |55.6 |**2.2ms** |**455** ||7.3M |17.0
23+
| [YOLOv5m](https://github.com/ultralytics/yolov5/releases) |640 |44.5 |44.5 |63.1 |2.9ms |345 ||21.4M |51.3
24+
| [YOLOv5l](https://github.com/ultralytics/yolov5/releases) |640 |48.1 |48.1 |66.4 |3.8ms |264 ||47.0M |115.4
25+
| [YOLOv5x](https://github.com/ultralytics/yolov5/releases) |640 |**50.1** |**50.1** |**68.7** |6.0ms |167 ||87.7M |218.8
26+
| | | | | | | || |
27+
| [YOLOv5x](https://github.com/ultralytics/yolov5/releases) + TTA |832 |**51.9** |**51.9** |**69.6** |24.9ms |40 ||87.7M |1005.3
28+
29+
<!---
30+
| [YOLOv5l6](https://github.com/ultralytics/yolov5/releases) |640 |49.0 |49.0 |67.4 |4.1ms |244 ||77.2M |117.7
31+
| [YOLOv5l6](https://github.com/ultralytics/yolov5/releases) |1280 |53.0 |53.0 |70.8 |12.3ms |81 ||77.2M |117.7
32+
--->
2933

3034
** AP<sup>test</sup> denotes COCO [test-dev2017](http://cocodataset.org/#upload) server results, all other AP results denote val2017 accuracy.
3135
** All AP numbers are for single-model single-scale without ensemble or TTA. **Reproduce mAP** by `python test.py --data coco.yaml --img 640 --conf 0.001 --iou 0.65`
3236
** Speed<sub>GPU</sub> averaged over 5000 COCO val2017 images using a GCP [n1-standard-16](https://cloud.google.com/compute/docs/machine-types#n1_standard_machine_types) V100 instance, and includes image preprocessing, FP16 inference, postprocessing and NMS. NMS is 1-2ms/img. **Reproduce speed** by `python test.py --data coco.yaml --img 640 --conf 0.25 --iou 0.45`
3337
** All checkpoints are trained to 300 epochs with default settings and hyperparameters (no autoaugmentation).
3438
** Test Time Augmentation ([TTA](https://github.com/ultralytics/yolov5/issues/303)) runs at 3 image sizes. **Reproduce TTA** by `python test.py --data coco.yaml --img 832 --iou 0.65 --augment`
3539

40+
3641
## Requirements
3742

3843
Python 3.8 or later with all [requirements.txt](https://github.com/ultralytics/yolov5/blob/master/requirements.txt) dependencies installed, including `torch>=1.7`. To install run:
@@ -106,21 +111,21 @@ import torch
106111
from PIL import Image
107112

108113
# Model
109-
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) # for PIL/cv2/np inputs and NMS
114+
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
110115

111116
# Images
112117
img1 = Image.open('zidane.jpg')
113118
img2 = Image.open('bus.jpg')
114119
imgs = [img1, img2] # batched list of images
115120

116121
# Inference
117-
prediction = model(imgs, size=640) # includes NMS
122+
result = model(imgs)
118123
```
119124

120125

121126
## Training
122127

123-
Download [COCO](https://github.com/ultralytics/yolov5/blob/master/data/scripts/get_coco.sh) and run command below. Training times for YOLOv5s/m/l/x are 2/4/6/8 days on a single V100 (multi-GPU times faster). Use the largest `--batch-size` your GPU allows (batch sizes shown for 16 GB devices).
128+
Run commands below to reproduce results on [COCO](https://github.com/ultralytics/yolov5/blob/master/data/scripts/get_coco.sh) dataset (dataset auto-downloads on first use). Training times for YOLOv5s/m/l/x are 2/4/6/8 days on a single V100 (multi-GPU times faster). Use the largest `--batch-size` your GPU allows (batch sizes shown for 16 GB devices).
124129
```bash
125130
$ python train.py --data coco.yaml --cfg yolov5s.yaml --weights '' --batch-size 64
126131
yolov5m 40

data/coco.yaml

+9-9
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ test: ../coco/test-dev2017.txt # 20288 of 40670 images, submit to https://compe
1818
nc: 80
1919

2020
# class names
21-
names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
22-
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
23-
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
24-
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
25-
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
26-
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
27-
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
28-
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
29-
'hair drier', 'toothbrush']
21+
names: [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
22+
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
23+
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
24+
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
25+
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
26+
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
27+
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
28+
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
29+
'hair drier', 'toothbrush' ]
3030

3131
# Print classes
3232
# with open('data/coco.yaml') as f:

data/coco128.yaml

+9-9
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ val: ../coco128/images/train2017/ # 128 images
1717
nc: 80
1818

1919
# class names
20-
names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
21-
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
22-
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
23-
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
24-
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
25-
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
26-
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
27-
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
28-
'hair drier', 'toothbrush']
20+
names: [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
21+
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
22+
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
23+
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
24+
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
25+
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
26+
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
27+
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
28+
'hair drier', 'toothbrush' ]

data/voc.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@ val: ../VOC/images/val/ # 4952 images
1717
nc: 20
1818

1919
# class names
20-
names: ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
21-
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
20+
names: [ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
21+
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' ]

models/common.py

+32-9
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, k
3030
super(Conv, self).__init__()
3131
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
3232
self.bn = nn.BatchNorm2d(c2)
33-
self.act = nn.Hardswish() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
33+
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
3434

3535
def forward(self, x):
3636
return self.act(self.bn(self.conv(x)))
@@ -105,9 +105,39 @@ class Focus(nn.Module):
105105
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
106106
super(Focus, self).__init__()
107107
self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
108+
# self.contract = Contract(gain=2)
108109

109110
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
110111
return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
112+
# return self.conv(self.contract(x))
113+
114+
115+
class Contract(nn.Module):
116+
# Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
117+
def __init__(self, gain=2):
118+
super().__init__()
119+
self.gain = gain
120+
121+
def forward(self, x):
122+
N, C, H, W = x.size() # assert (H / s == 0) and (W / s == 0), 'Indivisible gain'
123+
s = self.gain
124+
x = x.view(N, C, H // s, s, W // s, s) # x(1,64,40,2,40,2)
125+
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
126+
return x.view(N, C * s * s, H // s, W // s) # x(1,256,40,40)
127+
128+
129+
class Expand(nn.Module):
130+
# Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
131+
def __init__(self, gain=2):
132+
super().__init__()
133+
self.gain = gain
134+
135+
def forward(self, x):
136+
N, C, H, W = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
137+
s = self.gain
138+
x = x.view(N, s, s, C // s ** 2, H, W) # x(1,2,2,16,80,80)
139+
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
140+
return x.view(N, C // s ** 2, H * s, W * s) # x(1,16,160,160)
111141

112142

113143
class Concat(nn.Module):
@@ -253,20 +283,13 @@ def tolist(self):
253283
return x
254284

255285

256-
class Flatten(nn.Module):
257-
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
258-
@staticmethod
259-
def forward(x):
260-
return x.view(x.size(0), -1)
261-
262-
263286
class Classify(nn.Module):
264287
# Classification head, i.e. x(b,c1,20,20) to x(b,c2)
265288
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
266289
super(Classify, self).__init__()
267290
self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
268291
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
269-
self.flat = Flatten()
292+
self.flat = nn.Flatten()
270293

271294
def forward(self, x):
272295
z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list

models/experimental.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ def forward(self, x, augment=False):
105105
for module in self:
106106
y.append(module(x, augment)[0])
107107
# y = torch.stack(y).max(0)[0] # max ensemble
108-
# y = torch.cat(y, 1) # nms ensemble
109-
y = torch.stack(y).mean(0) # mean ensemble
108+
# y = torch.stack(y).mean(0) # mean ensemble
109+
y = torch.cat(y, 1) # nms ensemble
110110
return y, None # inference, train output
111111

112112

models/hub/anchors.yaml

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Default YOLOv5 anchors for COCO data
2+
3+
4+
# P5 -------------------------------------------------------------------------------------------------------------------
5+
# P5-640:
6+
anchors_p5_640:
7+
- [ 10,13, 16,30, 33,23 ] # P3/8
8+
- [ 30,61, 62,45, 59,119 ] # P4/16
9+
- [ 116,90, 156,198, 373,326 ] # P5/32
10+
11+
12+
# P6 -------------------------------------------------------------------------------------------------------------------
13+
# P6-640: thr=0.25: 0.9964 BPR, 5.54 anchors past thr, n=12, img_size=640, metric_all=0.281/0.716-mean/best, past_thr=0.469-mean: 9,11, 21,19, 17,41, 43,32, 39,70, 86,64, 65,131, 134,130, 120,265, 282,180, 247,354, 512,387
14+
anchors_p6_640:
15+
- [ 9,11, 21,19, 17,41 ] # P3/8
16+
- [ 43,32, 39,70, 86,64 ] # P4/16
17+
- [ 65,131, 134,130, 120,265 ] # P5/32
18+
- [ 282,180, 247,354, 512,387 ] # P6/64
19+
20+
# P6-1280: thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1280, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 19,27, 44,40, 38,94, 96,68, 86,152, 180,137, 140,301, 303,264, 238,542, 436,615, 739,380, 925,792
21+
anchors_p6_1280:
22+
- [ 19,27, 44,40, 38,94 ] # P3/8
23+
- [ 96,68, 86,152, 180,137 ] # P4/16
24+
- [ 140,301, 303,264, 238,542 ] # P5/32
25+
- [ 436,615, 739,380, 925,792 ] # P6/64
26+
27+
# P6-1920: thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1920, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 28,41, 67,59, 57,141, 144,103, 129,227, 270,205, 209,452, 455,396, 358,812, 653,922, 1109,570, 1387,1187
28+
anchors_p6_1920:
29+
- [ 28,41, 67,59, 57,141 ] # P3/8
30+
- [ 144,103, 129,227, 270,205 ] # P4/16
31+
- [ 209,452, 455,396, 358,812 ] # P5/32
32+
- [ 653,922, 1109,570, 1387,1187 ] # P6/64
33+
34+
35+
# P7 -------------------------------------------------------------------------------------------------------------------
36+
# P7-640: thr=0.25: 0.9962 BPR, 6.76 anchors past thr, n=15, img_size=640, metric_all=0.275/0.733-mean/best, past_thr=0.466-mean: 11,11, 13,30, 29,20, 30,46, 61,38, 39,92, 78,80, 146,66, 79,163, 149,150, 321,143, 157,303, 257,402, 359,290, 524,372
37+
anchors_p7_640:
38+
- [ 11,11, 13,30, 29,20 ] # P3/8
39+
- [ 30,46, 61,38, 39,92 ] # P4/16
40+
- [ 78,80, 146,66, 79,163 ] # P5/32
41+
- [ 149,150, 321,143, 157,303 ] # P6/64
42+
- [ 257,402, 359,290, 524,372 ] # P7/128
43+
44+
# P7-1280: thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1280, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 19,22, 54,36, 32,77, 70,83, 138,71, 75,173, 165,159, 148,334, 375,151, 334,317, 251,626, 499,474, 750,326, 534,814, 1079,818
45+
anchors_p7_1280:
46+
- [ 19,22, 54,36, 32,77 ] # P3/8
47+
- [ 70,83, 138,71, 75,173 ] # P4/16
48+
- [ 165,159, 148,334, 375,151 ] # P5/32
49+
- [ 334,317, 251,626, 499,474 ] # P6/64
50+
- [ 750,326, 534,814, 1079,818 ] # P7/128
51+
52+
# P7-1920: thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1920, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 29,34, 81,55, 47,115, 105,124, 207,107, 113,259, 247,238, 222,500, 563,227, 501,476, 376,939, 749,711, 1126,489, 801,1222, 1618,1227
53+
anchors_p7_1920:
54+
- [ 29,34, 81,55, 47,115 ] # P3/8
55+
- [ 105,124, 207,107, 113,259 ] # P4/16
56+
- [ 247,238, 222,500, 563,227 ] # P5/32
57+
- [ 501,476, 376,939, 749,711 ] # P6/64
58+
- [ 1126,489, 801,1222, 1618,1227 ] # P7/128

0 commit comments

Comments
 (0)