Skip to content

Commit

Permalink
GetAllowLessThanKResults in context class
Browse files Browse the repository at this point in the history
  • Loading branch information
NeelamMahapatro committed Dec 22, 2023
1 parent 77b432f commit 8a3329e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
10 changes: 8 additions & 2 deletions include/parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ enum State : uint8_t
template <typename LabelT = uint32_t> class IndexSearchContext
{
public:
IndexSearchContext(uint32_t time_limit_in_microseconds = 0u, uint32_t io_limit = UINT32_MAX)
: _time_limit_in_microseconds(time_limit_in_microseconds), _io_limit(io_limit), _result_state(State::Unknown)
IndexSearchContext(uint32_t time_limit_in_microseconds = 0u, uint32_t io_limit = UINT32_MAX, bool allowLessThanKResults = false)
: _time_limit_in_microseconds(time_limit_in_microseconds), _io_limit(io_limit), _result_state(State::Unknown), _allowLessThankResults(allowLessThanKResults)
{
_use_filter = false;
_label = (LabelT)0;
Expand Down Expand Up @@ -198,6 +198,11 @@ template <typename LabelT = uint32_t> class IndexSearchContext
return _stats;
}

bool GetAllowLessThanKResults()
{
return _allowLessThankResults;
}

private:
uint32_t _time_limit_in_microseconds;
uint32_t _io_limit;
Expand All @@ -206,6 +211,7 @@ template <typename LabelT = uint32_t> class IndexSearchContext
LabelT _label;
Timer _timer;
QueryStats _stats;
bool _allowLessThankResults;
};

} // namespace diskann
12 changes: 10 additions & 2 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2251,10 +2251,14 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search(const T *query, con
break;
}

if (pos <= K)
if (pos <= K && context.GetAllowLessThanKResults())
{
context.SetState(State::Success);
}
else if(pos < K)
{
context.SetState(State::Failure);
}
else
{
context.SetState(State::Failure);
Expand Down Expand Up @@ -2386,10 +2390,14 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const
if (pos == K)
break;
}
if (pos <= K)
if (pos <= K && context.GetAllowLessThanKResults())
{
context.SetState(State::Success);
}
else if(pos < K)
{
context.SetState(State::Failure);
}
else
{
context.SetState(State::Failure);
Expand Down

0 comments on commit 8a3329e

Please sign in to comment.