Skip to content

Commit 2822532

Browse files
authored
libraft and pylibraft API for CAGRA build and HNSW search (#2022)
Closes #1772 Authors: - Divye Gala (https://github.com/divyegala) - Corey J. Nolet (https://github.com/cjnolet) - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) URL: #2022
1 parent e272176 commit 2822532

34 files changed

+1691
-158
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ repos:
2727
types_or: [python, cython]
2828
additional_dependencies: ["flake8-force"]
2929
- repo: https://github.com/pre-commit/mirrors-mypy
30-
rev: 'v0.971'
30+
rev: 'v1.3.0'
3131
hooks:
3232
- id: mypy
3333
additional_dependencies: [types-cachetools]

cpp/CMakeLists.txt

+9
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ option(BUILD_SHARED_LIBS "Build raft shared libraries" ON)
5757
option(BUILD_TESTS "Build raft unit-tests" ON)
5858
option(BUILD_PRIMS_BENCH "Build raft C++ benchmark tests" OFF)
5959
option(BUILD_ANN_BENCH "Build raft ann benchmarks" OFF)
60+
option(BUILD_CAGRA_HNSWLIB "Build CAGRA+hnswlib interface" ON)
6061
option(CUDA_ENABLE_KERNELINFO "Enable kernel resource usage info" OFF)
6162
option(CUDA_ENABLE_LINEINFO
6263
"Enable the -lineinfo option for nvcc (useful for cuda-memcheck / profiler)" OFF
@@ -195,6 +196,10 @@ if(BUILD_PRIMS_BENCH OR BUILD_ANN_BENCH)
195196
rapids_cpm_gbench()
196197
endif()
197198

199+
if(BUILD_CAGRA_HNSWLIB)
200+
include(cmake/thirdparty/get_hnswlib.cmake)
201+
endif()
202+
198203
# ##################################################################################################
199204
# * raft ---------------------------------------------------------------------
200205
add_library(raft INTERFACE)
@@ -203,6 +208,9 @@ add_library(raft::raft ALIAS raft)
203208
target_include_directories(
204209
raft INTERFACE "$<BUILD_INTERFACE:${RAFT_SOURCE_DIR}/include>" "$<INSTALL_INTERFACE:include>"
205210
)
211+
if(BUILD_CAGRA_HNSWLIB)
212+
target_link_libraries(raft INTERFACE hnswlib::hnswlib)
213+
endif()
206214

207215
if(NOT BUILD_CPU_ONLY)
208216
# Keep RAFT as lightweight as possible. Only CUDA libs and rmm should be used in global target.
@@ -425,6 +433,7 @@ if(RAFT_COMPILE_LIBRARY)
425433
src/raft_runtime/neighbors/cagra_search.cu
426434
src/raft_runtime/neighbors/cagra_serialize.cu
427435
src/raft_runtime/neighbors/eps_neighborhood.cu
436+
src/raft_runtime/neighbors/hnsw.cpp
428437
src/raft_runtime/neighbors/ivf_flat_build.cu
429438
src/raft_runtime/neighbors/ivf_flat_search.cu
430439
src/raft_runtime/neighbors/ivf_flat_serialize.cu

cpp/bench/ann/CMakeLists.txt

+3-13
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,7 @@ endfunction()
225225

226226
if(RAFT_ANN_BENCH_USE_HNSWLIB)
227227
ConfigureAnnBench(
228-
NAME HNSWLIB PATH bench/ann/src/hnswlib/hnswlib_benchmark.cpp
229-
LINKS
230-
hnswlib::hnswlib
228+
NAME HNSWLIB PATH bench/ann/src/hnswlib/hnswlib_benchmark.cpp LINKS hnswlib::hnswlib
231229
)
232230

233231
endif()
@@ -276,12 +274,7 @@ endif()
276274

277275
if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB)
278276
ConfigureAnnBench(
279-
NAME
280-
RAFT_CAGRA_HNSWLIB
281-
PATH
282-
bench/ann/src/raft/raft_cagra_hnswlib.cu
283-
LINKS
284-
raft::compiled
277+
NAME RAFT_CAGRA_HNSWLIB PATH bench/ann/src/raft/raft_cagra_hnswlib.cu LINKS raft::compiled
285278
hnswlib::hnswlib
286279
)
287280
endif()
@@ -336,10 +329,7 @@ endif()
336329

337330
if(RAFT_ANN_BENCH_USE_GGNN)
338331
include(cmake/thirdparty/get_glog.cmake)
339-
ConfigureAnnBench(
340-
NAME GGNN PATH bench/ann/src/ggnn/ggnn_benchmark.cu
341-
LINKS glog::glog ggnn::ggnn
342-
)
332+
ConfigureAnnBench(NAME GGNN PATH bench/ann/src/ggnn/ggnn_benchmark.cu LINKS glog::glog ggnn::ggnn)
343333
endif()
344334

345335
# ##################################################################################################

cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ struct hnsw_dist_t<uint8_t> {
5252
using type = int;
5353
};
5454

55+
template <>
56+
struct hnsw_dist_t<int8_t> {
57+
using type = int;
58+
};
59+
5560
template <typename T>
5661
class HnswLib : public ANN<T> {
5762
public:
@@ -135,7 +140,7 @@ void HnswLib<T>::build(const T* dataset, size_t nrow, cudaStream_t)
135140
space_ = std::make_shared<hnswlib::L2Space>(dim_);
136141
}
137142
} else if constexpr (std::is_same_v<T, uint8_t>) {
138-
space_ = std::make_shared<hnswlib::L2SpaceI>(dim_);
143+
space_ = std::make_shared<hnswlib::L2SpaceI<T>>(dim_);
139144
}
140145

