Skip to content

Commit 887e644

Browse files
committed
resolving some errors
1 parent 9402f01 commit 887e644

6 files changed

+133
-81
lines changed

include/abstract_filter_store.h

+5-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ template <typename label_type> class AbstractFilterStore
1616
DISKANN_DLLEXPORT virtual bool detect_common_filters(uint32_t point_id, bool search_invocation,
1717
const std::vector<label_type> &incoming_labels) = 0;
1818

19-
DISKANN_DLLEXPORT virtual const std::vector<label_type> &get_labels_by_point_id(const location_t point_id) = 0;
19+
DISKANN_DLLEXPORT virtual const std::vector<label_type> &get_labels_by_point(const location_t point_id) = 0;
2020
DISKANN_DLLEXPORT virtual const tsl::robin_set<label_type> &get_all_label_set() = 0;
2121
// Throws: out of range exception
2222
DISKANN_DLLEXPORT virtual void add_label_to_point(const location_t point_id, label_type label) = 0;
@@ -31,11 +31,11 @@ template <typename label_type> class AbstractFilterStore
3131

3232
// TODO: in future we may accept a set or vector of universal labels
3333
DISKANN_DLLEXPORT virtual void set_universal_label(label_type universal_label) = 0;
34-
DISKANN_DLLEXPORT virtual const label_type get_universal_label() const = 0;
34+
DISKANN_DLLEXPORT virtual void set_universal_labels(const std::vector<std::string> &universal_labels) = 0;
35+
// DISKANN_DLLEXPORT virtual const label_type get_universal_label() const = 0;
3536

3637
// takes raw label file and then genrate internal mapping file and keep the info of mapping
37-
DISKANN_DLLEXPORT virtual size_t load_raw_labels(const std::string &raw_labels_file,
38-
const std::string &universal_label) = 0;
38+
DISKANN_DLLEXPORT virtual size_t load_raw_labels(const std::string &raw_labels_file) = 0;
3939

4040
DISKANN_DLLEXPORT virtual void save_labels(const std::string &save_path, const size_t total_points) = 0;
4141
DISKANN_DLLEXPORT virtual void save_medoids(const std::string &save_path) = 0;
@@ -47,6 +47,7 @@ template <typename label_type> class AbstractFilterStore
4747
DISKANN_DLLEXPORT virtual size_t load_labels(const std::string &labels_file) = 0;
4848
DISKANN_DLLEXPORT virtual size_t load_medoids(const std::string &labels_to_medoid_file) = 0;
4949
DISKANN_DLLEXPORT virtual void load_label_map(const std::string &labels_map_file) = 0;
50+
DISKANN_DLLEXPORT virtual void load_universal_labels(const std::string &universal_labels_file) = 0;
5051

5152
private:
5253
size_t _num_points;

include/in_mem_filter_store.h

+7-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ template <typename label_type> class InMemFilterStore : public AbstractFilterSto
1414
bool detect_common_filters(uint32_t point_id, bool search_invocation,
1515
const std::vector<label_type> &incoming_labels) override;
1616

17-
const std::vector<label_type> &get_labels_by_point_id(const location_t point_id);
17+
const std::vector<label_type> &get_labels_by_point(const location_t point_id);
1818
const tsl::robin_set<label_type> &get_all_label_set();
1919
// Throws: out of range exception
2020
void add_label_to_point(const location_t point_id, label_type label);
@@ -28,10 +28,11 @@ template <typename label_type> class InMemFilterStore : public AbstractFilterSto
2828

2929
// TODO: in future we may accept a set or vector of universal labels
3030
void set_universal_label(label_type universal_label);
31-
const label_type get_universal_label() const;
31+
void set_universal_labels(const std::vector<std::string> &raw_universal_labels);
32+
// const label_type get_universal_label() const;
3233

3334
// ideally takes raw label file and then genrate internal mapping file and keep the info of mapping
34-
size_t load_raw_labels(const std::string &raw_labels_file, const std::string &universal_label);
35+
size_t load_raw_labels(const std::string &raw_labels_file);
3536

3637
void save_labels(const std::string &save_path, const size_t total_points);
3738
void save_medoids(const std::string &save_path);
@@ -43,6 +44,7 @@ template <typename label_type> class InMemFilterStore : public AbstractFilterSto
4344
size_t load_labels(const std::string &labels_file);
4445
size_t load_medoids(const std::string &labels_to_medoid_file);
4546
void load_label_map(const std::string &labels_map_file);
47+
void load_universal_labels(const std::string &universal_labels_file);
4648

