forked from facebookresearch/faiss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbench_pq_tables.py
78 lines (61 loc) · 2.17 KB
/
bench_pq_tables.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import time
import os
import numpy as np
import faiss
os.system("grep -m1 'model name' < /proc/cpuinfo")
def format_tab(x):
return "\n".join("\t".join("%g" % xi for xi in row) for row in x)
def run_bench(d, dsub, nbit=8, metric=None):
M = d // dsub
pq = faiss.ProductQuantizer(d, M, nbit)
pq.train(faiss.randn((max(1000, pq.ksub * 50), d), 123))
sp = faiss.swig_ptr
times = []
nrun = 100
print(f"d={d} dsub={dsub} ksub={pq.ksub}", end="\t")
res = []
for nx in 1, 10, 100:
x = faiss.randn((nx, d), 555)
times = []
for run in range(nrun):
t0 = time.time()
new_tab = np.zeros((nx, M, pq.ksub), "float32")
if metric == faiss.METRIC_INNER_PRODUCT:
pq.compute_inner_prod_tables(nx, sp(x), sp(new_tab))
elif metric == faiss.METRIC_L2:
pq.compute_distance_tables(nx, sp(x), sp(new_tab))
else:
assert False
t1 = time.time()
if run >= nrun // 5: # the rest is considered warmup
times.append((t1 - t0))
times = np.array(times) * 1000
print(f"nx={nx}: {np.mean(times):.3f} ms (± {np.std(times):.4f})",
end="\t")
res.append(times.mean())
print()
return res
# for have_threads in True, False:
for have_threads in False, True:
if have_threads:
# good config for Intel(R) Xeon(R) CPU E5-2698 v4 @ 2.20GHz
nthread = 32
else:
nthread = 1
faiss.omp_set_num_threads(nthread)
for metric in faiss.METRIC_INNER_PRODUCT, faiss.METRIC_L2:
print("============= nthread=", nthread, "metric=", metric)
allres = []
for dsub in 2, 4, 8:
for nbit in 4, 8:
for M in 8, 20:
res = run_bench(M * dsub, dsub, nbit, metric)
allres.append(res)
allres = np.array(allres)
print("formated result:")
print(format_tab(allres))