Skip to content

Commit

Permalink
working draft PR for cleaning up disk based filter search (#414)
Browse files Browse the repository at this point in the history
* made changes to clean up filter number conversion, and fixed bug with universal filter search

* minor typecast fix

---------

Co-authored-by: rakri <[email protected]>
  • Loading branch information
rakri and rakri authored Aug 30, 2023
1 parent 353e538 commit fa6c279
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 88 deletions.
12 changes: 5 additions & 7 deletions include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,10 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
DISKANN_DLLEXPORT void set_universal_label(const LabelT &label);

private:
DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, uint32_t label_id);
DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id);
std::unordered_map<std::string, LabelT> load_label_map(const std::string &map_file);
DISKANN_DLLEXPORT void parse_label_file(const std::string &map_file, size_t &num_pts_labels);
DISKANN_DLLEXPORT void get_label_file_metadata(std::string map_file, uint32_t &num_pts, uint32_t &num_total_labels);
DISKANN_DLLEXPORT inline int32_t get_filter_number(const LabelT &filter_label);
DISKANN_DLLEXPORT void generate_random_labels(std::vector<LabelT> &labels, const uint32_t num_labels,
const uint32_t nthreads);

Expand Down Expand Up @@ -222,12 +221,11 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex

// filter support
uint32_t *_pts_to_label_offsets = nullptr;
uint32_t *_pts_to_labels = nullptr;
tsl::robin_set<LabelT> _labels;
uint32_t *_pts_to_label_counts = nullptr;
LabelT *_pts_to_labels = nullptr;
std::unordered_map<LabelT, std::vector<uint32_t>> _filter_to_medoid_ids;
bool _use_universal_label;
uint32_t _universal_filter_num;
std::vector<LabelT> _filter_list;
bool _use_universal_label = false;
LabelT _universal_filter_label;
tsl::robin_set<uint32_t> _dummy_pts;
tsl::robin_set<uint32_t> _has_dummy_pts;
tsl::robin_map<uint32_t, uint32_t> _dummy_to_real_map;
Expand Down
4 changes: 4 additions & 0 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1754,6 +1754,10 @@ LabelT Index<T, TagT, LabelT>::get_converted_label(const std::string &raw_label)
{
return _label_map[raw_label];
}
if (_use_universal_label)
{
return _universal_label;
}
std::stringstream stream;
stream << "Unable to find label in the Label Map";
diskann::cerr << stream.str() << std::endl;
Expand Down
113 changes: 32 additions & 81 deletions src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ template <typename T, typename LabelT> PQFlashIndex<T, LabelT>::~PQFlashIndex()
{
delete[] _pts_to_label_offsets;
}

if (_pts_to_label_counts != nullptr)
{
delete[] _pts_to_label_counts;
}
if (_pts_to_labels != nullptr)
{
delete[] _pts_to_labels;
Expand Down Expand Up @@ -536,21 +539,6 @@ template <typename T, typename LabelT> void PQFlashIndex<T, LabelT>::use_medoids
}
}

template <typename T, typename LabelT>
inline int32_t PQFlashIndex<T, LabelT>::get_filter_number(const LabelT &filter_label)
{
int idx = -1;
for (uint32_t i = 0; i < _filter_list.size(); i++)
{
if (_filter_list[i] == filter_label)
{
idx = i;
break;
}
}
return idx;
}

template <typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::generate_random_labels(std::vector<LabelT> &labels, const uint32_t num_labels,
const uint32_t nthreads)
Expand All @@ -559,30 +547,22 @@ void PQFlashIndex<T, LabelT>::generate_random_labels(std::vector<LabelT> &labels
labels.clear();
labels.resize(num_labels);

uint64_t num_total_labels =
_pts_to_label_offsets[_num_points - 1] + _pts_to_labels[_pts_to_label_offsets[_num_points - 1]];
uint64_t num_total_labels = _pts_to_label_offsets[_num_points - 1] + _pts_to_label_counts[_num_points - 1];
std::mt19937 gen(rd());
std::uniform_int_distribution<uint64_t> dis(0, num_total_labels);

tsl::robin_set<uint64_t> skip_locs;
for (uint32_t i = 0; i < _num_points; i++)
if (num_total_labels == 0)
{
skip_locs.insert(_pts_to_label_offsets[i]);
std::stringstream stream;
stream << "No labels found in data. Not sampling random labels ";
diskann::cerr << stream.str() << std::endl;
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}
std::uniform_int_distribution<uint64_t> dis(0, num_total_labels - 1);

#pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads)
for (int64_t i = 0; i < num_labels; i++)
{
bool found_flag = false;
while (!found_flag)
{
uint64_t rnd_loc = dis(gen);
if (skip_locs.find(rnd_loc) == skip_locs.end())
{
found_flag = true;
labels[i] = _filter_list[_pts_to_labels[rnd_loc]];
}
}
uint64_t rnd_loc = dis(gen);
labels[i] = (LabelT)_pts_to_labels[rnd_loc];
}
}

