Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Porting multi-filter OR search support from the DLVS branch #546

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 182 additions & 13 deletions apps/search_disk_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "common_includes.h"
#include <boost/program_options.hpp>

#include "utils.h"
#include "index.h"
#include "disk_utils.h"
#include "math_utils.h"
Expand All @@ -28,9 +29,126 @@
#endif

#define WARMUP false
#define DISKANN_DEBUG_INDIVIDUAL_RESULTS

namespace po = boost::program_options;

#ifdef DISKANN_DEBUG_PRINT_RETSET
void dump_retset(uint64_t test_id, uint64_t query_num, diskann::QueryStats *stats, const std::string &result_output_prefix)
{
std::stringstream ss;
if (stats != nullptr)
{
for (int i = 0; i < query_num; i++)
{
ss << i << "\t";
for (int j = 0; j < (stats + i)->query_retset.size(); j++)
{
ss << "(" << (stats + i)->query_retset[j].id << ", " << (stats + i)->query_retset[j].distance
<< "), ";
}
ss << std::endl;
}

}
std::string results_file = result_output_prefix + "_L" + std::to_string(test_id) + "_retset.tsv";
std::ofstream writer(results_file);
writer << ss.str() << std::endl;
writer.close();
}
#endif

#ifdef DISKANN_DEBUG_INDIVIDUAL_RESULTS
void dump_individual_results(uint64_t test_id, uint64_t query_num, uint32_t *gt_ids, float *gt_dists, uint64_t gt_dim,
const std::vector<uint32_t> &query_result_ids,
const std::vector<float> &query_result_dists, uint64_t recall_at,
const std::string &result_output_prefix)
{
uint32_t cumulative_dist_matches = 0;
uint32_t cumulative_id_matches = 0;
std::stringstream results_stream;
std::stringstream per_query_stats_stream;

per_query_stats_stream << "query_id\tid_matches\tdist_matches\ttotal_matches\trecall" << std::endl;
for (int qid = 0; qid < query_num; qid++)
{
results_stream << qid << "\t";
uint32_t per_query_dist_matches = 0;
uint32_t per_query_id_matches = 0;

for (uint64_t i = 0; i < recall_at; i++)
{
auto rindex = qid * recall_at + i;
results_stream << "(" << query_result_ids[rindex] << "," << query_result_dists[rindex] << ",";

bool id_match = false;
bool dist_match = false;
for (uint64_t j = 0; j < recall_at; j++)
{
auto gindex = qid * gt_dim + j;
if (query_result_ids[rindex] == gt_ids[gindex])
{
per_query_id_matches++;
id_match = true;
break;
}
else if (query_result_dists[rindex] / gt_dists[gindex] <= 1.0f)
{
per_query_dist_matches++;
dist_match = true;
break;
}
}
std::string code = "X";
if (id_match)
{
code = "I";
}
else if (dist_match)
{
code = "D";
}
results_stream << code << "),";
}

results_stream << std::endl;

cumulative_id_matches += per_query_id_matches;
cumulative_dist_matches += per_query_dist_matches;
per_query_stats_stream << qid << "\t" << per_query_id_matches << "\t" << per_query_dist_matches << "\t"
<< per_query_id_matches + per_query_dist_matches << "\t"
<< (per_query_id_matches + per_query_dist_matches) * 1.0f / recall_at << std::endl;
}
{

std::string results_file = result_output_prefix + "_L" + std::to_string(test_id) + "_results.tsv";
std::ofstream out(results_file);
out << results_stream.str() << std::endl;
}
{
std::string per_query_stats_file = result_output_prefix + "_L" + std::to_string(test_id) + "_query_stats.tsv";
std::ofstream out(per_query_stats_file);
out << per_query_stats_stream.str() << std::endl;
}
}

void write_gt_to_tsv(const std::string &cur_result_path, uint64_t query_num, uint32_t *gt_ids, float *gt_dists,
uint64_t gt_dim)
{
std::ofstream gt_out(cur_result_path + "_gt.tsv");
for (int i = 0; i < query_num; i++)
{
gt_out << i << "\t";
for (int j = 0; j < gt_dim; j++)
{
gt_out << "(" << gt_ids[i * gt_dim + j] << "," << gt_dists[i * gt_dim + j] << "),";
}
gt_out << std::endl;
}
}
#endif


