forked from facebookresearch/faiss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbench_scalar_quantizer.py
82 lines (64 loc) · 2.65 KB
/
bench_scalar_quantizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import time
import numpy as np
import faiss
from datasets import load_sift1M
print("load data")
xb, xq, xt, gt = load_sift1M()
nq, d = xq.shape
ncent = 256
variants = [(name, getattr(faiss.ScalarQuantizer, name))
for name in dir(faiss.ScalarQuantizer)
if name.startswith('QT_')]
quantizer = faiss.IndexFlatL2(d)
# quantizer.add(np.zeros((1, d), dtype='float32'))
if False:
for name, qtype in [('flat', 0)] + variants:
print("============== test", name)
t0 = time.time()
if name == 'flat':
index = faiss.IndexIVFFlat(quantizer, d, ncent,
faiss.METRIC_L2)
else:
index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent,
qtype, faiss.METRIC_L2)
index.nprobe = 16
print("[%.3f s] train" % (time.time() - t0))
index.train(xt)
print("[%.3f s] add" % (time.time() - t0))
index.add(xb)
print("[%.3f s] search" % (time.time() - t0))
D, I = index.search(xq, 100)
print("[%.3f s] eval" % (time.time() - t0))
for rank in 1, 10, 100:
n_ok = (I[:, :rank] == gt[:, :1]).sum()
print("%.4f" % (n_ok / float(nq)), end=' ')
print()
if True:
for name, qtype in variants:
print("============== test", name)
for rsname, vals in [('RS_minmax',
[-0.4, -0.2, -0.1, -0.05, 0.0, 0.1, 0.5]),
('RS_meanstd', [0.8, 1.0, 1.5, 2.0, 3.0, 5.0, 10.0]),
('RS_quantiles', [0.02, 0.05, 0.1, 0.15]),
('RS_optim', [0.0])]:
for val in vals:
print("%-15s %5g " % (rsname, val), end=' ')
index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent,
qtype, faiss.METRIC_L2)
index.nprobe = 16
index.sq.rangestat = getattr(faiss.ScalarQuantizer,
rsname)
index.rangestat_arg = val
index.train(xt)
index.add(xb)
t0 = time.time()
D, I = index.search(xq, 100)
t1 = time.time()
for rank in 1, 10, 100:
n_ok = (I[:, :rank] == gt[:, :1]).sum()
print("%.4f" % (n_ok / float(nq)), end=' ')
print(" %.3f s" % (t1 - t0))