Skip to content

Commit 3706079

Browse files
Implement HistogramLoss with tests
1 parent 6b4d4d1 commit 3706079

File tree

4 files changed

+252
-0
lines changed

4 files changed

+252
-0
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ dist/
66
*.egg-info/
77
site/
88
venv/
9+
**/.vscode
910
.ipynb_checkpoints
1011
examples/notebooks/dataset
1112
examples/notebooks/CIFAR10_Dataset

Diff for: src/pytorch_metric_learning/losses/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .cross_batch_memory import CrossBatchMemory
99
from .fast_ap_loss import FastAPLoss
1010
from .generic_pair_loss import GenericPairLoss
11+
from .histogram_loss import HistogramLoss
1112
from .instance_loss import InstanceLoss
1213
from .intra_pair_variance_loss import IntraPairVarianceLoss
1314
from .large_margin_softmax_loss import LargeMarginSoftmaxLoss

Diff for: src/pytorch_metric_learning/losses/histogram_loss.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import torch
2+
3+
from ..distances import CosineSimilarity
4+
from ..utils import common_functions as c_f
5+
from ..utils import loss_and_miner_utils as lmu
6+
from .base_metric_loss_function import BaseMetricLossFunction
7+
8+
9+
def filter_pairs(*tensors: torch.Tensor):
10+
t = torch.stack(tensors)
11+
t, _ = torch.sort(t, dim=0)
12+
t = torch.unique(t, dim=1)
13+
return t.tolist()
14+
15+
16+
class HistogramLoss(BaseMetricLossFunction):
17+
def __init__(self, n_bins: int = None, delta: float = None, **kwargs):
18+
super().__init__(**kwargs)
19+
assert (
20+
delta is None
21+
and n_bins is not None
22+
or delta is not None
23+
and n_bins is None
24+
or delta is not None
25+
and n_bins is not None
26+
), "delta and n_bins cannot be both None"
27+
28+
if delta is not None and n_bins is not None:
29+
assert (
30+
delta == 2 / n_bins
31+
), f"delta and n_bins must satisfy the equation delta = 2/n_bins.\nPassed values are delta={delta} and n_bins={n_bins}"
32+
33+
self.delta = delta if delta is not None else 2 / n_bins
34+
self.add_to_recordable_attributes(name="num_bins", is_stat=True)
35+
36+
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
37+
c_f.labels_or_indices_tuple_required(labels, indices_tuple)
38+
c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels)
39+
indices_tuple = lmu.convert_to_triplets(
40+
indices_tuple, labels, ref_labels, t_per_anchor="all"
41+
)
42+
anchor_idx, positive_idx, negative_idx = indices_tuple
43+
if len(anchor_idx) == 0:
44+
return self.zero_losses()
45+
mat = self.distance(embeddings, ref_emb)
46+
47+
anchor_positive_idx = filter_pairs(anchor_idx, positive_idx)
48+
anchor_negative_idx = filter_pairs(anchor_idx, negative_idx)
49+
ap_dists = mat[anchor_positive_idx]
50+
an_dists = mat[anchor_negative_idx]
51+
52+
p_pos = self.compute_density(ap_dists)
53+
phi = torch.cumsum(p_pos, dim=0)
54+
55+
p_neg = self.compute_density(an_dists)
56+
return {
57+
"loss": {
58+
"losses": torch.sum(p_neg * phi),
59+
"indices": None,
60+
"reduction_type": "already_reduced",
61+
}
62+
}
63+
64+
def compute_density(self, distances):
65+
size = distances.size(0)
66+
r_star = torch.floor(
67+
(distances.float() + 1) / self.delta
68+
) # Indices of the bins containing the values of the distances
69+
r_star = c_f.to_device(r_star, tensor=distances, dtype=torch.long)
70+
71+
delta_ijr_a = (distances + 1 - r_star * self.delta) / self.delta
72+
delta_ijr_b = ((r_star + 1) * self.delta - 1 - distances) / self.delta
73+
delta_ijr_a = c_f.to_dtype(delta_ijr_a, tensor=distances)
74+
delta_ijr_b = c_f.to_dtype(delta_ijr_b, tensor=distances)
75+
76+
density = torch.zeros(round(1 + 2 / self.delta))
77+
density = c_f.to_device(density, tensor=distances, dtype=distances.dtype)
78+
79+
# For each node sum the contributions of the bins whose ending node is this one
80+
density.scatter_add_(0, r_star + 1, delta_ijr_a)
81+
# For each node sum the contributions of the bins whose starting node is this one
82+
density.scatter_add_(0, r_star, delta_ijr_b)
83+
return density / size
84+
85+
def get_default_distance(self):
86+
return CosineSimilarity()

