Skip to content

Commit 735f7ff

Browse files
authored
Adding re-ranking for image retrieval (microsoft#515)
1 parent 0b503c4 commit 735f7ff

File tree

8 files changed

+462
-58
lines changed

8 files changed

+462
-58
lines changed

NOTICE.txt

+26
Original file line numberDiff line numberDiff line change
@@ -500,3 +500,29 @@ Permission is hereby granted, free of charge, to any person obtaining a copy of
500500
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
501501
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
502502

503+
--
504+
505+
https://github.com/layumi/Person_reID_baseline_pytorch
506+
507+
508+
MIT License
509+
510+
Copyright (c) 2018 Zhedong Zheng
511+
512+
Permission is hereby granted, free of charge, to any person obtaining a copy
513+
of this software and associated documentation files (the "Software"), to deal
514+
in the Software without restriction, including without limitation the rights
515+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
516+
copies of the Software, and to permit persons to whom the Software is
517+
furnished to do so, subject to the following conditions:
518+
519+
The above copyright notice and this permission notice shall be included in all
520+
copies or substantial portions of the Software.
521+
522+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
523+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
524+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
525+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
526+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
527+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
528+
SOFTWARE.

scenarios/similarity/02_state_of_the_art.ipynb

+103-51
Large diffs are not rendered by default.

scenarios/similarity/README.md

+5
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ Below are a subset of popular papers in the field with reported accuracies on st
2727
| [Classification is a Strong Baseline for DeepMetric Learning](https://arxiv.org/abs/1811.12649) <br> (Implemented in this repository) | BMVC 2019 | No | **84%** (512-dim) <br> **89%** (2048-dim) | 61% (512-dim) <br> **65%** (2048-dim) | **78%** (512-dim) <br> **80%** (2048-dim) |
2828

2929

30+
## Re-ranking
31+
32+
In addition to the SOTA method introduced above, we provide an implementation of a popular re-ranking approach published in the CVPR 2017 paper [Re-ranking Person Re-identification with k-reciprocal Encoding](http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf). Re-ranking is a post-processing step to improve retrieval accuracy. The proposed approach is fast, fully automatic, unsupervised, and shown to outperform other state-of-the-art methods with regards to accuracy.
33+
34+
3035
## Frequently asked questions
3136

3237
Answers to Frequently Asked Questions such as "How many images do I need to train a model?" or "How to annotate images?" can be found in the [FAQ.md](FAQ.md) file. For image classification specified questions, see the [FAQ.md](../classification/FAQ.md) in the classification folder.

tests/conftest.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from PIL import Image
1717
from torch import tensor
1818
from pathlib import Path
19-
from fastai.vision import cnn_learner, models
19+
from fastai.vision import cnn_learner, DatasetType, models
2020
from fastai.vision.data import ImageList, imagenet_stats
2121
from typing import List, Tuple
2222
from tempfile import TemporaryDirectory
@@ -35,6 +35,7 @@
3535
_apply_threshold,
3636
)
3737
from utils_cv.similarity.data import Urls as is_urls
38+
from utils_cv.similarity.model import compute_features_learner
3839

3940

4041
def path_classification_notebooks():
@@ -279,7 +280,7 @@ def tiny_ic_databunch(tmp_session):
279280
.split_by_rand_pct(valid_pct=0.1, seed=20)
280281
.label_from_folder()
281282
.transform(size=50)
282-
.databunch(bs=16, num_workers = db_num_workers())
283+
.databunch(bs=16, num_workers=db_num_workers())
283284
.normalize(imagenet_stats)
284285
)
285286

@@ -351,7 +352,7 @@ def testing_databunch(tmp_session):
351352
.split_by_rand_pct(valid_pct=0.2, seed=20)
352353
.label_from_folder()
353354
.transform(size=300)
354-
.databunch(bs=16, num_workers = db_num_workers())
355+
.databunch(bs=16, num_workers=db_num_workers())
355356
.normalize(imagenet_stats)
356357
)
357358

@@ -735,6 +736,7 @@ def workspace_region(request):
735736

736737
# ------|-- Similarity ---------------------------------------------
737738

739+
738740
@pytest.fixture(scope="session")
739741
def tiny_is_data_path(tmp_session) -> str:
740742
""" Returns the path to the tiny fridge objects dataset. """
@@ -743,4 +745,14 @@ def tiny_is_data_path(tmp_session) -> str:
743745
fpath=tmp_session,
744746
dest=tmp_session,
745747
exist_ok=True,
746-
)
748+
)
749+
750+
751+
@pytest.fixture(scope="session")
752+
def tiny_ic_databunch_valid_features(tiny_ic_databunch):
753+
learn = cnn_learner(tiny_ic_databunch, models.resnet18)
754+
embedding_layer = learn.model[1][6]
755+
features = compute_features_learner(
756+
tiny_ic_databunch, DatasetType.Valid, learn, embedding_layer
757+
)
758+
return features