Expand Down Expand Up @@ -613,6 +593,10 @@ LabelT PQFlashIndex<T, LabelT>::get_converted_label(const std::string &filter_la
{
return _label_map[filter_label];
}
if (_use_universal_label)
{
return _universal_filter_label;
}
std::stringstream stream;
stream << "Unable to find label in the Label Map";
diskann::cerr << stream.str() << std::endl;
Expand Down Expand Up @@ -646,14 +630,14 @@ void PQFlashIndex<T, LabelT>::get_label_file_metadata(std::string map_file, uint
}

template <typename T, typename LabelT>
inline bool PQFlashIndex<T, LabelT>::point_has_label(uint32_t point_id, uint32_t label_id)
inline bool PQFlashIndex<T, LabelT>::point_has_label(uint32_t point_id, LabelT label_id)
{
uint32_t start_vec = _pts_to_label_offsets[point_id];
uint32_t num_lbls = _pts_to_labels[start_vec];
uint32_t num_lbls = _pts_to_label_counts[point_id];
bool ret_val = false;
for (uint32_t i = 0; i < num_lbls; i++)
{
if (_pts_to_labels[start_vec + 1 + i] == label_id)
if (_pts_to_labels[start_vec + i] == label_id)
{
ret_val = true;
break;
Expand All @@ -679,38 +663,27 @@ void PQFlashIndex<T, LabelT>::parse_label_file(const std::string &label_file, si
get_label_file_metadata(label_file, num_pts_in_label_file, num_total_labels);

_pts_to_label_offsets = new uint32_t[num_pts_in_label_file];
_pts_to_labels = new uint32_t[num_pts_in_label_file + num_total_labels];
uint32_t counter = 0;
_pts_to_label_counts = new uint32_t[num_pts_in_label_file];
_pts_to_labels = new LabelT[num_total_labels];
uint32_t labels_seen_so_far = 0;

while (std::getline(infile, line))
{
std::istringstream iss(line);
std::vector<uint32_t> lbls(0);

_pts_to_label_offsets[line_cnt] = counter;
uint32_t &num_lbls_in_cur_pt = _pts_to_labels[counter];
_pts_to_label_offsets[line_cnt] = labels_seen_so_far;
uint32_t &num_lbls_in_cur_pt = _pts_to_label_counts[line_cnt];
num_lbls_in_cur_pt = 0;
counter++;
getline(iss, token, '\t');
std::istringstream new_iss(token);
while (getline(new_iss, token, ','))
{
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
LabelT token_as_num = (LabelT)std::stoul(token);
if (_labels.find(token_as_num) == _labels.end())
{
_filter_list.emplace_back(token_as_num);
}
int32_t filter_num = get_filter_number(token_as_num);
if (filter_num == -1)
{
diskann::cout << "Error!! " << std::endl;
exit(-1);
}
_pts_to_labels[counter++] = filter_num;
_pts_to_labels[labels_seen_so_far++] = (LabelT)token_as_num;
num_lbls_in_cur_pt++;
_labels.insert(token_as_num);
}

if (num_lbls_in_cur_pt == 0)
Expand All @@ -726,16 +699,8 @@ void PQFlashIndex<T, LabelT>::parse_label_file(const std::string &label_file, si

template <typename T, typename LabelT> void PQFlashIndex<T, LabelT>::set_universal_label(const LabelT &label)
{
int32_t temp_filter_num = get_filter_number(label);
if (temp_filter_num == -1)
{
diskann::cout << "Error, could not find universal label." << std::endl;
}
else
{
_use_universal_label = true;
_universal_filter_num = (uint32_t)temp_filter_num;
}
_use_universal_label = true;
_universal_filter_label = label;
}

#ifdef EXEC_ENV_OLS
Expand Down Expand Up @@ -1178,22 +1143,6 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
const uint32_t io_limit, const bool use_reorder_data,
QueryStats *stats)
{
int32_t filter_num = 0;
if (use_filter)
{
filter_num = get_filter_number(filter_label);
if (filter_num < 0)
{
if (!_use_universal_label)
{
return;
}
else
{
filter_num = _universal_filter_num;
}
}
}

uint64_t num_sector_per_nodes = DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN);
if (beam_width > num_sector_per_nodes * defaults::MAX_N_SECTOR_READS)
Expand Down Expand Up @@ -1443,7 +1392,8 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end())
continue;

if (use_filter && !point_has_label(id, filter_num) && !point_has_label(id, _universal_filter_num))
if (use_filter && !(point_has_label(id, filter_label)) &&
(!_use_universal_label || !point_has_label(id, _universal_filter_label)))
continue;
cmps++;
float dist = dist_scratch[m];
Expand Down Expand Up @@ -1505,7 +1455,8 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end())
continue;

if (use_filter && !point_has_label(id, filter_num) && !point_has_label(id, _universal_filter_num))
if (use_filter && !(point_has_label(id, filter_label)) &&
(!_use_universal_label || !point_has_label(id, _universal_filter_label)))
continue;
cmps++;
float dist = dist_scratch[m];
Expand Down

0 comments on commit fa6c279

Please sign in to comment.