Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Support bitset filter for Brute Force #560

Merged
merged 21 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
1ba31da
[Feat] Support `bitset` filter for Brute Force
rhdong Jan 8, 2025
4e30bd2
TBR
rhdong Dec 9, 2024
cbc5d38
fix the bits type to uint32_t
rhdong Jan 8, 2025
3a5d4e0
skip half test cases when cuSparse version < 12.0.1
rhdong Jan 8, 2025
8a45192
Merge branch 'branch-25.02' into rhdong/bf-bitset
rhdong Jan 10, 2025
e79b1e3
Merge branch 'branch-25.02' into rhdong/bf-bitset
rhdong Jan 15, 2025
8c0031a
Merge branch 'branch-25.02' into rhdong/bf-bitset
rhdong Jan 15, 2025
4c53846
Revert "customized raft"
rhdong Jan 15, 2025
4a53e94
Merge remote-tracking branch 'rhdong/rhdong/bf-bitset' into rhdong/bf…
rhdong Jan 15, 2025
85d2dfc
Merge branch 'branch-25.02' into rhdong/bf-bitset
rhdong Jan 16, 2025
5ef5bc5
Merge branch 'branch-25.02' into rhdong/bf-bitset
rhdong Jan 21, 2025
f53d1ce
Merge branch 'branch-25.02' into rhdong/bf-bitset
rhdong Jan 22, 2025
9beb58f
Merge branch 'branch-25.02' into rhdong/bf-bitset
rhdong Jan 22, 2025
36bae13
optimization by comments
rhdong Jan 23, 2025
1fcc7de
Merge branch 'branch-25.02' into rhdong/bf-bitset
cjnolet Jan 24, 2025
b58f2a5
Merge branch 'branch-25.02' into rhdong/bf-bitset
rhdong Jan 27, 2025
6c7b583
add usage example for bitset_filter
rhdong Jan 27, 2025
7c4d50e
Merge remote-tracking branch 'origin/branch-25.02' into rhdong/bf-bitset
rhdong Jan 29, 2025
3ecccfb
Merge branch 'branch-25.02' into rhdong/bf-bitset
rhdong Jan 30, 2025
6cc5059
Merge branch 'branch-25.02' into rhdong/bf-bitset
rhdong Jan 30, 2025
4243fb4
reorder comments and remove filtering namespace
rhdong Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 140 additions & 25 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,15 +332,28 @@ auto build(raft::resources const& handle,
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
* calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
* pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
* eliminate entirely allocations happening within `search`:
* eliminate entirely allocations happening within `search`.
*
* Usage example:
* @code{.cpp}
* ...
* // Use the same allocator across multiple searches to reduce the number of
* // cuda memory allocations
* brute_force::search(handle, index, queries1, out_inds1, out_dists1);
* brute_force::search(handle, index, queries2, out_inds2, out_dists2);
* brute_force::search(handle, index, queries3, out_inds3, out_dists3);
* ...
* using namespace cuvs::neighbors;
*
* // use default index parameters
* brute_force::index_params index_params;
* // create and fill the index from a [N, D] dataset
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* // use default search parameters
* brute_force::search_params search_params;
* // create a bitset to filter the search
* auto removed_indices = raft::make_device_vector<int64_t, int64_t>(res, n_removed_indices);
* raft::core::bitset<std::uint32_t, int64_t> removed_indices_bitset(
* res, removed_indices.view(), dataset.extent(0));
* // search K nearest neighbours according to a bitset
* auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k);
* auto distances = raft::make_device_matrix<float>(res, n_queries, k);
* auto filter = filtering::bitset_filter(removed_indices_bitset.view());
* brute_force::search(res, search_params, index, queries, neighbors, distances, filter);
* @endcode
*
* @param[in] handle
Expand All @@ -350,9 +363,17 @@ auto build(raft::resources const& handle,
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
* @param[in] sample_filter An optional device bitmap filter function with a `row-major` layout and
* the shape of [n_queries, index->size()], which means the filter will use the first
* `index->size()` bits to indicate whether queries[0] should compute the distance with dataset.
* @param[in] sample_filter An optional device filter that restricts which dataset elements should
* be considered for each query.
*
* - Supports two types of filters:
* 1. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
* 2. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
* where each bit indicates whether a specific dataset element should be considered for a
* particular query. (1 for inclusion, 0 for exclusion).
*
* - The default value is `none_sample_filter`, which applies no filtering.
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
Expand All @@ -379,15 +400,28 @@ void search(raft::resources const& handle,
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
* calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
* pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
* eliminate entirely allocations happening within `search`:
* eliminate entirely allocations happening within `search`.
*
* Usage example:
* @code{.cpp}
* ...
* // Use the same allocator across multiple searches to reduce the number of
* // cuda memory allocations
* brute_force::search(handle, index, queries1, out_inds1, out_dists1);
* brute_force::search(handle, index, queries2, out_inds2, out_dists2);
* brute_force::search(handle, index, queries3, out_inds3, out_dists3);
* ...
* using namespace cuvs::neighbors;
*
* // use default index parameters
* brute_force::index_params index_params;
* // create and fill the index from a [N, D] dataset
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* // use default search parameters
* brute_force::search_params search_params;
* // create a bitset to filter the search
* auto removed_indices = raft::make_device_vector<int64_t, int64_t>(res, n_removed_indices);
* raft::core::bitset<std::uint32_t, int64_t> removed_indices_bitset(
* res, removed_indices.view(), dataset.extent(0));
* // search K nearest neighbours according to a bitset
* auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k);
* auto distances = raft::make_device_matrix<half>(res, n_queries, k);
* auto filter = filtering::bitset_filter(removed_indices_bitset.view());
* brute_force::search(res, search_params, index, queries, neighbors, distances, filter);
* @endcode
*
* @param[in] handle
Expand All @@ -397,8 +431,17 @@ void search(raft::resources const& handle,
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
* @param[in] sample_filter a optional device bitmap filter function that greenlights samples for a
* given
* @param[in] sample_filter An optional device filter that restricts which dataset elements should
* be considered for each query.
*
* - Supports two types of filters:
* 1. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
* 2. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
* where each bit indicates whether a specific dataset element should be considered for a
* particular query. (1 for inclusion, 0 for exclusion).
*
* - The default value is `none_sample_filter`, which applies no filtering.
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
Expand All @@ -421,15 +464,51 @@ void search(raft::resources const& handle,
*
* See the [brute_force::build](#brute_force::build) documentation for a usage example.
*
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
* calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
* pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
* eliminate entirely allocations happening within `search`.
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
*
* // use default index parameters
* brute_force::index_params index_params;
* // create and fill the index from a [N, D] dataset
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* // use default search parameters
* brute_force::search_params search_params;
* // create a bitset to filter the search
* auto removed_indices = raft::make_device_vector<int64_t, int64_t>(res, n_removed_indices);
* raft::core::bitset<std::uint32_t, int64_t> removed_indices_bitset(
* res, removed_indices.view(), dataset.extent(0));
* // search K nearest neighbours according to a bitset
* auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k);
* auto distances = raft::make_device_matrix<float>(res, n_queries, k);
* auto filter = filtering::bitset_filter(removed_indices_bitset.view());
* brute_force::search(res, search_params, index, queries, neighbors, distances, filter);
* @endcode
*
* @param[in] handle
* @param[in] params parameters configuring the search
* @param[in] index bruteforce constructed index
* @param[in] queries a device pointer to a col-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
* @param[in] sample_filter an optional device bitmap filter function that greenlights samples for a
* given query
* @param[in] sample_filter An optional device filter that restricts which dataset elements should
* be considered for each query.
*
* - Supports two types of filters:
* 1. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
* 2. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
* where each bit indicates whether a specific dataset element should be considered for a
* particular query. (1 for inclusion, 0 for exclusion).
*
* - The default value is `none_sample_filter`, which applies no filtering.
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
Expand All @@ -452,15 +531,51 @@ void search(raft::resources const& handle,
*
* See the [brute_force::build](#brute_force::build) documentation for a usage example.
*
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
* calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
* pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
* eliminate entirely allocations happening within `search`.
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
*
* // use default index parameters
* brute_force::index_params index_params;
* // create and fill the index from a [N, D] dataset
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* // use default search parameters
* brute_force::search_params search_params;
* // create a bitset to filter the search
* auto removed_indices = raft::make_device_vector<int64_t, int64_t>(res, n_removed_indices);
* raft::core::bitset<std::uint32_t, int64_t> removed_indices_bitset(
* res, removed_indices.view(), dataset.extent(0));
* // search K nearest neighbours according to a bitset
* auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k);
* auto distances = raft::make_device_matrix<half>(res, n_queries, k);
* auto filter = filtering::bitset_filter(removed_indices_bitset.view());
* brute_force::search(res, search_params, index, queries, neighbors, distances, filter);
* @endcode
*
* @param[in] handle
* @param[in] params parameters configuring the search
* @param[in] index bruteforce constructed index
* @param[in] queries a device pointer to a col-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
* @param[in] sample_filter an optional device bitmap filter function that greenlights samples for a
* given query
* @param[in] sample_filter An optional device filter that restricts which dataset elements should
* be considered for each query.
*
* - Supports two types of filters:
* 1. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
* 2. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
* where each bit indicates whether a specific dataset element should be considered for a
* particular query. (1 for inclusion, 0 for exclusion).
*
* - The default value is `none_sample_filter`, which applies no filtering.
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
Expand Down
34 changes: 29 additions & 5 deletions cpp/include/cuvs/neighbors/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <cstdint>
#include <cuvs/distance/distance.hpp>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/host_mdspan.hpp>
Expand Down Expand Up @@ -456,8 +457,11 @@ inline constexpr bool is_vpq_dataset_v = is_vpq_dataset<DatasetT>::value;

namespace filtering {

enum class FilterType { None, Bitmap, Bitset };

struct base_filter {
virtual ~base_filter() = default;
virtual ~base_filter() = default;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice no changes have been made to brute_force.hpp. Ideally, we'll at at least be listing out in the docs which filters are supported, right? Otherwise this is going to be very confusing for users. Also, can we set the default to bitset filter? I suspect most users will want bitset.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve just added the comments. I believe using bitset as the default setting might not be ideal if we don't have enough input from end-users. Perhaps we should discuss this in the team group, as I noticed that the none filter is also set as the default in CAGRA.

Copy link
Member

@cjnolet cjnolet Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe using bitset as the default setting might not be ideal if we don't have enough input from end-users.

I think you may have misunderstood me. The none filter is fine as the default for the the search functions, but for the code example in the docs, we should make sure we use a bitset and leave bitmap to users who need it. FAISS doesn't even support a bitmap and users aren't asking for it generally. It's good to keep it exposed for users who might need it.

virtual FilterType get_filter_type() const = 0;
};

/* A filter that filters nothing. This is the default behavior. */
Expand All @@ -475,6 +479,8 @@ struct none_sample_filter : public base_filter {
const uint32_t query_ix,
// the index of the current sample
const uint32_t sample_ix) const;

FilterType get_filter_type() const override { return FilterType::None; }
};

/**
Expand Down Expand Up @@ -513,15 +519,24 @@ struct ivf_to_sample_filter {
*/
template <typename bitmap_t, typename index_t>
struct bitmap_filter : public base_filter {
using view_t = cuvs::core::bitmap_view<bitmap_t, index_t>;

// View of the bitset to use as a filter
const cuvs::core::bitmap_view<bitmap_t, index_t> bitmap_view_;
const view_t bitmap_view_;

bitmap_filter(const cuvs::core::bitmap_view<bitmap_t, index_t> bitmap_for_filtering);
bitmap_filter(const view_t bitmap_for_filtering);
inline _RAFT_HOST_DEVICE bool operator()(
// query index
const uint32_t query_ix,
// the index of the current sample
const uint32_t sample_ix) const;

FilterType get_filter_type() const override { return FilterType::Bitmap; }

view_t view() const { return bitmap_view_; }

template <typename csr_matrix_t>
void to_csr(raft::resources const& handle, csr_matrix_t& csr);
};

/**
Expand All @@ -532,15 +547,24 @@ struct bitmap_filter : public base_filter {
*/
template <typename bitset_t, typename index_t>
struct bitset_filter : public base_filter {
using view_t = cuvs::core::bitset_view<bitset_t, index_t>;

// View of the bitset to use as a filter
const cuvs::core::bitset_view<bitset_t, index_t> bitset_view_;
const view_t bitset_view_;

bitset_filter(const cuvs::core::bitset_view<bitset_t, index_t> bitset_for_filtering);
bitset_filter(const view_t bitset_for_filtering);
inline _RAFT_HOST_DEVICE bool operator()(
// query index
const uint32_t query_ix,
// the index of the current sample
const uint32_t sample_ix) const;

FilterType get_filter_type() const override { return FilterType::Bitset; }

view_t view() const { return bitset_view_; }

template <typename csr_matrix_t>
void to_csr(raft::resources const& handle, csr_matrix_t& csr);
};

/**
Expand Down
14 changes: 7 additions & 7 deletions cpp/src/neighbors/brute_force_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ void _search(cuvsResources_t res,
using queries_mdspan_type = raft::device_matrix_view<T const, int64_t, QueriesLayoutT>;
using neighbors_mdspan_type = raft::device_matrix_view<int64_t, int64_t, raft::row_major>;
using distances_mdspan_type = raft::device_matrix_view<float, int64_t, raft::row_major>;
using prefilter_mds_type = raft::device_vector_view<const uint32_t, int64_t>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to keep the filter immutable, don' we?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is to be compatible with the bitset filter utilized in CAGRA and other algorithms. The immutable is implemented on the search API by restricting the filter to be const cuvs::neighbors::filtering::base_filter. Please refer to cagra, here

Copy link
Member

@cjnolet cjnolet Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using obj<const type> everywhere in the codebase for containers to immutable types so we should be doing that here as well. If the CAGRA API is not doing this, it should be.

using prefilter_bmp_type = cuvs::core::bitmap_view<const uint32_t, int64_t>;
using prefilter_mds_type = raft::device_vector_view<uint32_t, int64_t>;
using prefilter_bmp_type = cuvs::core::bitmap_view<uint32_t, int64_t>;

auto queries_mds = cuvs::core::from_dlpack<queries_mdspan_type>(queries_tensor);
auto neighbors_mds = cuvs::core::from_dlpack<neighbors_mdspan_type>(neighbors_tensor);
Expand All @@ -85,14 +85,14 @@ void _search(cuvsResources_t res,
distances_mds,
cuvs::neighbors::filtering::none_sample_filter{});
} else if (prefilter.type == BITMAP) {
auto prefilter_ptr = reinterpret_cast<DLManagedTensor*>(prefilter.addr);
auto prefilter_mds = cuvs::core::from_dlpack<prefilter_mds_type>(prefilter_ptr);
auto prefilter_view = cuvs::neighbors::filtering::bitmap_filter(
prefilter_bmp_type((const uint32_t*)prefilter_mds.data_handle(),
auto prefilter_ptr = reinterpret_cast<DLManagedTensor*>(prefilter.addr);
auto prefilter_mds = cuvs::core::from_dlpack<prefilter_mds_type>(prefilter_ptr);
const auto prefilter = cuvs::neighbors::filtering::bitmap_filter(
prefilter_bmp_type((uint32_t*)prefilter_mds.data_handle(),
queries_mds.extent(0),
index_ptr->dataset().extent(0)));
cuvs::neighbors::brute_force::search(
*res_ptr, params, *index_ptr, queries_mds, neighbors_mds, distances_mds, prefilter_view);
*res_ptr, params, *index_ptr, queries_mds, neighbors_mds, distances_mds, prefilter);
} else {
RAFT_FAIL("Unsupported prefilter type: BITSET");
}
Expand Down
Loading