Diff for: tests/losses/test_histogram_loss.py

+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import unittest
2+
3+
import torch
4+
from numpy.testing import assert_almost_equal
5+
6+
from pytorch_metric_learning.losses import HistogramLoss
7+
8+
from .. import TEST_DEVICE, TEST_DTYPES
9+
from ..zzz_testing_utils.testing_utils import angle_to_coord
10+
11+
12+
######################################
13+
#######ORIGINAL IMPLEMENTATION########
14+
######################################
15+
# DIRECTLY COPIED from https://github.com/valerystrizh/pytorch-histogram-loss/blob/master/losses.py.
16+
# This code is copied from the official PyTorch implementation
17+
# so that we can make sure our implementation returns the same result.
18+
# Some minor changes were made to avoid errors during testing.
19+
# Every change in the original code is reported and explained.
20+
class OriginalImplementationHistogramLoss(torch.nn.Module):
21+
def __init__(self, num_steps, cuda=True):
22+
super(OriginalImplementationHistogramLoss, self).__init__()
23+
self.step = 2 / (num_steps - 1)
24+
self.eps = 1 / num_steps
25+
self.cuda = cuda
26+
self.t = torch.arange(-1, 1 + self.step, self.step).view(-1, 1)
27+
self.tsize = self.t.size()[0]
28+
if self.cuda:
29+
self.t = self.t.cuda()
30+
31+
def forward(self, features, classes):
32+
def histogram(inds, size):
33+
s_repeat_ = s_repeat.clone()
34+
indsa = (
35+
(s_repeat_floor - (self.t - self.step) > -self.eps)
36+
& (s_repeat_floor - (self.t - self.step) < self.eps)
37+
& inds
38+
)
39+
assert (
40+
indsa.nonzero().size()[0] == size
41+
), "Another number of bins should be used"
42+
zeros = torch.zeros((1, indsa.size()[1])).byte()
43+
if self.cuda:
44+
zeros = zeros.cuda()
45+
indsb = torch.cat((indsa, zeros))[1:, :].to(
46+
dtype=torch.bool
47+
) # Added to avoid bug with masks of uint8
48+
s_repeat_[~(indsb | indsa)] = 0
49+
# indsa corresponds to the first condition of the second equation of the paper
50+
self.t = self.t.to(
51+
dtype=s_repeat_.dtype
52+
) # Added to avoid errors when using Half precision
53+
s_repeat_[indsa] = (s_repeat_ - self.t + self.step)[indsa] / self.step
54+
# indsb corresponds to the second condition of the second equation of the paper
55+
s_repeat_[indsb] = (-s_repeat_ + self.t + self.step)[indsb] / self.step
56+
57+
return s_repeat_.sum(1) / size
58+
59+
classes_size = classes.size()[0]
60+
classes_eq = (
61+
classes.repeat(classes_size, 1)
62+
== classes.view(-1, 1).repeat(1, classes_size)
63+
).data
64+
dists = torch.mm(features, features.transpose(0, 1))
65+
assert (
66+
(dists > 1 + self.eps).sum().item() + (dists < -1 - self.eps).sum().item()
67+
) == 0, "L2 normalization should be used"
68+
s_inds = torch.triu(torch.ones(classes_eq.size()), 1).byte()
69+
if self.cuda:
70+
s_inds = s_inds.cuda()
71+
classes_eq = classes_eq.to(
72+
device=s_inds.device
73+
) # Added to avoid errors when using only cpu
74+
pos_inds = classes_eq[s_inds].repeat(self.tsize, 1)
75+
neg_inds = ~classes_eq[s_inds].repeat(self.tsize, 1)
76+
pos_size = classes_eq[s_inds].sum().item()
77+
neg_size = (~classes_eq[s_inds]).sum().item()
78+
s = dists[s_inds].view(1, -1)
79+
s_repeat = s.repeat(self.tsize, 1)
80+
s_repeat_floor = (torch.floor(s_repeat.data / self.step) * self.step).float()
81+
82+
histogram_pos = histogram(pos_inds, pos_size)
83+
assert_almost_equal(
84+
histogram_pos.sum().item(),
85+
1,
86+
decimal=1,
87+
err_msg="Not good positive histogram",
88+
verbose=True,
89+
)
90+
histogram_neg = histogram(neg_inds, neg_size)
91+
assert_almost_equal(
92+
histogram_neg.sum().item(),
93+
1,
94+
decimal=1,
95+
err_msg="Not good negative histogram",
96+
verbose=True,
97+
)
98+
histogram_pos_repeat = histogram_pos.view(-1, 1).repeat(
99+
1, histogram_pos.size()[0]
100+
)
101+
histogram_pos_inds = torch.tril(
102+
torch.ones(histogram_pos_repeat.size()), -1
103+
).byte()
104+
if self.cuda:
105+
histogram_pos_inds = histogram_pos_inds.cuda()
106+
histogram_pos_repeat[histogram_pos_inds] = 0
107+
histogram_pos_cdf = histogram_pos_repeat.sum(0)
108+
loss = torch.sum(histogram_neg * histogram_pos_cdf)
109+
110+
return loss
111+
112+
113+
class TestHistogramLoss(unittest.TestCase):
114+
def test_histogram_loss(self):
115+
for dtype in TEST_DTYPES:
116+
embeddings = torch.randn(
117+
5,
118+
32,
119+
requires_grad=True,
120+
dtype=dtype,
121+
).to(
122+
TEST_DEVICE
123+
) # 2D embeddings
124+
embeddings = torch.nn.functional.normalize(embeddings)
125+
labels = torch.LongTensor([0, 0, 1, 1, 2])
126+
127+
num_steps = 5 if dtype == torch.float16 else 21
128+
num_bins = num_steps - 1
129+
loss_func = HistogramLoss(n_bins=num_bins)
130+
131+
loss = loss_func(embeddings, labels)
132+
133+
original_loss_func = OriginalImplementationHistogramLoss(
134+
num_steps=num_steps
135+
)
136+
correct_loss = original_loss_func(embeddings, labels)
137+
138+
rtol = 1e-2 if dtype == torch.float16 else 1e-5
139+
self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol))
140+
141+
def test_with_no_valid_triplets(self):
142+
loss_funcA = HistogramLoss(n_bins=4)
143+
for dtype in TEST_DTYPES:
144+
embedding_angles = [0, 20, 40, 60, 80]
145+
embeddings = torch.tensor(
146+
[angle_to_coord(a) for a in embedding_angles],
147+
requires_grad=True,
148+
dtype=dtype,
149+
).to(
150+
TEST_DEVICE
151+
) # 2D embeddings
152+
labels = torch.LongTensor([0, 1, 2, 3, 4])
153+
lossA = loss_funcA(embeddings, labels)
154+
self.assertEqual(lossA, 0)
155+
156+
def test_assertion_raises(self):
157+
with self.assertRaises(AssertionError):
158+
_ = HistogramLoss()
159+
160+
with self.assertRaises(AssertionError):
161+
_ = HistogramLoss(n_bins=1, delta=0.5)
162+
163+
with self.assertRaises(AssertionError):
164+
_ = HistogramLoss(n_bins=10, delta=0.4)

0 commit comments

Comments
 (0)