141146
appr_alg_ = std::make_shared<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>>(
@@ -205,7 +210,7 @@ void HnswLib<T>::load(const std::string& path_to_index)
205210
space_ = std::make_shared<hnswlib::L2Space>(dim_);
206211
}
207212
} else if constexpr (std::is_same_v<T, uint8_t>) {
208-
space_ = std::make_shared<hnswlib::L2SpaceI>(dim_);
213+
space_ = std::make_shared<hnswlib::L2SpaceI<T>>(dim_);
209214
}
210215

211216
appr_alg_ = std::make_shared<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>>(

cpp/cmake/patches/hnswlib.diff

+57
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,63 @@
105105
}
106106
}
107107
}
108+
diff --git a/hnswlib/space_l2.h b/hnswlib/space_l2.h
109+
index 4413537..c3240f3 100644
110+
--- a/hnswlib/space_l2.h
111+
+++ b/hnswlib/space_l2.h
112+
@@ -252,13 +252,14 @@ namespace hnswlib {
113+
~L2Space() {}
114+
};
115+
116+
+ template <typename T>
117+
static int
118+
L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) {
119+
120+
size_t qty = *((size_t *) qty_ptr);
121+
int res = 0;
122+
- unsigned char *a = (unsigned char *) pVect1;
123+
- unsigned char *b = (unsigned char *) pVect2;
124+
+ T *a = (T *) pVect1;
125+
+ T *b = (T *) pVect2;
126+
127+
qty = qty >> 2;
128+
for (size_t i = 0; i < qty; i++) {
129+
@@ -279,11 +280,12 @@ namespace hnswlib {
130+
return (res);
131+
}
132+
133+
+ template <typename T>
134+
static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) {
135+
size_t qty = *((size_t*)qty_ptr);
136+
int res = 0;
137+
- unsigned char* a = (unsigned char*)pVect1;
138+
- unsigned char* b = (unsigned char*)pVect2;
139+
+ T* a = (T*)pVect1;
140+
+ T* b = (T*)pVect2;
141+
142+
for(size_t i = 0; i < qty; i++)
143+
{
144+
@@ -294,6 +296,7 @@ namespace hnswlib {
145+
return (res);
146+
}
147+
148+
+ template <typename T>
149+
class L2SpaceI : public SpaceInterface<int> {
150+
151+
DISTFUNC<int> fstdistfunc_;
152+
@@ -302,10 +305,10 @@ namespace hnswlib {
153+
public:
154+
L2SpaceI(size_t dim) {
155+
if(dim % 4 == 0) {
156+
- fstdistfunc_ = L2SqrI4x;
157+
+ fstdistfunc_ = L2SqrI4x<T>;
158+
}
159+
else {
160+
- fstdistfunc_ = L2SqrI;
161+
+ fstdistfunc_ = L2SqrI<T>;
162+
}
163+
dim_ = dim;
164+
data_size_ = dim * sizeof(unsigned char);
108165
diff --git a/hnswlib/visited_list_pool.h b/hnswlib/visited_list_pool.h
109166
index 5e1a4a5..4195ebd 100644
110167
--- a/hnswlib/visited_list_pool.h

cpp/cmake/thirdparty/get_hnswlib.cmake

+5-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ function(find_and_configure_hnswlib)
3030
rapids_cpm_find(
3131
hnswlib ${PKG_VERSION}
3232
GLOBAL_TARGETS hnswlib::hnswlib
33+
BUILD_EXPORT_SET raft-exports
34+
INSTALL_EXPORT_SET raft-exports
3335
CPM_ARGS
3436
GIT_REPOSITORY ${PKG_REPOSITORY}
3537
GIT_TAG ${PKG_PINNED_TAG}
@@ -51,11 +53,13 @@ function(find_and_configure_hnswlib)
5153
# write export rules
5254
rapids_export(
5355
BUILD hnswlib
56+
VERSION ${PKG_VERSION}
5457
EXPORT_SET hnswlib-exports
5558
GLOBAL_TARGETS hnswlib
5659
NAMESPACE hnswlib::)
5760
rapids_export(
5861
INSTALL hnswlib
62+
VERSION ${PKG_VERSION}
5963
EXPORT_SET hnswlib-exports
6064
GLOBAL_TARGETS hnswlib
6165
NAMESPACE hnswlib::)
@@ -74,5 +78,5 @@ endif()
7478
find_and_configure_hnswlib(VERSION 0.6.2
7579
REPOSITORY ${RAFT_HNSWLIB_GIT_REPOSITORY}
7680
PINNED_TAG ${RAFT_HNSWLIB_GIT_TAG}
77-
EXCLUDE_FROM_ALL ON
81+
EXCLUDE_FROM_ALL OFF
7882
)

cpp/include/raft/neighbors/cagra_serialize.cuh

+20-14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023, NVIDIA CORPORATION.
2+
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -32,13 +32,14 @@ namespace raft::neighbors::cagra {
3232
*
3333
* @code{.cpp}
3434
* #include <raft/core/resources.hpp>
35+
* #include <raft/neighbors/cagra_serialize.hpp>
3536
*
3637
* raft::resources handle;
3738
*
3839
* // create an output stream
3940
* std::ostream os(std::cout.rdbuf());
40-
* // create an index with `auto index = cagra::build(...);`
41-
* raft::serialize(handle, os, index);
41+
* // create an index with `auto index = raft::cagra::build(...);`
42+
* raft::cagra::serialize(handle, os, index);
4243
* @endcode
4344
*
4445
* @tparam T data element type
@@ -66,13 +67,14 @@ void serialize(raft::resources const& handle,
6667
*
6768
* @code{.cpp}
6869
* #include <raft/core/resources.hpp>
70+
* #include <raft/neighbors/cagra_serialize.hpp>
6971
*
7072
* raft::resources handle;
7173
*
7274
* // create a string with a filepath
7375
* std::string filename("/path/to/index");
74-
* // create an index with `auto index = cagra::build(...);`
75-
* raft::serialize(handle, filename, index);
76+
* // create an index with `auto index = raft::cagra::build(...);`
77+
* raft::cagra::serialize(handle, filename, index);
7678
* @endcode
7779
*
7880
* @tparam T data element type
@@ -100,13 +102,14 @@ void serialize(raft::resources const& handle,
100102
*
101103
* @code{.cpp}
102104
* #include <raft/core/resources.hpp>
105+
* #include <raft/neighbors/cagra_serialize.hpp>
103106
*
104107
* raft::resources handle;
105108
*
106109
* // create an output stream
107110
* std::ostream os(std::cout.rdbuf());
108-
* // create an index with `auto index = cagra::build(...);`
109-
* raft::serialize_to_hnswlib(handle, os, index);
111+
* // create an index with `auto index = raft::cagra::build(...);`
112+
* raft::cagra::serialize_to_hnswlib(handle, os, index);
110113
* @endcode
111114
*
112115
* @tparam T data element type
@@ -120,25 +123,26 @@ void serialize(raft::resources const& handle,
120123
template <typename T, typename IdxT>
121124
void serialize_to_hnswlib(raft::resources const& handle,
122125
std::ostream& os,
123-
const index<T, IdxT>& index)
126+
const raft::neighbors::cagra::index<T, IdxT>& index)
124127
{
125128
detail::serialize_to_hnswlib<T, IdxT>(handle, os, index);
126129
}
127130

128131
/**
129-
* Write the CAGRA built index as a base layer HNSW index to file
132+
* Save a CAGRA build index in hnswlib base-layer-only serialized format
130133
*
131134
* Experimental, both the API and the serialization format are subject to change.
132135
*
133136
* @code{.cpp}
134137
* #include <raft/core/resources.hpp>
138+
* #include <raft/neighbors/cagra_serialize.hpp>
135139
*
136140
* raft::resources handle;
137141
*
138142
* // create a string with a filepath
139143
* std::string filename("/path/to/index");
140-
* // create an index with `auto index = cagra::build(...);`
141-
* raft::serialize_to_hnswlib(handle, filename, index);
144+
* // create an index with `auto index = raft::cagra::build(...);`
145+
* raft::cagra::serialize_to_hnswlib(handle, filename, index);
142146
* @endcode
143147
*
144148
* @tparam T data element type
@@ -152,7 +156,7 @@ void serialize_to_hnswlib(raft::resources const& handle,
152156
template <typename T, typename IdxT>
153157
void serialize_to_hnswlib(raft::resources const& handle,
154158
const std::string& filename,
155-
const index<T, IdxT>& index)
159+
const raft::neighbors::cagra::index<T, IdxT>& index)
156160
{
157161
detail::serialize_to_hnswlib<T, IdxT>(handle, filename, index);
158162
}
@@ -164,14 +168,15 @@ void serialize_to_hnswlib(raft::resources const& handle,
164168
*
165169
* @code{.cpp}
166170
* #include <raft/core/resources.hpp>
171+
* #include <raft/neighbors/cagra_serialize.hpp>
167172
*
168173
* raft::resources handle;
169174
*
170175
* // create an input stream
171176
* std::istream is(std::cin.rdbuf());
172177
* using T = float; // data element type
173178
* using IdxT = int; // type of the index
174-
* auto index = raft::deserialize<T, IdxT>(handle, is);
179+
* auto index = raft::cagra::deserialize<T, IdxT>(handle, is);
175180
* @endcode
176181
*
177182
* @tparam T data element type
@@ -195,14 +200,15 @@ index<T, IdxT> deserialize(raft::resources const& handle, std::istream& is)
195200
*
196201
* @code{.cpp}
197202
* #include <raft/core/resources.hpp>
203+
* #include <raft/neighbors/cagra_serialize.hpp>
198204
*
199205
* raft::resources handle;
200206
*
201207
* // create a string with a filepath
202208
* std::string filename("/path/to/index");
203209
* using T = float; // data element type
204210
* using IdxT = int; // type of the index
205-
* auto index = raft::deserialize<T, IdxT>(handle, filename);
211+
* auto index = raft::cagra::deserialize<T, IdxT>(handle, filename);
206212
* @endcode
207213
*
208214
* @tparam T data element type

0 commit comments

Comments
 (0)