tests/unit/similarity/test_similarity_metrics.py

+29
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from utils_cv.similarity.data import comparative_set_builder
99
from utils_cv.similarity.metrics import (
1010
compute_distances,
11+
evaluate,
1112
positive_image_ranks,
1213
recall_at_k,
1314
vector_distance,
@@ -64,3 +65,31 @@ def test_recall_at_k():
6465
assert recall_at_k(rank_list, 3) == 60
6566
assert recall_at_k(rank_list, 6) == 100
6667
assert recall_at_k(rank_list, 10) == 100
68+
69+
70+
def test_evaluate(tiny_ic_databunch, tiny_ic_databunch_valid_features):
71+
(rank_accs, mAP) = evaluate(
72+
tiny_ic_databunch.valid_ds,
73+
tiny_ic_databunch_valid_features,
74+
use_rerank=False,
75+
)
76+
assert 0 <= mAP <= 1.0
77+
assert len(rank_accs) == 6
78+
assert max(rank_accs) <= 1.001
79+
assert min(rank_accs) >= -0.001
80+
for i in range(len(rank_accs) - 1):
81+
rank_accs[i] <= rank_accs[i + 1]
82+
83+
(rank_accs, ap) = evaluate(
84+
tiny_ic_databunch.valid_ds,
85+
tiny_ic_databunch_valid_features,
86+
use_rerank=True,
87+
rerank_k1=2,
88+
rerank_k2=3,
89+
)
90+
assert 0 <= mAP <= 1.0
91+
assert len(rank_accs) == 6
92+
assert max(rank_accs) <= 1.001
93+
assert min(rank_accs) >= -0.001
94+
for i in range(len(rank_accs) - 1):
95+
rank_accs[i] <= rank_accs[i + 1]

utils_cv/similarity/metrics.py

+53-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT License.
3-
4-
from typing import List
5-
3+
from typing import Dict, List
64
import numpy as np
75
import scipy
86

7+
from fastai.vision import LabelList
8+
from .references.evaluate import evaluate_with_query_set
9+
910

1011
def vector_distance(
1112
vec1: np.ndarray,
@@ -105,3 +106,52 @@ def recall_at_k(ranks: List[int], k: int) -> float:
105106
below_threshold = [x for x in ranks if x <= k]
106107
percent_in_top_k = round(100.0 * len(below_threshold) / len(ranks), 1)
107108
return percent_in_top_k
109+
110+
111+
def evaluate(
112+
data: LabelList,
113+
features: Dict[str, np.array],
114+
use_rerank=False,
115+
rerank_k1=20,
116+
rerank_k2=6,
117+
rerank_lambda=0.3,
118+
):
119+
"""
120+
Computes rank@1 through rank@10 accuracy as well as mAP, optionally with re-ranking
121+
post-processor to improve accuracy (see the re-ranking implementation for more info).
122+
123+
Args:
124+
data: Fastai's image labellist
125+
features: Dictionary of DNN features for each image
126+
use_rerank: use re-ranking
127+
rerank_k1, rerank_k2, rerank_lambda: re-ranking parameters
128+
Returns:
129+
rank_accs: accuracy at rank1 through rank10
130+
mAP: average precision
131+
132+
"""
133+
134+
labels = np.array([data.y[i].obj for i in range(len(data.y))])
135+
features = np.array([features[str(s)] for s in data.items])
136+
137+
# Assign each image into its own group. This serves as id during evaluation to
138+
# ensure a query image is not compared to itself during rank computation.
139+
# For the market-1501 dataset, the group ids can be used to ensure that a query
140+
# can not match to an image taken from the same camera.
141+
groups = np.array(range(len(labels)))
142+
assert len(labels) == len(groups) == features.shape[0]
143+
144+
# Run evaluation
145+
rank_accs, mAP = evaluate_with_query_set(
146+
labels,
147+
groups,
148+
features,
149+
labels,
150+
groups,
151+
features,
152+
use_rerank,
153+
rerank_k1,
154+
rerank_k2,
155+
rerank_lambda,
156+
)
157+
return rank_accs, mAP
+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Most of the code in this file is copied and slightly modified from:
2+
# https://github.com/layumi/Person_reID_baseline_pytorch/blob/master/evaluate.py
3+
4+
import numpy as np
5+
import time
6+
import torch
7+
8+
from .re_ranking import re_ranking
9+
10+
11+
# Note: the Market1501 dataset has a slightly different evaluation procedure which can be used
12+
# by setting is_market1501=True.
13+
def evaluate_with_query_set(
14+
gallery_labels,
15+
gallery_groups,
16+
gallery_features,
17+
query_labels,
18+
query_groups,
19+
query_features,
20+
use_rerank=False,
21+
rerank_k1=20,
22+
rerank_k2=6,
23+
rerank_lambda=0.3,
24+
is_market1501=False,
25+
):
26+
27+
# Init
28+
ap = 0.0
29+
CMC = torch.IntTensor(len(gallery_labels)).zero_()
30+
31+
# Compute pairwise distance
32+
q_g_dist = np.dot(query_features, np.transpose(gallery_features))
33+
34+
# Improve pairwise distances using re-ranking
35+
if use_rerank:
36+
print("Calculate re-ranked distances..")
37+
q_q_dist = np.dot(query_features, np.transpose(query_features))
38+
g_g_dist = np.dot(gallery_features, np.transpose(gallery_features))
39+
since = time.time()
40+
distances = re_ranking(
41+
q_g_dist, q_q_dist, g_g_dist, k1=rerank_k1, k2=rerank_k2, lambda_value=rerank_lambda,
42+
)
43+
time_elapsed = time.time() - since
44+
print(
45+
"Reranking complete in {:.0f}m {:.0f}s".format(
46+
time_elapsed // 60, time_elapsed % 60
47+
)
48+
)
49+
else:
50+
distances = -q_g_dist
51+
52+
# Compute accuracies
53+
norm = 0
54+
skip = 1 # set to >1 to only consider a subset of the query images
55+
for i in range(len(query_labels))[::skip]:
56+
ap_tmp, CMC_tmp = evaluate_helper(
57+
distances[i, :],
58+
query_labels[i],
59+
query_groups[i],
60+
gallery_labels,
61+
gallery_groups,
62+
is_market1501,
63+
)
64+
if CMC_tmp[0] == -1:
65+
continue
66+
norm += 1
67+
ap += ap_tmp
68+
CMC = CMC + CMC_tmp
69+
70+
# Print accuracy. Note that Market1501 normalizes by dividing over number of query images.
71+
if is_market1501:
72+
norm = len(query_labels) / float(skip)
73+
ap = ap / norm
74+
CMC = CMC.float()
75+
CMC = CMC / norm
76+
print(
77+
"Rank@1:{:.1f}, rank@5:{:.1f}, mAP:{:.2f}".format(100 * CMC[0], 100 * CMC[4], ap)
78+
)
79+
80+
return (CMC, ap)
81+
82+
83+
# Explanation:
84+
# - query_index: all images in the reference set with the same label as the query image ("true match")
85+
# - camera_index: all images which share the same group (called "camera" since the code was originally written for the Market-1501 dataset).
86+
# - junk_index2: all reference images with the same group ("camera") as the query are considered "false matches".
87+
# - junk_index1: for the market1501 dataset, images with label -1 should be ignored.
88+
def evaluate_helper(score, ql, qc, gl, gc, is_market1501=False):
89+
assert type(gl) == np.ndarray, "Input gl has to be a numpy ndarray"
90+
assert type(gc) == np.ndarray, "Input gc has to be a numpy ndarray"
91+
92+
# Sort scores
93+
index = np.argsort(score) # from small to large
94+
95+
# Compare reference images to the query image.
96+
query_index = np.argwhere(gl == ql)
97+
camera_index = np.argwhere(gc == qc)
98+
good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
99+
junk_index2 = np.intersect1d(query_index, camera_index)
100+
101+
# For market 1501 dataset, ignore images with label -1
102+
if is_market1501:
103+
junk_index1a = np.argwhere(gl == -1)
104+
junk_index1b = np.argwhere(gl == "-1")
105+
junk_index1 = np.append(junk_index1a, junk_index1b)
106+
junk_index = np.append(junk_index2, junk_index1)
107+
else:
108+
junk_index = junk_index2
109+
110+
CMC_tmp = compute_mAP(index, good_index, junk_index)
111+
return CMC_tmp
112+
113+
114+
def compute_mAP(index, good_index, junk_index):
115+
ap = 0
116+
cmc = torch.IntTensor(len(index)).zero_()
117+
if good_index.size == 0: # if empty
118+
cmc[0] = -1
119+
return ap, cmc
120+
121+
# remove junk_index
122+
mask = np.in1d(index, junk_index, invert=True)
123+
index = index[mask]
124+
125+
# find good_index index
126+
ngood = len(good_index)
127+
mask = np.in1d(index, good_index)
128+
rows_good = np.argwhere(mask) # == True)
129+
rows_good = rows_good.flatten()
130+
131+
cmc[rows_good[0] :] = 1
132+
for i in range(ngood):
133+
d_recall = 1.0 / ngood
134+
precision = (i + 1) * 1.0 / (rows_good[i] + 1)
135+
if rows_good[i] != 0:
136+
old_precision = i * 1.0 / rows_good[i]
137+
else:
138+
old_precision = 1.0
139+
ap = ap + d_recall * (old_precision + precision) / 2
140+
141+
return ap, cmc

0 commit comments

Comments
 (0)