@@ -10,10 +10,15 @@ namespace hnswlib {
10
10
for (unsigned i = 0 ; i < qty; i++) {
11
11
res += ((float *) pVect1)[i] * ((float *) pVect2)[i];
12
12
}
13
- return ( 1 . 0f - res) ;
13
+ return res;
14
14
15
15
}
16
16
17
+ static float
18
+ InnerProductDistance (const void *pVect1, const void *pVect2, const void *qty_ptr) {
19
+ return 1 .0f - InnerProduct (pVect1, pVect2, qty_ptr);
20
+ }
21
+
17
22
#if defined(USE_AVX)
18
23
19
24
// Favor using AVX if available.
@@ -61,8 +66,13 @@ namespace hnswlib {
61
66
62
67
_mm_store_ps (TmpRes, sum_prod);
63
68
float sum = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ];;
64
- return 1 .0f - sum;
65
- }
69
+ return sum;
70
+ }
71
+
72
+ static float
73
+ InnerProductDistanceSIMD4ExtAVX (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
74
+ return 1 .0f - InnerProductSIMD4ExtAVX (pVect1v, pVect2v, qty_ptr);
75
+ }
66
76
67
77
#endif
68
78
@@ -121,7 +131,12 @@ namespace hnswlib {
121
131
_mm_store_ps (TmpRes, sum_prod);
122
132
float sum = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ];
123
133
124
- return 1 .0f - sum;
134
+ return sum;
135
+ }
136
+
137
+ static float
138
+ InnerProductDistanceSIMD4ExtSSE (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
139
+ return 1 .0f - InnerProductSIMD4ExtSSE (pVect1v, pVect2v, qty_ptr);
125
140
}
126
141
127
142
#endif
@@ -156,7 +171,12 @@ namespace hnswlib {
156
171
_mm512_store_ps (TmpRes, sum512);
157
172
float sum = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ] + TmpRes[4 ] + TmpRes[5 ] + TmpRes[6 ] + TmpRes[7 ] + TmpRes[8 ] + TmpRes[9 ] + TmpRes[10 ] + TmpRes[11 ] + TmpRes[12 ] + TmpRes[13 ] + TmpRes[14 ] + TmpRes[15 ];
158
173
159
- return 1 .0f - sum;
174
+ return sum;
175
+ }
176
+
177
+ static float
178
+ InnerProductDistanceSIMD16ExtAVX512 (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
179
+ return 1 .0f - InnerProductSIMD16ExtAVX512 (pVect1v, pVect2v, qty_ptr);
160
180
}
161
181
162
182
#endif
@@ -196,15 +216,20 @@ namespace hnswlib {
196
216
_mm256_store_ps (TmpRes, sum256);
197
217
float sum = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ] + TmpRes[4 ] + TmpRes[5 ] + TmpRes[6 ] + TmpRes[7 ];
198
218
199
- return 1 .0f - sum;
219
+ return sum;
220
+ }
221
+
222
+ static float
223
+ InnerProductDistanceSIMD16ExtAVX (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
224
+ return 1 .0f - InnerProductSIMD16ExtAVX (pVect1v, pVect2v, qty_ptr);
200
225
}
201
226
202
227
#endif
203
228
204
229
#if defined(USE_SSE)
205
230
206
- static float
207
- InnerProductSIMD16ExtSSE (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
231
+ static float
232
+ InnerProductSIMD16ExtSSE (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
208
233
float PORTABLE_ALIGN32 TmpRes[8 ];
209
234
float *pVect1 = (float *) pVect1v;
210
235
float *pVect2 = (float *) pVect2v;
@@ -245,17 +270,24 @@ namespace hnswlib {
245
270
_mm_store_ps (TmpRes, sum_prod);
246
271
float sum = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ];
247
272
248
- return 1 .0f - sum;
273
+ return sum;
274
+ }
275
+
276
+ static float
277
+ InnerProductDistanceSIMD16ExtSSE (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
278
+ return 1 .0f - InnerProductSIMD16ExtSSE (pVect1v, pVect2v, qty_ptr);
249
279
}
250
280
251
281
#endif
252
282
253
283
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
254
284
DISTFUNC<float > InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE;
255
285
DISTFUNC<float > InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE;
286
+ DISTFUNC<float > InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE;
287
+ DISTFUNC<float > InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE;
256
288
257
289
static float
258
- InnerProductSIMD16ExtResiduals (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
290
+ InnerProductDistanceSIMD16ExtResiduals (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
259
291
size_t qty = *((size_t *) qty_ptr);
260
292
size_t qty16 = qty >> 4 << 4 ;
261
293
float res = InnerProductSIMD16Ext (pVect1v, pVect2v, &qty16);
@@ -264,11 +296,11 @@ namespace hnswlib {
264
296
265
297
size_t qty_left = qty - qty16;
266
298
float res_tail = InnerProduct (pVect1, pVect2, &qty_left);
267
- return res + res_tail - 1 . 0f ;
299
+ return 1 . 0f - ( res + res_tail) ;
268
300
}
269
301
270
302
static float
271
- InnerProductSIMD4ExtResiduals (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
303
+ InnerProductDistanceSIMD4ExtResiduals (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
272
304
size_t qty = *((size_t *) qty_ptr);
273
305
size_t qty4 = qty >> 2 << 2 ;
274
306
@@ -279,7 +311,7 @@ namespace hnswlib {
279
311
float *pVect2 = (float *) pVect2v + qty4;
280
312
float res_tail = InnerProduct (pVect1, pVect2, &qty_left);
281
313
282
- return res + res_tail - 1 . 0f ;
314
+ return 1 . 0f - ( res + res_tail) ;
283
315
}
284
316
#endif
285
317
@@ -290,30 +322,37 @@ namespace hnswlib {
290
322
size_t dim_;
291
323
public:
292
324
InnerProductSpace (size_t dim) {
293
- fstdistfunc_ = InnerProduct ;
325
+ fstdistfunc_ = InnerProductDistance ;
294
326
#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
295
327
#if defined(USE_AVX512)
296
- if (AVX512Capable ())
328
+ if (AVX512Capable ()) {
297
329
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512;
298
- else if (AVXCapable ())
330
+ InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512;
331
+ } else if (AVXCapable ()) {
299
332
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
333
+ InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
334
+ }
300
335
#elif defined(USE_AVX)
301
- if (AVXCapable ())
336
+ if (AVXCapable ()) {
302
337
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
338
+ InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
339
+ }
303
340
#endif
304
341
#if defined(USE_AVX)
305
- if (AVXCapable ())
342
+ if (AVXCapable ()) {
306
343
InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX;
344
+ InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX;
345
+ }
307
346
#endif
308
347
309
348
if (dim % 16 == 0 )
310
- fstdistfunc_ = InnerProductSIMD16Ext ;
349
+ fstdistfunc_ = InnerProductDistanceSIMD16Ext ;
311
350
else if (dim % 4 == 0 )
312
- fstdistfunc_ = InnerProductSIMD4Ext ;
351
+ fstdistfunc_ = InnerProductDistanceSIMD4Ext ;
313
352
else if (dim > 16 )
314
- fstdistfunc_ = InnerProductSIMD16ExtResiduals ;
353
+ fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals ;
315
354
else if (dim > 4 )
316
- fstdistfunc_ = InnerProductSIMD4ExtResiduals ;
355
+ fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals ;
317
356
#endif
318
357
dim_ = dim;
319
358
data_size_ = dim * sizeof (float );
@@ -334,5 +373,4 @@ namespace hnswlib {
334
373
~InnerProductSpace () {}
335
374
};
336
375
337
-
338
376
}
0 commit comments