From 8a3329e4291d9873c892de07f97246aff0e12a82 Mon Sep 17 00:00:00 2001 From: Neelam Mahapatro Date: Sat, 23 Dec 2023 00:27:17 +0530 Subject: [PATCH] GetAllowLessThanKResults in context class --- include/parameters.h | 10 ++++++++-- src/index.cpp | 12 ++++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/include/parameters.h b/include/parameters.h index edde5df9c..8e5c99455 100644 --- a/include/parameters.h +++ b/include/parameters.h @@ -139,8 +139,8 @@ enum State : uint8_t template class IndexSearchContext { public: - IndexSearchContext(uint32_t time_limit_in_microseconds = 0u, uint32_t io_limit = UINT32_MAX) - : _time_limit_in_microseconds(time_limit_in_microseconds), _io_limit(io_limit), _result_state(State::Unknown) + IndexSearchContext(uint32_t time_limit_in_microseconds = 0u, uint32_t io_limit = UINT32_MAX, bool allowLessThanKResults = false) + : _time_limit_in_microseconds(time_limit_in_microseconds), _io_limit(io_limit), _result_state(State::Unknown), _allowLessThankResults(allowLessThanKResults) { _use_filter = false; _label = (LabelT)0; @@ -198,6 +198,11 @@ template class IndexSearchContext return _stats; } + bool GetAllowLessThanKResults() + { + return _allowLessThankResults; + } + private: uint32_t _time_limit_in_microseconds; uint32_t _io_limit; @@ -206,6 +211,7 @@ template class IndexSearchContext LabelT _label; Timer _timer; QueryStats _stats; + bool _allowLessThankResults; }; } // namespace diskann diff --git a/src/index.cpp b/src/index.cpp index 9655a074e..544d2f31b 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2251,10 +2251,14 @@ std::pair Index::search(const T *query, con break; } - if (pos <= K) + if (pos <= K && context.GetAllowLessThanKResults()) { context.SetState(State::Success); } + else if(pos < K) + { + context.SetState(State::Failure); + } else { context.SetState(State::Failure); @@ -2386,10 +2390,14 @@ std::pair Index::search_with_filters(const if (pos == K) break; } - if (pos <= K) + if (pos <= K && context.GetAllowLessThanKResults()) { context.SetState(State::Success); } + else if(pos < K) + { + context.SetState(State::Failure); + } else { context.SetState(State::Failure);