diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index 83668e0ea..c98500815 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -115,11 +115,10 @@ template class PQFlashIndex DISKANN_DLLEXPORT void set_universal_label(const LabelT &label); private: - DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, uint32_t label_id); + DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id); std::unordered_map load_label_map(const std::string &map_file); DISKANN_DLLEXPORT void parse_label_file(const std::string &map_file, size_t &num_pts_labels); DISKANN_DLLEXPORT void get_label_file_metadata(std::string map_file, uint32_t &num_pts, uint32_t &num_total_labels); - DISKANN_DLLEXPORT inline int32_t get_filter_number(const LabelT &filter_label); DISKANN_DLLEXPORT void generate_random_labels(std::vector &labels, const uint32_t num_labels, const uint32_t nthreads); @@ -222,12 +221,11 @@ template class PQFlashIndex // filter support uint32_t *_pts_to_label_offsets = nullptr; - uint32_t *_pts_to_labels = nullptr; - tsl::robin_set _labels; + uint32_t *_pts_to_label_counts = nullptr; + LabelT *_pts_to_labels = nullptr; std::unordered_map> _filter_to_medoid_ids; - bool _use_universal_label; - uint32_t _universal_filter_num; - std::vector _filter_list; + bool _use_universal_label = false; + LabelT _universal_filter_label; tsl::robin_set _dummy_pts; tsl::robin_set _has_dummy_pts; tsl::robin_map _dummy_to_real_map; diff --git a/src/index.cpp b/src/index.cpp index b4ebe1dda..0b10cc9a0 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1754,6 +1754,10 @@ LabelT Index::get_converted_label(const std::string &raw_label) { return _label_map[raw_label]; } + if (_use_universal_label) + { + return _universal_label; + } std::stringstream stream; stream << "Unable to find label in the Label Map"; diskann::cerr << stream.str() << std::endl; diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index e76debcdf..e26df08d0 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -81,7 +81,10 @@ template PQFlashIndex::~PQFlashIndex() { delete[] _pts_to_label_offsets; } - + if (_pts_to_label_counts != nullptr) + { + delete[] _pts_to_label_counts; + } if (_pts_to_labels != nullptr) { delete[] _pts_to_labels; @@ -536,21 +539,6 @@ template void PQFlashIndex::use_medoids } } -template -inline int32_t PQFlashIndex::get_filter_number(const LabelT &filter_label) -{ - int idx = -1; - for (uint32_t i = 0; i < _filter_list.size(); i++) - { - if (_filter_list[i] == filter_label) - { - idx = i; - break; - } - } - return idx; -} - template void PQFlashIndex::generate_random_labels(std::vector &labels, const uint32_t num_labels, const uint32_t nthreads) @@ -559,30 +547,22 @@ void PQFlashIndex::generate_random_labels(std::vector &labels labels.clear(); labels.resize(num_labels); - uint64_t num_total_labels = - _pts_to_label_offsets[_num_points - 1] + _pts_to_labels[_pts_to_label_offsets[_num_points - 1]]; + uint64_t num_total_labels = _pts_to_label_offsets[_num_points - 1] + _pts_to_label_counts[_num_points - 1]; std::mt19937 gen(rd()); - std::uniform_int_distribution dis(0, num_total_labels); - - tsl::robin_set skip_locs; - for (uint32_t i = 0; i < _num_points; i++) + if (num_total_labels == 0) { - skip_locs.insert(_pts_to_label_offsets[i]); + std::stringstream stream; + stream << "No labels found in data. Not sampling random labels "; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } + std::uniform_int_distribution dis(0, num_total_labels - 1); #pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) for (int64_t i = 0; i < num_labels; i++) { - bool found_flag = false; - while (!found_flag) - { - uint64_t rnd_loc = dis(gen); - if (skip_locs.find(rnd_loc) == skip_locs.end()) - { - found_flag = true; - labels[i] = _filter_list[_pts_to_labels[rnd_loc]]; - } - } + uint64_t rnd_loc = dis(gen); + labels[i] = (LabelT)_pts_to_labels[rnd_loc]; } } @@ -613,6 +593,10 @@ LabelT PQFlashIndex::get_converted_label(const std::string &filter_la { return _label_map[filter_label]; } + if (_use_universal_label) + { + return _universal_filter_label; + } std::stringstream stream; stream << "Unable to find label in the Label Map"; diskann::cerr << stream.str() << std::endl; @@ -646,14 +630,14 @@ void PQFlashIndex::get_label_file_metadata(std::string map_file, uint } template -inline bool PQFlashIndex::point_has_label(uint32_t point_id, uint32_t label_id) +inline bool PQFlashIndex::point_has_label(uint32_t point_id, LabelT label_id) { uint32_t start_vec = _pts_to_label_offsets[point_id]; - uint32_t num_lbls = _pts_to_labels[start_vec]; + uint32_t num_lbls = _pts_to_label_counts[point_id]; bool ret_val = false; for (uint32_t i = 0; i < num_lbls; i++) { - if (_pts_to_labels[start_vec + 1 + i] == label_id) + if (_pts_to_labels[start_vec + i] == label_id) { ret_val = true; break; @@ -679,18 +663,18 @@ void PQFlashIndex::parse_label_file(const std::string &label_file, si get_label_file_metadata(label_file, num_pts_in_label_file, num_total_labels); _pts_to_label_offsets = new uint32_t[num_pts_in_label_file]; - _pts_to_labels = new uint32_t[num_pts_in_label_file + num_total_labels]; - uint32_t counter = 0; + _pts_to_label_counts = new uint32_t[num_pts_in_label_file]; + _pts_to_labels = new LabelT[num_total_labels]; + uint32_t labels_seen_so_far = 0; while (std::getline(infile, line)) { std::istringstream iss(line); std::vector lbls(0); - _pts_to_label_offsets[line_cnt] = counter; - uint32_t &num_lbls_in_cur_pt = _pts_to_labels[counter]; + _pts_to_label_offsets[line_cnt] = labels_seen_so_far; + uint32_t &num_lbls_in_cur_pt = _pts_to_label_counts[line_cnt]; num_lbls_in_cur_pt = 0; - counter++; getline(iss, token, '\t'); std::istringstream new_iss(token); while (getline(new_iss, token, ',')) @@ -698,19 +682,8 @@ void PQFlashIndex::parse_label_file(const std::string &label_file, si token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); LabelT token_as_num = (LabelT)std::stoul(token); - if (_labels.find(token_as_num) == _labels.end()) - { - _filter_list.emplace_back(token_as_num); - } - int32_t filter_num = get_filter_number(token_as_num); - if (filter_num == -1) - { - diskann::cout << "Error!! " << std::endl; - exit(-1); - } - _pts_to_labels[counter++] = filter_num; + _pts_to_labels[labels_seen_so_far++] = (LabelT)token_as_num; num_lbls_in_cur_pt++; - _labels.insert(token_as_num); } if (num_lbls_in_cur_pt == 0) @@ -726,16 +699,8 @@ void PQFlashIndex::parse_label_file(const std::string &label_file, si template void PQFlashIndex::set_universal_label(const LabelT &label) { - int32_t temp_filter_num = get_filter_number(label); - if (temp_filter_num == -1) - { - diskann::cout << "Error, could not find universal label." << std::endl; - } - else - { - _use_universal_label = true; - _universal_filter_num = (uint32_t)temp_filter_num; - } + _use_universal_label = true; + _universal_filter_label = label; } #ifdef EXEC_ENV_OLS @@ -1178,22 +1143,6 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t const uint32_t io_limit, const bool use_reorder_data, QueryStats *stats) { - int32_t filter_num = 0; - if (use_filter) - { - filter_num = get_filter_number(filter_label); - if (filter_num < 0) - { - if (!_use_universal_label) - { - return; - } - else - { - filter_num = _universal_filter_num; - } - } - } uint64_t num_sector_per_nodes = DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); if (beam_width > num_sector_per_nodes * defaults::MAX_N_SECTOR_READS) @@ -1443,7 +1392,8 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) continue; - if (use_filter && !point_has_label(id, filter_num) && !point_has_label(id, _universal_filter_num)) + if (use_filter && !(point_has_label(id, filter_label)) && + (!_use_universal_label || !point_has_label(id, _universal_filter_label))) continue; cmps++; float dist = dist_scratch[m]; @@ -1505,7 +1455,8 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) continue; - if (use_filter && !point_has_label(id, filter_num) && !point_has_label(id, _universal_filter_num)) + if (use_filter && !(point_has_label(id, filter_label)) && + (!_use_universal_label || !point_has_label(id, _universal_filter_label))) continue; cmps++; float dist = dist_scratch[m];