forked from facebookresearch/faiss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbench_ivf_selector.cpp
145 lines (127 loc) · 4.69 KB
/
bench_ivf_selector.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <omp.h>
#include <unistd.h>
#include <memory>
#include <faiss/IVFlib.h>
#include <faiss/IndexIVF.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/index_factory.h>
#include <faiss/index_io.h>
#include <faiss/utils/random.h>
#include <faiss/utils/utils.h>
/************************
* This benchmark attempts to measure the runtime overhead to use an IDSelector
* over doing an unconditional sequential scan. Unfortunately the results of the
* benchmark also depend a lot on the parallel_mode and the way
* search_with_parameters works.
*/
int main() {
using idx_t = faiss::idx_t;
int d = 64;
size_t nb = 1024 * 1024;
size_t nq = 512 * 16;
size_t k = 10;
std::vector<float> data((nb + nq) * d);
float* xb = data.data();
float* xq = data.data() + nb * d;
faiss::rand_smooth_vectors(nb + nq, d, data.data(), 1234);
std::unique_ptr<faiss::Index> index;
// const char *index_key = "IVF1024,Flat";
const char* index_key = "IVF1024,SQ8";
printf("index_key=%s\n", index_key);
std::string stored_name =
std::string("/tmp/bench_ivf_selector_") + index_key + ".faissindex";
if (access(stored_name.c_str(), F_OK) != 0) {
printf("creating index\n");
index.reset(faiss::index_factory(d, index_key));
double t0 = faiss::getmillisecs();
index->train(nb, xb);
double t1 = faiss::getmillisecs();
index->add(nb, xb);
double t2 = faiss::getmillisecs();
printf("Write %s\n", stored_name.c_str());
faiss::write_index(index.get(), stored_name.c_str());
} else {
printf("Read %s\n", stored_name.c_str());
index.reset(faiss::read_index(stored_name.c_str()));
}
faiss::IndexIVF* index_ivf = static_cast<faiss::IndexIVF*>(index.get());
index->verbose = true;
for (int tt = 0; tt < 3; tt++) {
if (tt == 1) {
index_ivf->parallel_mode = 3;
} else {
index_ivf->parallel_mode = 0;
}
if (tt == 2) {
printf("set single thread\n");
omp_set_num_threads(1);
}
printf("parallel_mode=%d\n", index_ivf->parallel_mode);
std::vector<float> D1(nq * k);
std::vector<idx_t> I1(nq * k);
{
double t2 = faiss::getmillisecs();
index->search(nq, xq, k, D1.data(), I1.data());
double t3 = faiss::getmillisecs();
printf("search time, no selector: %.3f ms\n", t3 - t2);
}
std::vector<float> D2(nq * k);
std::vector<idx_t> I2(nq * k);
{
double t2 = faiss::getmillisecs();
faiss::IVFSearchParameters params;
faiss::ivflib::search_with_parameters(
index.get(), nq, xq, k, D2.data(), I2.data(), ¶ms);
double t3 = faiss::getmillisecs();
printf("search time with nullptr selector: %.3f ms\n", t3 - t2);
}
FAISS_THROW_IF_NOT(I1 == I2);
FAISS_THROW_IF_NOT(D1 == D2);
{
double t2 = faiss::getmillisecs();
faiss::IVFSearchParameters params;
faiss::IDSelectorAll sel;
params.sel = &sel;
faiss::ivflib::search_with_parameters(
index.get(), nq, xq, k, D2.data(), I2.data(), ¶ms);
double t3 = faiss::getmillisecs();
printf("search time with selector: %.3f ms\n", t3 - t2);
}
FAISS_THROW_IF_NOT(I1 == I2);
FAISS_THROW_IF_NOT(D1 == D2);
std::vector<float> D3(nq * k);
std::vector<idx_t> I3(nq * k);
{
int nt = omp_get_max_threads();
double t2 = faiss::getmillisecs();
faiss::IVFSearchParameters params;
#pragma omp parallel for if (nt > 1)
for (idx_t slice = 0; slice < nt; slice++) {
idx_t i0 = nq * slice / nt;
idx_t i1 = nq * (slice + 1) / nt;
if (i1 > i0) {
faiss::ivflib::search_with_parameters(
index.get(),
i1 - i0,
xq + i0 * d,
k,
D3.data() + i0 * k,
I3.data() + i0 * k,
¶ms);
}
}
double t3 = faiss::getmillisecs();
printf("search time with null selector + manual parallel: %.3f ms\n",
t3 - t2);
}
FAISS_THROW_IF_NOT(I1 == I3);
FAISS_THROW_IF_NOT(D1 == D3);
}
return 0;
}