4749
private:
4850
size_t _num_points;
@@ -53,10 +55,12 @@ template <typename label_type> class InMemFilterStore : public AbstractFilterSto
5355
// medoids
5456
std::unordered_map<label_type, uint32_t> _label_to_medoid_id;
5557
std::unordered_map<uint32_t, uint32_t> _medoid_counts; // medoids only happen for filtered index
58+
5659
// universal label
5760
bool _use_universal_label = false;
5861
label_type _universal_label = 0; // this is the internal mapping, may not always be true in future
5962
tsl::robin_set<label_type> _universal_labels_set;
63+
std::set<std::string> _raw_universal_labels;
6064

6165
// populates pts_to labels and _labels from given label file
6266
size_t parse_label_file(const std::string &label_file);

include/index.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,11 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
101101

102102
// Filtered Support
103103
DISKANN_DLLEXPORT void build_filtered_index(const char *filename, const std::string &label_file,
104-
const std::string &universal_label, const size_t num_points_to_load,
104+
const size_t num_points_to_load,
105105
const std::vector<TagT> &tags = std::vector<TagT>());
106106

107-
DISKANN_DLLEXPORT void set_universal_label(const LabelT &label);
107+
// DISKANN_DLLEXPORT void set_universal_label(const LabelT &label);
108+
DISKANN_DLLEXPORT void set_universal_labels(const std::vector<std::string> &raw_labels);
108109

109110
// Get converted integer label from string to int map (_label_map)
110111
DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &raw_label);

src/disk_utils.cpp

+34-27
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr
615615
double sampling_rate, double ram_budget, std::string mem_index_path,
616616
std::string medoids_file, std::string centroids_file, size_t build_pq_bytes, bool use_opq,
617617
uint32_t num_threads, bool use_filters, const std::string &label_file,
618-
const std::string &labels_to_medoids_file, const std::string &universal_label,
618+
const std::string &disk_labels_to_medoids_file, const std::string &universal_label,
619619
const uint32_t Lf)
620620
{
621621
size_t base_num, base_dim;
@@ -642,22 +642,22 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr
642642
_index.build(base_file.c_str(), base_num);
643643
else
644644
{
645-
// if (universal_label != "")
646-
//{ // indicates no universal label
647-
// LabelT unv_label_as_num = 0;
648-
// _index.set_universal_label(unv_label_as_num);
649-
// }
650-
_index.build_filtered_index(base_file.c_str(), label_file, universal_label, base_num);
645+
if (universal_label != "")
646+
{ // indicates no universal label
647+
// LabelT unv_label_as_num = 0;
648+
_index.set_universal_labels({universal_label});
649+
}
650+
_index.build_filtered_index(base_file.c_str(), label_file, base_num);
651651
}
652652
_index.save(mem_index_path.c_str());
653653

654654
if (use_filters)
655655
{
656656
// need to copy the labels_to_medoids file to the specified input
657657
// file
658-
std::remove(labels_to_medoids_file.c_str());
658+
std::remove(disk_labels_to_medoids_file.c_str());
659659
std::string mem_labels_to_medoid_file = mem_index_path + "_labels_to_medoids.txt";
660-
copy_file(mem_labels_to_medoid_file, labels_to_medoids_file);
660+
copy_file(mem_labels_to_medoid_file, disk_labels_to_medoids_file);
661661
std::remove(mem_labels_to_medoid_file.c_str());
662662
}
663663

@@ -712,12 +712,12 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr
712712
else
713713
{
714714
diskann::extract_shard_labels(label_file, shard_ids_file, shard_labels_file);
715-
// if (universal_label != "")
716-
//{ // indicates no universal label
717-
// LabelT unv_label_as_num = 0;
718-
// _index.set_universal_label(unv_label_as_num);
719-
// }
720-
_index.build_filtered_index(shard_base_file.c_str(), shard_labels_file, universal_label, shard_base_pts);
715+
if (universal_label != "")
716+
{ // indicates no universal label
717+
// LabelT unv_label_as_num = 0;
718+
_index.set_universal_labels({universal_label});
719+
}
720+
_index.build_filtered_index(shard_base_file.c_str(), shard_labels_file, shard_base_pts);
721721
}
722722
_index.save(shard_index_file.c_str());
723723
// copy universal label file from first shard to the final destination
@@ -738,7 +738,7 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr
738738
timer.reset();
739739
diskann::merge_shards(merged_index_prefix + "_subshard-", "_mem.index", merged_index_prefix + "_subshard-",
740740
"_ids_uint32.bin", num_parts, R, mem_index_path, medoids_file, use_filters,
741-
labels_to_medoids_file);
741+
disk_labels_to_medoids_file);
742742
diskann::cout << timer.elapsed_seconds_for_step("merging indices") << std::endl;
743743