void print_stats(std::string category, std::vector<float> percentiles, std::vector<float> results)
{
diskann::cout << std::setw(20) << category << ": " << std::flush;
Expand All @@ -47,6 +165,44 @@ void print_stats(std::string category, std::vector<float> percentiles, std::vect
diskann::cout << std::endl;
}

template<typename T, typename LabelT>
void parse_labels_of_query(const std::string &filters_for_query,
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> &pFlashIndex,
std::vector<LabelT> &label_ids_for_query)
{
std::vector<std::string> label_strs_for_query;
diskann::split_string(filters_for_query, FILTER_OR_SEPARATOR, label_strs_for_query);
for (auto &label_str_for_query : label_strs_for_query)
{
label_ids_for_query.push_back(pFlashIndex->get_converted_label(label_str_for_query));
}
}

template<typename T, typename LabelT>
void populate_label_ids(const std::vector<std::string> &filters_of_queries,
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> &pFlashIndex,
std::vector<std::vector<LabelT>> &label_ids_of_queries, bool apply_one_to_all, uint32_t query_count)
{
if (apply_one_to_all)
{
std::vector<LabelT> label_ids_of_query;
parse_labels_of_query(filters_of_queries[0], pFlashIndex, label_ids_of_query);
for (uint32_t i = 0; i < query_count; i++)
{
label_ids_of_queries.push_back(label_ids_of_query);
}
}
else
{
for (auto &filters_of_query : filters_of_queries)
{
std::vector<LabelT> label_ids_of_query;
parse_labels_of_query(filters_of_query, pFlashIndex, label_ids_of_query);
label_ids_of_queries.push_back(label_ids_of_query);
}
}
}

template <typename T, typename LabelT = uint32_t>
int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix,
const std::string &result_output_prefix, const std::string &query_file, std::string &gt_file,
Expand Down Expand Up @@ -173,6 +329,14 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
diskann::cout << "..done" << std::endl;
}

std::vector<std::vector<LabelT>> per_query_label_ids;
if (filtered_search)
{
populate_label_ids(query_filters, _pFlashIndex, per_query_label_ids, (query_filters.size() == 1), query_num );
}



diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
diskann::cout.precision(2);

Expand Down Expand Up @@ -236,19 +400,10 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
}
else
{
LabelT label_for_search;
if (query_filters.size() == 1)
{ // one label for all queries
label_for_search = _pFlashIndex->get_converted_label(query_filters[0]);
}
else
{ // one label for each query
label_for_search = _pFlashIndex->get_converted_label(query_filters[i]);
}
_pFlashIndex->cached_beam_search(
query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at),
query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search,
use_reorder_data, stats + i);
query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, per_query_label_ids[i],
search_io_limit, use_reorder_data, stats + i);
gopalrs marked this conversation as resolved.
Show resolved Hide resolved
}
}
auto e = std::chrono::high_resolution_clock::now();
Expand All @@ -270,25 +425,40 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
auto mean_cpuus = diskann::get_mean_stats<float>(stats, query_num,
[](const diskann::QueryStats &stats) { return stats.cpu_us; });

auto mean_hops = diskann::get_mean_stats<uint32_t>(
stats, query_num, [](const diskann::QueryStats &stats) { return stats.n_hops; });
gopalrs marked this conversation as resolved.
Show resolved Hide resolved

double recall = 0;
if (calc_recall_flag)
{
recall = diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim,
query_result_ids[test_id].data(), recall_at, recall_at);
best_recall = std::max(recall, best_recall);
}
#ifdef DISKANN_DEBUG_INDIVIDUAL_RESULTS
dump_individual_results(test_id, query_num, gt_ids, gt_dists, gt_dim, query_result_ids[test_id],
query_result_dists[test_id], recall_at, result_output_prefix);
#endif
#ifdef DISKANN_DEBUG_PRINT_RETSET
dump_retset(test_id, query_num, stats, result_output_prefix);
#endif

diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps
<< std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios
<< std::setw(16) << mean_cpuus;
if (calc_recall_flag)
{
diskann::cout << std::setw(16) << recall << std::endl;
diskann::cout << std::setw(16) << recall << std::endl ;
}
else
{
diskann::cout << std::endl;
}
delete[] stats;
}
#ifdef DISKANN_DEBUG_INDIVIDUAL_RESULTS
write_gt_to_tsv(result_output_prefix, query_num, gt_ids, gt_dists, gt_dim);
#endif

diskann::cout << "Done searching. Now saving results " << std::endl;
uint64_t test_id = 0;
Expand Down Expand Up @@ -443,7 +613,6 @@ int main(int argc, char **argv)
{
query_filters = read_file_to_vector_of_strings(query_filters_file);
}

try
{
if (!query_filters.empty() && label_type == "ushort")
Expand Down
4 changes: 4 additions & 0 deletions include/percentile_stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ struct QueryStats
unsigned n_cmps = 0; // # cmps
unsigned n_cache_hits = 0; // # cache_hits
unsigned n_hops = 0; // # search hops

#ifdef DISKANN_DEBUG_PRINT_RETSET
std::vector<Neighbor> query_retset; //copy of the retset to debug PQ distances.
#endif
};

template <typename T>
Expand Down
36 changes: 27 additions & 9 deletions include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT license.

#pragma once
#include <unordered_map>
#include "common_includes.h"

#include "aligned_file_reader.h"
Expand All @@ -17,6 +18,11 @@
#include "tsl/robin_set.h"

