Skip to content

Commit b950b46

Browse files
COCO mAP metric (#2901)
* Keep only cocomap-related changes i.e. ObjectDetectionMap and its dependencies * Some improvements Removed allow_multiple... Renamed average_operand Renamed _measure_recall... to _compute_recall... * Update docs * Fix a bug in docs Docs has some nasty errors * Fix a tiny bug related to allgather * Fix a few bugs * Redesign code: Removed generic detection logics. Just that of the COCO is remained Tests are updated * Remove all_gather with different shape * Add test for all_gather_with_different_shape func * A few improvements * Add an output transform and apply a review comment * Add a test for the output_transform * Remove 'flavor' because all DeciAI, Ultralytics, Detectron and pycocotools use the 'max-precision' approach * Revert Metric change and a few bug fix * A tiny improvement in local variable names * Add max_dep and area_range * some improvements * Improvement in code * Some improvements * Fix a bug; Some improvements; Improve docs * Fix metrics.rst * Remove @OverRide which is for 3.12 * Fix mypy issues * Fix two tests * Fix a typo in tests * Fix dist tests * Add common obj. det. metrics * Change an annotation for the sake of M1 python3.8 * Use if check on torch.double usages for MPS backend * Fix a typo * Fix a bug related to tensors on same devices * Fix a bug related to MPS and torch.double * Fix a bug related to MPS * Fix a bug related to MPS * Fix a bug related to MPS * Resolve MPS's lack of cummax * Revert MPS fallback * Apply comments * Revert unnecessary changes * Apply review comments * Skip MPS on test_integraion as well --------- Co-authored-by: vfdev <[email protected]>
1 parent 8255a0f commit b950b46

11 files changed

+2185
-5
lines changed

docs/source/metrics.rst

+4
Original file line numberDiff line numberDiff line change
@@ -332,13 +332,17 @@ Complete list of metrics
332332
Frequency
333333
Loss
334334
MeanAbsoluteError
335+
MeanAveragePrecision
335336
MeanPairwiseDistance
336337
MeanSquaredError
337338
metric.Metric
338339
metric_group.MetricGroup
339340
metrics_lambda.MetricsLambda
340341
MultiLabelConfusionMatrix
341342
MutualInformation
343+
ObjectDetectionAvgPrecisionRecall
344+
CommonObjectDetectionMetrics
345+
vision.object_detection_average_precision_recall.coco_tensor_list_to_dict_list
342346
precision.Precision
343347
PSNR
344348
recall.Recall

ignite/metrics/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ignite.metrics.loss import Loss
2222
from ignite.metrics.maximum_mean_discrepancy import MaximumMeanDiscrepancy
2323
from ignite.metrics.mean_absolute_error import MeanAbsoluteError
24+
from ignite.metrics.mean_average_precision import MeanAveragePrecision
2425
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
2526
from ignite.metrics.mean_squared_error import MeanSquaredError
2627
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage
@@ -39,6 +40,11 @@
3940
from ignite.metrics.running_average import RunningAverage
4041
from ignite.metrics.ssim import SSIM
4142
from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy
43+
from ignite.metrics.vision.object_detection_average_precision_recall import (
44+
coco_tensor_list_to_dict_list,
45+
CommonObjectDetectionMetrics,
46+
ObjectDetectionAvgPrecisionRecall,
47+
)
4248

4349
__all__ = [
4450
"Metric",
@@ -90,4 +96,8 @@
9096
"PrecisionRecallCurve",
9197
"RocCurve",
9298
"ROC_AUC",
99+
"MeanAveragePrecision",
100+
"ObjectDetectionAvgPrecisionRecall",
101+
"CommonObjectDetectionMetrics",
102+
"coco_tensor_list_to_dict_list",
93103
]

ignite/metrics/mean_average_precision.py

+394
Large diffs are not rendered by default.

ignite/metrics/metric.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ def _is_list_of_tensors_or_numbers(x: Sequence[Union[torch.Tensor, float]]) -> b
873873
return isinstance(x, Sequence) and all([isinstance(t, (torch.Tensor, Number)) for t in x])
874874

875875

876-
def _to_batched_tensor(x: Union[torch.Tensor, float], device: Optional[torch.device] = None) -> torch.Tensor:
876+
def _to_batched_tensor(x: Union[torch.Tensor, Number], device: Optional[torch.device] = None) -> torch.Tensor:
877877
if isinstance(x, torch.Tensor):
878878
return x.unsqueeze(dim=0)
879879
return torch.tensor([x], device=device)

ignite/metrics/metric_group.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Dict, Sequence
1+
from typing import Any, Callable, Dict, Sequence, Tuple
22

33
import torch
44

@@ -15,6 +15,11 @@ class MetricGroup(Metric):
1515
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
1616
form expected by the metric. `output_transform` of each metric in the group is also
1717
called upon its update.
18+
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
19+
true for multi-output model, for example, if ``y_pred`` and ``y`` contain multi-ouput as
20+
``(y_pred_a, y_pred_b)`` and ``(y_a, y_b)``, in which case the update method is called for
21+
``(y_pred_a, y_a)`` and ``(y_pred_b, y_b)``.Alternatively, ``output_transform`` can be used to handle
22+
this.
1823
1924
Examples:
2025
We construct a group of metrics, attach them to the engine at once and retrieve their result.
@@ -34,13 +39,18 @@ class MetricGroup(Metric):
3439
3540
# And also altogether
3641
state.metrics["eval_metrics"]
42+
43+
.. versionchanged:: 0.5.2
44+
``skip_unrolling`` argument is added.
3745
"""
3846

39-
_state_dict_all_req_keys = ("metrics",)
47+
_state_dict_all_req_keys: Tuple[str, ...] = ("metrics",)
4048

41-
def __init__(self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x):
49+
def __init__(
50+
self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x, skip_unrolling: bool = False
51+
):
4252
self.metrics = metrics
43-
super(MetricGroup, self).__init__(output_transform=output_transform)
53+
super(MetricGroup, self).__init__(output_transform=output_transform, skip_unrolling=skip_unrolling)
4454

4555
def reset(self) -> None:
4656
for m in self.metrics.values():

ignite/metrics/vision/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ignite.metrics.vision.object_detection_average_precision_recall import ObjectDetectionAvgPrecisionRecall
2+
3+
__all__ = ["ObjectDetectionAvgPrecisionRecall"]

0 commit comments

Comments
 (0)