744744
// delete tempFiles
@@ -1159,14 +1159,16 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
11591159
std::string pq_compressed_vectors_path = index_prefix_path + "_pq_compressed.bin";
11601160
std::string mem_index_path = index_prefix_path + "_mem.index";
11611161
std::string disk_index_path = index_prefix_path + "_disk.index";
1162+
11621163
std::string medoids_path = disk_index_path + "_medoids.bin";
11631164
std::string centroids_path = disk_index_path + "_centroids.bin";
11641165

1165-
std::string labels_to_medoids_path = disk_index_path + "_labels_to_medoids.txt";
1166+
std::string disk_labels_to_medoids_path = disk_index_path + "_labels_to_medoids.txt";
11661167
std::string mem_labels_file = mem_index_path + "_labels.txt";
11671168
std::string disk_labels_file = disk_index_path + "_labels.txt";
11681169
std::string mem_univ_label_file = mem_index_path + "_universal_label.txt";
11691170
std::string disk_univ_label_file = disk_index_path + "_universal_label.txt";
1171+
std::string mem_labels_int_map_file = mem_index_path + "_labels_map.txt";
11701172
std::string disk_labels_int_map_file = disk_index_path + "_labels_map.txt";
11711173
std::string dummy_remap_file = disk_index_path + "_dummy_remap.txt"; // remap will be used if we break-up points of
11721174
// high label-density to create copies
@@ -1232,19 +1234,19 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
12321234
std::string augmented_data_file, augmented_labels_file;
12331235
if (use_filters)
12341236
{
1235-
convert_labels_string_to_int(labels_file_original, labels_file_to_use, disk_labels_int_map_file,
1236-
universal_label);
1237+
/* convert_labels_string_to_int(labels_file_original, labels_file_to_use, disk_labels_int_map_file,
1238+
universal_label);*/
12371239
augmented_data_file = index_prefix_path + "_augmented_data.bin";
12381240
augmented_labels_file = index_prefix_path + "_augmented_labels.txt";
12391241
if (filter_threshold != 0)
12401242
{
12411243
dummy_remap_file = index_prefix_path + "_dummy_remap.txt";
1242-
breakup_dense_points<T>(data_file_to_use, labels_file_to_use, filter_threshold, augmented_data_file,
1244+
breakup_dense_points<T>(data_file_to_use, labels_file_original, filter_threshold, augmented_data_file,
12431245
augmented_labels_file,
12441246
dummy_remap_file); // RKNOTE: This has large memory footprint,
12451247
// need to make this streaming
12461248
data_file_to_use = augmented_data_file;
1247-
labels_file_to_use = augmented_labels_file;
1249+
labels_file_original = augmented_labels_file;
12481250
}
12491251
}
12501252

@@ -1287,10 +1289,10 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
12871289
#endif
12881290

12891291
timer.reset();
1290-
diskann::build_merged_vamana_index<T, LabelT>(data_file_to_use.c_str(), diskann::Metric::L2, L, R, p_val,
1291-
indexing_ram_budget, mem_index_path, medoids_path, centroids_path,
1292-
build_pq_bytes, use_opq, num_threads, use_filters, labels_file_to_use,
1293-
labels_to_medoids_path, universal_label, Lf);
1292+
diskann::build_merged_vamana_index<T, LabelT>(
1293+
data_file_to_use.c_str(), diskann::Metric::L2, L, R, p_val, indexing_ram_budget, mem_index_path, medoids_path,
1294+
centroids_path, build_pq_bytes, use_opq, num_threads, use_filters, labels_file_original,
1295+
disk_labels_to_medoids_path, universal_label, Lf);
12941296
diskann::cout << timer.elapsed_seconds_for_step("building merged vamana index") << std::endl;
12951297

12961298
timer.reset();
@@ -1315,16 +1317,21 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
13151317
gen_random_slice<T>(data_file_to_use.c_str(), sample_base_prefix, sample_sampling_rate);
13161318
if (use_filters)
13171319
{
1318-
copy_file(labels_file_to_use, disk_labels_file);
1320+
// copy labels file
1321+
copy_file(mem_labels_file, disk_labels_file);
13191322
std::remove(mem_labels_file.c_str());
1323+
// copy universal label
13201324
if (universal_label != "")
13211325
{
13221326
copy_file(mem_univ_label_file, disk_univ_label_file);
13231327
std::remove(mem_univ_label_file.c_str());
13241328
}
1329+
// copy map file
1330+
copy_file(mem_labels_int_map_file, disk_labels_int_map_file);
1331+
std::remove(mem_labels_int_map_file.c_str());
1332+
13251333
std::remove(augmented_data_file.c_str());
13261334
std::remove(augmented_labels_file.c_str());
1327-
std::remove(labels_file_to_use.c_str());
13281335
}
13291336

13301337
std::remove(mem_index_path.c_str());

0 commit comments

Comments
 (0)