#define FULL_PRECISION_REORDER_MULTIPLIER 3
#define DEFAULT_VISITED_RESERVE_SIZE 4096
//default max filters per query is set to the same
//as what we expect Bing to provide. If this is overkill,
//it can be set by clients in the load() function
#define DEFAULT_MAX_FILTERS_PER_QUERY 4096

namespace diskann
{
Expand All @@ -29,19 +35,28 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
DISKANN_DLLEXPORT ~PQFlashIndex();

#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT int load(diskann::MemoryMappedFiles &files, uint32_t num_threads, const char *index_prefix);
DISKANN_DLLEXPORT int load(diskann::MemoryMappedFiles &files, uint32_t num_threads, const char *index_prefix,
uint32_t max_filters_per_query = DEFAULT_MAX_FILTERS_PER_QUERY);
#else
// load compressed data, and obtains the handle to the disk-resident index
DISKANN_DLLEXPORT int load(uint32_t num_threads, const char *index_prefix);
DISKANN_DLLEXPORT int load(uint32_t num_threads, const char *index_prefix,
uint32_t max_filters_per_query = DEFAULT_MAX_FILTERS_PER_QUERY);
#endif

DISKANN_DLLEXPORT void load_labels(const std::string &disk_index_filepath);
DISKANN_DLLEXPORT void load_label_medoid_map(const std::string &labels_to_medoids_filepath,
std::istream &medoid_stream);
DISKANN_DLLEXPORT void load_dummy_map(const std::string &dummy_map_filepath, std::istream &dummy_map_stream);

#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT int load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads,
const char *index_filepath, const char *pivots_filepath,
const char *compressed_filepath);
const char *compressed_filepath,
uint32_t max_filters_per_query);
#else
DISKANN_DLLEXPORT int load_from_separate_paths(uint32_t num_threads, const char *index_filepath,
const char *pivots_filepath, const char *compressed_filepath);
const char *pivots_filepath, const char *compressed_filepath,
uint32_t max_filters_per_query);
#endif

DISKANN_DLLEXPORT void load_cache_list(std::vector<uint32_t> &node_list);
Expand Down Expand Up @@ -77,7 +92,7 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex

DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search,
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
const bool use_filter, const LabelT &filter_label,
const bool use_filter, const std::vector<LabelT> &filter_labels,
const uint32_t io_limit, const bool use_reorder_data = false,
QueryStats *stats = nullptr);

Expand Down Expand Up @@ -110,13 +125,16 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex

protected:
DISKANN_DLLEXPORT void use_medoids_data_as_centroids();
DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = 4096);
DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = DEFAULT_VISITED_RESERVE_SIZE,
uint64_t max_filters_per_query = DEFAULT_MAX_FILTERS_PER_QUERY);

DISKANN_DLLEXPORT void set_universal_label(const LabelT &label);

private:
DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id);
std::unordered_map<std::string, LabelT> load_label_map(std::basic_istream<char> &infile);
DISKANN_DLLEXPORT inline bool point_has_any_label(uint32_t point_id, const std::vector<LabelT> &label_ids);
void load_label_map(std::basic_istream<char> &map_reader,
std::unordered_map<std::string, LabelT> &string_to_int_map);
DISKANN_DLLEXPORT void parse_label_file(std::basic_istream<char> &infile, size_t &num_pts_labels);
DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts,
uint32_t &num_total_labels);
Expand Down Expand Up @@ -181,7 +199,7 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
// chunk_size = chunk size of each dimension chunk
// pq_tables = float* [[2^8 * [chunk_size]] * _n_chunks]
uint8_t *data = nullptr;
uint64_t _n_chunks;
uint64_t _n_chunks = 0;
FixedChunkPQTable _pq_table;

// distance comparator
Expand All @@ -199,7 +217,7 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
// we can optionally have multiple starting points
uint32_t *_medoids = nullptr;
// defaults to 1
size_t _num_medoids;
size_t _num_medoids = 1;
// by default, it is empty. If there are multiple
// centroids, we pick the medoid corresponding to the
// closest centroid as the starting point of search
Expand Down
4 changes: 2 additions & 2 deletions include/scratch.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ template <typename T> class SSDQueryScratch : public AbstractScratch<T>
NeighborPriorityQueue retset;
std::vector<Neighbor> full_retset;

SSDQueryScratch(size_t aligned_dim, size_t visited_reserve);
SSDQueryScratch(size_t aligned_dim, size_t visited_reserve, size_t max_filters_per_query);
~SSDQueryScratch();

void reset();
Expand All @@ -162,7 +162,7 @@ template <typename T> class SSDThreadData
SSDQueryScratch<T> scratch;
IOContext ctx;

SSDThreadData(size_t aligned_dim, size_t visited_reserve);
SSDThreadData(size_t aligned_dim, size_t visited_reserve, size_t max_filters_per_query);
void clear();
};

Expand Down
Loading
Loading