forked from facebookresearch/faiss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbench_big_batch_ivf.py
109 lines (79 loc) · 2.44 KB
/
bench_big_batch_ivf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import time
import faiss
import numpy as np
from faiss.contrib.datasets import SyntheticDataset
from faiss.contrib.big_batch_search import big_batch_search
parser = argparse.ArgumentParser()
def aa(*args, **kwargs):
group.add_argument(*args, **kwargs)
group = parser.add_argument_group('dataset options')
aa('--dim', type=int, default=64)
aa('--size', default="S")
group = parser.add_argument_group('index options')
aa('--nlist', type=int, default=100)
aa('--factory_string', default="", help="overrides nlist")
aa('--k', type=int, default=10)
aa('--nprobe', type=int, default=5)
aa('--nt', type=int, default=-1, help="nb search threads")
aa('--method', default="pairwise_distances", help="")
args = parser.parse_args()
print("args:", args)
if args.size == "S":
ds = SyntheticDataset(32, 2000, 4000, 1000)
elif args.size == "M":
ds = SyntheticDataset(32, 20000, 40000, 10000)
elif args.size == "L":
ds = SyntheticDataset(32, 200000, 400000, 100000)
else:
raise RuntimeError(f"dataset size {args.size} not supported")
nlist = args.nlist
nprobe = args.nprobe
k = args.k
def tic(name):
global tictoc
tictoc = (name, time.time())
print(name, end="\r", flush=True)
def toc():
global tictoc
name, t0 = tictoc
dt = time.time() - t0
print(f"{name}: {dt:.3f} s")
return dt
print(f"dataset {ds}, {nlist=:} {nprobe=:} {k=:}")
if args.factory_string == "":
factory_string = f"IVF{nlist},Flat"
else:
factory_string = args.factory_string
print(f"instantiate {factory_string}")
index = faiss.index_factory(ds.d, factory_string)
if args.factory_string != "":
nlist = index.nlist
print("nlist", nlist)
tic("train")
index.train(ds.get_train())
toc()
tic("add")
index.add(ds.get_database())
toc()
if args.nt != -1:
print("setting nb of threads to", args.nt)
faiss.omp_set_num_threads(args.nt)
tic("reference search")
index.nprobe
index.nprobe = nprobe
Dref, Iref = index.search(ds.get_queries(), k)
t_ref = toc()
tic("block search")
Dnew, Inew = big_batch_search(
index, ds.get_queries(),
k, method=args.method, verbose=10
)
t_tot = toc()
assert (Inew != Iref).sum() / Iref.size < 1e-4
np.testing.assert_almost_equal(Dnew, Dref, decimal=4)
print(f"total block search time {t_tot:.3f} s, speedup {t_ref / t_tot:.3f}x")