From bc39cd06c0a1d668aaae9b97199d724438d617bf Mon Sep 17 00:00:00 2001 From: Harsha Vardhan Simhadri Date: Wed, 11 Aug 2021 16:36:44 -0700 Subject: [PATCH] Revised inner product (#10) * working towards inner product in memory indices * done with in-memory code * made the inner product distance function return std::float_max if negative * more changes for disk index support * on the way to disk index support for MIPS * works now, need to change the PQ generation for MIPS * now incorporated disk+memory search for inner product * support for mips and l2 * changed inner product to -IP rather than 1/IP * towards adding support for storing PQ vectors in disk index for very large data * towards adding support for storing PQ vectors in disk index for very large data * halfway through PQ-based disk search option * code compiles for disk index pq * fixed some bug * shards are written as and when necessary * sharding is now on demand * minor changes * fixed one malloc bug in parameters * added a vector analyzer util * added missing file * fixed a bug which used L2 instead of inner product in cached beam search * now setting up the normalizing approach * towards pre-processing data * working towards newer inner product * more changes to do MIPS by reducing to L2 with extra coordinate * cleaned up code a bit, need to test everything again * testing underway * added back saturate graph to create denser indices * now we dont sample a new test dataset every iteration for estimating sharding * now num_parts increases by 2 * cleaned up warnings in Debug mode compiler * working towards inner product in memory indices * done with in-memory code * made the inner product distance function return std::float_max if negative * more changes for disk index support * on the way to disk index support for MIPS * works now, need to change the PQ generation for MIPS * now incorporated disk+memory search for inner product * support for mips and l2 * changed inner product to -IP rather than 1/IP * towards adding support for storing PQ vectors in disk index for very large data * towards adding support for storing PQ vectors in disk index for very large data * halfway through PQ-based disk search option * code compiles for disk index pq * fixed some bug * shards are written as and when necessary * sharding is now on demand * minor changes * fixed one malloc bug in parameters * added a vector analyzer util * added missing file * fixed a bug which used L2 instead of inner product in cached beam search * now setting up the normalizing approach * towards pre-processing data * working towards newer inner product * more changes to do MIPS by reducing to L2 with extra coordinate * cleaned up code a bit, need to test everything again * testing underway * added back saturate graph to create denser indices * now we dont sample a new test dataset every iteration for estimating sharding * now num_parts increases by 2 * cleaned up warnings in Debug mode compiler * added a normalizer to vector analysis * fixed one bug for MIPS * addressed all comments of PR * fixed minor typos. now running unit tests * ran clang-format as it doesnt run by default due to LINUX flag not set anywhere * clang introduced a bug in distance.h, fixed itt * added unit tester partially * minor bugfix * finished unit tester * changed back training size to 100K for now, we can increase to 1M later if necessary * added comments for unit_tester.sh * added auto tuning parameters for unit tester * re-ran clang formatting * small change to unit tester * fixed minor bug in unit tester * fixed some formatting on unit tester * started code for range search support in pq_flash_index * added more code for range search in disk index * added range search support * tested range search on small dataset * Update memory_mapper.h * minor edits Co-authored-by: ravishankar --- include/aligned_file_reader.h | 4 +- include/aux_utils.h | 7 +- include/distance.h | 40 ++- include/exceptions.h | 2 +- include/index.h | 5 +- include/memory_mapper.h | 2 +- include/parameters.h | 3 + include/partition_and_pq.h | 28 +- include/percentile_stats.h | 2 +- include/pq_flash_index.h | 43 ++- include/pq_table.h | 91 +++++- include/timer.h | 2 +- include/utils.h | 224 ++++++++++++++- include/windows_aligned_file_reader.h | 5 +- src/aux_utils.cpp | 310 ++++++++++++++------- src/index.cpp | 47 +++- src/linux_aligned_file_reader.cpp | 6 +- src/partition_and_pq.cpp | 292 +++++++++++++++---- src/pq_flash_index.cpp | 259 ++++++++++++----- src/windows_aligned_file_reader.cpp | 2 +- tests/CMakeLists.txt | 10 + tests/build_disk_index.cpp | 33 ++- tests/build_memory_index.cpp | 42 ++- tests/range_search_disk_index.cpp | 324 ++++++++++++++++++++++ tests/search_disk_index.cpp | 60 ++-- tests/search_memory_index.cpp | 57 ++-- tests/test_incremental_index.cpp | 2 +- tests/utils/CMakeLists.txt | 22 +- tests/utils/bin_to_tsv.cpp | 8 +- tests/utils/compute_groundtruth.cpp | 76 +++-- tests/utils/create_disk_layout.cpp | 22 +- tests/utils/float_bin_to_int8.cpp | 1 - tests/utils/gen_random_slice.cpp | 23 +- tests/utils/partition_data.cpp | 7 +- tests/utils/partition_with_ram_budget.cpp | 7 +- tests/utils/tsv_to_bin.cpp | 7 +- tests/utils/uint8_to_float.cpp | 21 ++ tests/utils/vector_analysis.cpp | 148 ++++++++++ unit_tester.sh | 71 +++++ 39 files changed, 1905 insertions(+), 410 deletions(-) create mode 100644 tests/range_search_disk_index.cpp create mode 100644 tests/utils/uint8_to_float.cpp create mode 100644 tests/utils/vector_analysis.cpp create mode 100755 unit_tester.sh diff --git a/include/aligned_file_reader.h b/include/aligned_file_reader.h index c33bacd66..b6b56a012 100644 --- a/include/aligned_file_reader.h +++ b/include/aligned_file_reader.h @@ -18,7 +18,7 @@ typedef io_context_t IOContext; #include #ifndef USE_BING_INFRA -struct IOContext{ +struct IOContext { HANDLE fhandle = NULL; HANDLE iocp = NULL; std::vector reqs; @@ -77,7 +77,7 @@ struct AlignedRead { class AlignedFileReader { protected: tsl::robin_map ctx_map; - std::mutex ctx_mut; + std::mutex ctx_mut; public: // returns the thread-specific context diff --git a/include/aux_utils.h b/include/aux_utils.h index 031a64ba8..f980b7fc5 100644 --- a/include/aux_utils.h +++ b/include/aux_utils.h @@ -29,13 +29,15 @@ typedef int FileHandle; #include "common_includes.h" #include "utils.h" #include "windows_customizations.h" +#include "gperftools/malloc_extension.h" namespace diskann { - const size_t TRAINING_SET_SIZE = 1500000; + const size_t TRAINING_SET_SIZE = 100000; const double SPACE_FOR_CACHED_NODES_IN_GB = 0.25; const double THRESHOLD_FOR_CACHING_IN_GB = 1.0; const uint32_t NUM_NODES_TO_CACHE = 250000; const uint32_t WARMUP_L = 20; + const uint32_t NUM_KMEANS_REPS = 12; template class PQFlashIndex; @@ -44,6 +46,9 @@ namespace diskann { unsigned num_queries, unsigned *gold_std, float *gs_dist, unsigned dim_gs, unsigned *our_results, unsigned dim_or, unsigned recall_at); +DISKANN_DLLEXPORT double calculate_range_search_recall(unsigned num_queries, std::vector> &groundtruth, + std::vector> &our_results); + DISKANN_DLLEXPORT void read_idmap(const std::string & fname, std::vector &ivecs); diff --git a/include/distance.h b/include/distance.h index 3d403d1e5..086c4286d 100644 --- a/include/distance.h +++ b/include/distance.h @@ -255,11 +255,16 @@ namespace diskann { virtual float compare(const int8_t *a, const int8_t *b, unsigned int length) const { #ifndef _WINDOWS - std::cout << "AVX only supported in Windows build."; - return 0; + int32_t result = 0; +#pragma omp simd reduction(+ : result) aligned(a, b : 8) + for (_s32 i = 0; i < (_s32) length; i++) { + result += ((int32_t)((int16_t) a[i] - (int16_t) b[i])) * + ((int32_t)((int16_t) a[i] - (int16_t) b[i])); + } + return (float) result; } #else - __m128 r = _mm_setzero_ps(); + __m128 r = _mm_setzero_ps(); __m128i r1; while (length >= 16) { r1 = _mm_subs_epi8(_mm_load_si128((__m128i *) a), @@ -273,7 +278,7 @@ namespace diskann { float res = r.m128_f32[0]; if (length >= 8) { - __m128 r2 = _mm_setzero_ps(); + __m128 r2 = _mm_setzero_ps(); __m128i r3 = _mm_subs_epi8(_mm_load_si128((__m128i *) (a - 8)), _mm_load_si128((__m128i *) (b - 8))); r2 = _mm_add_ps(r2, _mm_mulhi_epi8(r3)); @@ -285,7 +290,7 @@ namespace diskann { } if (length >= 4) { - __m128 r2 = _mm_setzero_ps(); + __m128 r2 = _mm_setzero_ps(); __m128i r3 = _mm_subs_epi8(_mm_load_si128((__m128i *) (a - 12)), _mm_load_si128((__m128i *) (b - 12))); r2 = _mm_add_ps(r2, _mm_mulhi_epi8_shift32(r3)); @@ -302,8 +307,12 @@ namespace diskann { virtual float compare(const float *a, const float *b, unsigned int length) const { #ifndef _WINDOWS - std::cout << "AVX only supported in Windows build."; - return 0; + float result = 0; +#pragma omp simd reduction(+ : result) aligned(a, b : 8) + for (_s32 i = 0; i < (_s32) length; i++) { + result += (a[i] - b[i]) * (a[i] - b[i]); + } + return result; } #else __m128 diff, v1, v2; @@ -328,7 +337,7 @@ namespace diskann { template class DistanceInnerProduct : public Distance { public: - float compare(const T *a, const T *b, unsigned size) const { + float inner_product(const T *a, const T *b, unsigned size) const { float result = 0; #ifdef __GNUC__ #ifdef __AVX__ @@ -426,10 +435,21 @@ namespace diskann { #endif return result; } + float compare(const T *a, const T *b, unsigned size) + const { // since we use normally minimization objective for distance + // comparisons, we are returning 1/x. + float result = inner_product(a, b, size); + // if (result < 0) + // return std::numeric_limits::max(); + // else + return -result; + } }; template - class DistanceFastL2 : public DistanceInnerProduct { + class DistanceFastL2 + : public DistanceInnerProduct { // currently defined only for float. + // templated for future use. public: float norm(const T *a, unsigned size) const { float result = 0; @@ -522,7 +542,7 @@ namespace diskann { using DistanceInnerProduct::compare; float compare(const T *a, const T *b, float norm, unsigned size) const { // not implement - float result = -2 * DistanceInnerProduct::compare(a, b, size); + float result = -2 * DistanceInnerProduct::inner_product(a, b, size); result += norm; return result; } diff --git a/include/exceptions.h b/include/exceptions.h index 0323ac3dc..eefb0f69c 100644 --- a/include/exceptions.h +++ b/include/exceptions.h @@ -12,4 +12,4 @@ namespace diskann { : std::logic_error("Function not yet implemented.") { } }; -} +} // namespace diskann diff --git a/include/index.h b/include/index.h index eedbb1491..fdcac5faf 100644 --- a/include/index.h +++ b/include/index.h @@ -52,7 +52,7 @@ namespace diskann { // Gopal. Added search overload that takes L as parameter, so that we // can customize L on a per-query basis without tampering with "Parameters" - DISKANN_DLLEXPORT std::pair search(const T *query, + DISKANN_DLLEXPORT std::pair search(const T * query, const size_t K, const unsigned L, unsigned *indices); @@ -63,7 +63,7 @@ namespace diskann { DISKANN_DLLEXPORT std::pair search_with_tags( const T *query, const size_t K, const unsigned L, TagT *tags, - unsigned frozen_pts, unsigned *indices_buffer = NULL); + unsigned *indices_buffer = NULL); // repositions frozen points to the end of _data - if they have been moved // during deletion @@ -167,6 +167,7 @@ namespace diskann { size_t consolidate_deletes(const Parameters ¶meters); private: + Metric _metric = diskann::L2; size_t _dim; size_t _aligned_dim; T * _data; diff --git a/include/memory_mapper.h b/include/memory_mapper.h index 4ebe6ec62..4ccbf6f28 100644 --- a/include/memory_mapper.h +++ b/include/memory_mapper.h @@ -38,4 +38,4 @@ namespace diskann { ~MemoryMapper(); }; -} \ No newline at end of file +} // namespace diskann diff --git a/include/parameters.h b/include/parameters.h index 42e7d3a46..1cff66285 100644 --- a/include/parameters.h +++ b/include/parameters.h @@ -19,6 +19,9 @@ namespace diskann { template inline void Set(const std::string &name, const ParamType &value) { // ParamType *ptr = (ParamType *) malloc(sizeof(ParamType)); + if (params.find(name) != params.end()) { + free(params[name]); + } ParamType *ptr = new ParamType; *ptr = value; params[name] = (void *) ptr; diff --git a/include/partition_and_pq.h b/include/partition_and_pq.h index 45b1e26e6..7273dcf62 100644 --- a/include/partition_and_pq.h +++ b/include/partition_and_pq.h @@ -27,10 +27,9 @@ template void gen_random_slice(const T *inputdata, size_t npts, size_t ndims, double p_val, float *&sampled_data, size_t &slice_size); -template -int estimate_cluster_sizes(const std::string data_file, float *pivots, - const size_t num_centers, const size_t dim, - const size_t k_base, +int estimate_cluster_sizes(float *test_data_float, size_t num_test, + float *pivots, const size_t num_centers, + const size_t dim, const size_t k_base, std::vector &cluster_sizes); template @@ -38,6 +37,17 @@ int shard_data_into_clusters(const std::string data_file, float *pivots, const size_t num_centers, const size_t dim, const size_t k_base, std::string prefix_path); +template +int shard_data_into_clusters_only_ids(const std::string data_file, + float *pivots, const size_t num_centers, + const size_t dim, const size_t k_base, + std::string prefix_path); + +template +int retrieve_shard_data_from_ids(const std::string data_file, + std::string idmap_filename, + std::string data_filename); + template int partition(const std::string data_file, const float sampling_rate, size_t num_centers, size_t max_k_means_reps, @@ -49,12 +59,10 @@ int partition_with_ram_budget(const std::string data_file, size_t graph_degree, const std::string prefix_path, size_t k_base); -DISKANN_DLLEXPORT int generate_pq_pivots(const float *train_data, - size_t num_train, unsigned dim, - unsigned num_centers, - unsigned num_pq_chunks, - unsigned max_k_means_reps, - std::string pq_pivots_path); +DISKANN_DLLEXPORT int generate_pq_pivots( + const float *train_data, size_t num_train, unsigned dim, + unsigned num_centers, unsigned num_pq_chunks, unsigned max_k_means_reps, + std::string pq_pivots_path, bool make_zero_mean = false); template int generate_pq_data_from_pivots(const std::string data_file, diff --git a/include/percentile_stats.h b/include/percentile_stats.h index 808546c16..6a7b7cec7 100644 --- a/include/percentile_stats.h +++ b/include/percentile_stats.h @@ -58,4 +58,4 @@ namespace diskann { } return avg / len; } -} +} // namespace diskann diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index c7601851c..d119b8197 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -70,7 +70,8 @@ namespace diskann { // Freeing the reader object is now the client's (DiskANNInterface's) // responsibility. DISKANN_DLLEXPORT PQFlashIndex( - std::shared_ptr &fileReader); + std::shared_ptr &fileReader, + diskann::Metric metric = diskann::Metric::L2); DISKANN_DLLEXPORT ~PQFlashIndex(); #ifdef EXEC_ENV_OLS @@ -79,8 +80,8 @@ namespace diskann { const char *disk_index_file); #else // load compressed data, and obtains the handle to the disk-resident index - DISKANN_DLLEXPORT int load(uint32_t num_threads, const char *pq_prefix, - const char *disk_index_file); + DISKANN_DLLEXPORT int load(uint32_t num_threads, const char *pq_prefix, + const char *disk_index_file); #endif DISKANN_DLLEXPORT void load_cache_list(std::vector &node_list); @@ -112,10 +113,15 @@ namespace diskann { // implemented DISKANN_DLLEXPORT void cached_beam_search( const T *query, const _u64 k_search, const _u64 l_search, _u64 *res_ids, - float *res_dists, const _u64 beam_width, QueryStats *stats = nullptr, - Distance *output_dist_func = nullptr); - std::shared_ptr &reader; + float *res_dists, const _u64 beam_width, QueryStats *stats = nullptr); + + DISKANN_DLLEXPORT _u32 range_search(const T *query1, const double range, + const _u64 l_search, _u64* indices, float* distances, + const _u64 beam_width, + QueryStats *stats = nullptr); + + std::shared_ptr &reader; protected: DISKANN_DLLEXPORT void use_medoids_data_as_centroids(); DISKANN_DLLEXPORT void setup_thread_data(_u64 nthreads); @@ -129,12 +135,19 @@ namespace diskann { // nbrs of node `i`: ((unsigned*)buf) + 1 _u64 max_node_len = 0, nnodes_per_sector = 0, max_degree = 0; + diskann::Metric metric = diskann::Metric::L2; + float max_base_norm = + 0; // used only for inner product search to re-scale the result value + // (due to the pre-processing of base during index build) // data info _u64 num_points = 0; _u64 data_dim = 0; + _u64 disk_data_dim = 0; // will be different from data_dim only if we use + // PQ for disk data (very large dimensionality) _u64 aligned_dim = 0; + _u64 disk_bytes_per_point = 0; - std::string disk_index_file; + std::string disk_index_file; std::vector> node_visit_counter; // PQ data @@ -142,15 +155,19 @@ namespace diskann { // data: _u8 * n_chunks // chunk_size = chunk size of each dimension chunk // pq_tables = float* [[2^8 * [chunk_size]] * n_chunks] - _u8 * data = nullptr; - _u64 chunk_size; - _u64 n_chunks; - FixedChunkPQTable pq_table; + _u8 * data = nullptr; + _u64 n_chunks; + FixedChunkPQTable pq_table; // distance comparator Distance * dist_cmp = nullptr; Distance *dist_cmp_float = nullptr; + // for very large datasets: we use PQ even for the disk resident index + bool use_disk_index_pq = false; + _u64 disk_pq_n_chunks; + FixedChunkPQTable disk_pq_table; + // medoid/start info uint32_t *medoids = nullptr; // by default it is just one entry point of graph, we @@ -162,11 +179,11 @@ namespace diskann { // closest centroid as the starting point of search // nhood_cache - unsigned *nhood_cache_buf = nullptr; + unsigned * nhood_cache_buf = nullptr; tsl::robin_map<_u32, std::pair<_u32, _u32 *>> nhood_cache; // coord_cache - T *coord_cache_buf = nullptr; + T * coord_cache_buf = nullptr; tsl::robin_map<_u32, T *> coord_cache; // thread-specific scratch diff --git a/include/pq_table.h b/include/pq_table.h index 3cac23c15..84fe1501a 100644 --- a/include/pq_table.h +++ b/include/pq_table.h @@ -6,15 +6,14 @@ #include "utils.h" namespace diskann { - template class FixedChunkPQTable { // data_dim = n_chunks * chunk_size; float* tables = nullptr; // pq_tables = float* [[2^8 * [chunk_size]] * n_chunks] // _u64 n_chunks; // n_chunks = # of chunks ndims is split into // _u64 chunk_size; // chunk_size = chunk size of each dimension chunk - _u64 ndims; // ndims = chunk_size * n_chunks - _u64 n_chunks; + _u64 ndims = 0; // ndims = chunk_size * n_chunks + _u64 n_chunks = 0; _u32* chunk_offsets = nullptr; _u32* rearrangement = nullptr; float* centroid = nullptr; @@ -79,14 +78,15 @@ namespace diskann { #else diskann::load_bin<_u32>(chunk_offset_file, chunk_offsets, numr, numc); #endif - if (numc != 1 || numr != num_chunks + 1) { + if (numc != 1 || (numr != num_chunks + 1 && num_chunks != 0)) { diskann::cerr << "Error loading chunk offsets file. numc: " << numc << " (should be 1). numr: " << numr << " (should be " << num_chunks + 1 << ")" << std::endl; throw diskann::ANNException("Error loading chunk offsets file", -1, __FUNCSIG__, __FILE__, __LINE__); } - + std::cout << "PQ data has " << numr - 1 << " bytes per point." + << std::endl; this->n_chunks = numr - 1; #ifdef EXEC_ENV_OLS @@ -126,8 +126,11 @@ namespace diskann { } } - void - populate_chunk_distances(const T* query_vec, float* dist_vec) { + _u32 + get_num_chunks() { + return n_chunks; + } + void populate_chunk_distances(const float* query_vec, float* dist_vec) { memset(dist_vec, 0, 256 * n_chunks * sizeof(float)); // chunk wise distance computation for (_u64 chunk = 0; chunk < n_chunks; chunk++) { @@ -137,11 +140,6 @@ namespace diskann { _u64 permuted_dim_in_query = rearrangement[j]; const float* centers_dim_vec = tables_T + (256 * j); for (_u64 idx = 0; idx < 256; idx++) { - // Gopal. Fixing crash in v14 machines. - // float diff = centers_dim_vec[idx] - - // ((float) query_vec[permuted_dim_in_query] - - // centroid[permuted_dim_in_query]); - // chunk_dists[idx] += (diff * diff); double diff = centers_dim_vec[idx] - (query_vec[permuted_dim_in_query] - centroid[permuted_dim_in_query]); @@ -150,5 +148,72 @@ namespace diskann { } } } -}; + + float l2_distance(const float* query_vec, _u8* base_vec) { + float res = 0; + for (_u64 chunk = 0; chunk < n_chunks; chunk++) { + for (_u64 j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) { + _u64 permuted_dim_in_query = rearrangement[j]; + const float* centers_dim_vec = tables_T + (256 * j); + float diff = centers_dim_vec[base_vec[chunk]] - + (query_vec[permuted_dim_in_query] - + centroid[permuted_dim_in_query]); + res += diff * diff; + } + } + return res; + } + + float inner_product(const float* query_vec, _u8* base_vec) { + float res = 0; + for (_u64 chunk = 0; chunk < n_chunks; chunk++) { + for (_u64 j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) { + _u64 permuted_dim_in_query = rearrangement[j]; + const float* centers_dim_vec = tables_T + (256 * j); + float diff = + centers_dim_vec[base_vec[chunk]] * + query_vec[permuted_dim_in_query]; // assumes centroid is 0 to + // prevent translation errors + res += diff; + } + } + return -res; // returns negative value to simulate distances (max -> min + // conversion) + } + + void inflate_vector(_u8* base_vec, float* out_vec) { + for (_u64 chunk = 0; chunk < n_chunks; chunk++) { + for (_u64 j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) { + _u64 original_dim = rearrangement[j]; + const float* centers_dim_vec = tables_T + (256 * j); + out_vec[original_dim] = + centers_dim_vec[base_vec[chunk]] + centroid[original_dim]; + } + } + } + + void populate_chunk_inner_products(const float* query_vec, float* dist_vec) { + memset(dist_vec, 0, 256 * n_chunks * sizeof(float)); + // chunk wise distance computation + for (_u64 chunk = 0; chunk < n_chunks; chunk++) { + // sum (q-c)^2 for the dimensions associated with this chunk + float* chunk_dists = dist_vec + (256 * chunk); + for (_u64 j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) { + _u64 permuted_dim_in_query = rearrangement[j]; + const float* centers_dim_vec = tables_T + (256 * j); + for (_u64 idx = 0; idx < 256; idx++) { + double prod = + centers_dim_vec[idx] * + query_vec[permuted_dim_in_query]; // assumes that we are not + // shifting the vectors to mean + // zero, i.e., centroid array + // should be all zeros + chunk_dists[idx] -= + (float) prod; // returning negative to keep the search code clean + // (max inner product vs min distance) + } + } + } + } +}; // namespace diskann } // namespace diskann diff --git a/include/timer.h b/include/timer.h index 4671c33be..bf52ed883 100644 --- a/include/timer.h +++ b/include/timer.h @@ -22,4 +22,4 @@ namespace diskann { .count(); } }; -} +} // namespace diskann diff --git a/include/utils.h b/include/utils.h index 6b9db5bf6..9bf2e14f3 100644 --- a/include/utils.h +++ b/include/utils.h @@ -218,6 +218,12 @@ namespace diskann { } #endif + inline void wait_for_keystroke() { + int a; + std::cout << "Press any number to continue.." << std::endl; + std::cin >> a; + } + template inline void load_bin(const std::string& bin_file, T*& data, size_t& npts, size_t& dim) { @@ -289,6 +295,136 @@ namespace diskann { } } + inline void prune_truthset_for_range(const std::string& bin_file, float range, std::vector> &groundtruth, + size_t& npts) { + _u64 read_blk_size = 64 * 1024 * 1024; + cached_ifstream reader(bin_file, read_blk_size); + diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." + << std::endl; + size_t actual_file_size = reader.get_file_size(); + + int npts_i32, dim_i32; + reader.read((char*) &npts_i32, sizeof(int)); + reader.read((char*) &dim_i32, sizeof(int)); + npts = (unsigned) npts_i32; + _u64 dim = (unsigned) dim_i32; + _u32* ids; + float* dists; + + diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "..." + << std::endl; + + int truthset_type = -1; // 1 means truthset has ids and distances, 2 means + // only ids, -1 is error + size_t expected_file_size_with_dists = + 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_with_dists) + truthset_type = 1; + + if (truthset_type == -1) { + std::stringstream stream; + stream << "Error. File size mismatch. File should have bin format, with " + "npts followed by ngt followed by npts*ngt ids and optionally " + "followed by npts*ngt distance values; actual size: " + << actual_file_size + << ", expected: " << expected_file_size_with_dists; + diskann::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); + } + + ids = new uint32_t[npts * dim]; + reader.read((char*) ids, npts * dim * sizeof(uint32_t)); + + if (truthset_type == 1) { + dists = new float[npts * dim]; + reader.read((char*) dists, npts * dim * sizeof(float)); + } + float min_dist = std::numeric_limits::max(); + float max_dist = 0; + groundtruth.resize(npts); + for (_u32 i = 0; i < npts; i++) { + groundtruth[i].clear(); + for (_u32 j = 0; j < dim; j++) { + if (dists[i*dim + j] <= range) { + groundtruth[i].emplace_back(ids[i*dim+j]); + } + min_dist = min_dist > dists[i*dim+j] ? dists[i*dim + j] : min_dist; + max_dist = max_dist < dists[i*dim+j] ? dists[i*dim + j] : max_dist; + } + //std::cout<> &groundtruth, _u64 & gt_num) { + _u64 read_blk_size = 64 * 1024 * 1024; + cached_ifstream reader(bin_file, read_blk_size); + diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." + << std::endl; + size_t actual_file_size = reader.get_file_size(); + + int npts_u32, total_u32; + reader.read((char*) &npts_u32, sizeof(int)); + reader.read((char*) &total_u32, sizeof(int)); + + gt_num = (_u64) npts_u32; + _u64 total_res = (_u64) total_u32; + + diskann::cout << "Metadata: #pts = " << gt_num << ", #total_results = " << total_res << "..." + << std::endl; + + size_t expected_file_size = + 2*sizeof(_u32) + gt_num*sizeof(_u32) + total_res*sizeof(_u32); + + if (actual_file_size != expected_file_size) { + std::stringstream stream; + stream << "Error. File size mismatch in range truthset. actual size: " + << actual_file_size + << ", expected: " << expected_file_size; + diskann::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); + } + groundtruth.clear(); + groundtruth.resize(gt_num); + std::vector<_u32> gt_count(gt_num); + + + + reader.read((char*) gt_count.data(), sizeof(_u32)*gt_num); + + std::vector<_u32> gt_stats(gt_count); + std::sort(gt_stats.begin(), gt_stats.end()); + + std::cout<<"GT count percentiles:" << std::endl; + for (_u32 p = 0; p < 100; p += 5) + std::cout << "percentile " << p << ": " + << gt_stats[std::floor((p / 100.0) * gt_num)] << std::endl; + std::cout << "percentile 100" + << ": " << gt_stats[gt_num - 1] << std::endl; + + + for (_u32 i = 0; i < gt_num; i++) { + groundtruth[i].clear(); + groundtruth[i].resize(gt_count[i]); + if (gt_count[i]!=0) + reader.read((char*) groundtruth[i].data(), sizeof(_u32)*gt_count[i]); + +// debugging code +/* if (i < 10) { + std::cout< inline void load_bin(MemoryMappedFiles& files, const std::string& bin_file, @@ -410,6 +546,89 @@ namespace diskann { } } + // this function will take in_file of n*d dimensions and save the output as a + // floating point matrix + // with n*(d+1) dimensions. All vectors are scaled by a large value M so that + // the norms are <=1 and the final coordinate is set so that the resulting + // norm (in d+1 coordinates) is equal to 1 this is a classical transformation + // from MIPS to L2 search from "On Symmetric and Asymmetric LSHs for Inner + // Product Search" by Neyshabur and Srebro + + template + float prepare_base_for_inner_products(const std::string in_file, + const std::string out_file) { + std::cout << "Pre-processing base file by adding extra coordinate" + << std::endl; + std::ifstream in_reader(in_file.c_str(), std::ios::binary); + std::ofstream out_writer(out_file.c_str(), std::ios::binary); + _u64 npts, in_dims, out_dims; + float max_norm = 0; + + _u32 npts32, dims32; + in_reader.read((char*) &npts32, sizeof(uint32_t)); + in_reader.read((char*) &dims32, sizeof(uint32_t)); + + npts = npts32; + in_dims = dims32; + out_dims = in_dims + 1; + _u32 outdims32 = (_u32) out_dims; + + out_writer.write((char*) &npts32, sizeof(uint32_t)); + out_writer.write((char*) &outdims32, sizeof(uint32_t)); + + size_t BLOCK_SIZE = 100000; + size_t block_size = npts <= BLOCK_SIZE ? npts : BLOCK_SIZE; + std::unique_ptr in_block_data = + std::make_unique(block_size * in_dims); + std::unique_ptr out_block_data = + std::make_unique(block_size * out_dims); + + std::memset(out_block_data.get(), 0, sizeof(float) * block_size * out_dims); + _u64 num_blocks = DIV_ROUND_UP(npts, block_size); + + std::vector norms(npts, 0); + + for (_u64 b = 0; b < num_blocks; b++) { + _u64 start_id = b * block_size; + _u64 end_id = (b + 1) * block_size < npts ? (b + 1) * block_size : npts; + _u64 block_pts = end_id - start_id; + in_reader.read((char*) in_block_data.get(), + block_pts * in_dims * sizeof(T)); + for (_u64 p = 0; p < block_pts; p++) { + for (_u64 j = 0; j < in_dims; j++) { + norms[start_id + p] += + in_block_data[p * in_dims + j] * in_block_data[p * in_dims + j]; + } + max_norm = + max_norm > norms[start_id + p] ? max_norm : norms[start_id + p]; + } + } + + max_norm = std::sqrt(max_norm); + + in_reader.seekg(2 * sizeof(_u32), std::ios::beg); + for (_u64 b = 0; b < num_blocks; b++) { + _u64 start_id = b * block_size; + _u64 end_id = (b + 1) * block_size < npts ? (b + 1) * block_size : npts; + _u64 block_pts = end_id - start_id; + in_reader.read((char*) in_block_data.get(), + block_pts * in_dims * sizeof(T)); + for (_u64 p = 0; p < block_pts; p++) { + for (_u64 j = 0; j < in_dims; j++) { + out_block_data[p * out_dims + j] = + in_block_data[p * in_dims + j] / max_norm; + } + float res = 1 - (norms[start_id + p] / (max_norm * max_norm)); + res = res <= 0 ? 0 : std::sqrt(res); + out_block_data[p * out_dims + out_dims - 1] = res; + } + out_writer.write((char*) out_block_data.get(), + block_pts * out_dims * sizeof(float)); + } + out_writer.close(); + return max_norm; + } + // plain saves data as npts X ndims array into filename template void save_Tvecs(const char* filename, T* data, size_t npts, size_t ndims) { @@ -495,8 +714,9 @@ inline bool validate_file_size(const std::string& name) { size_t expected_file_size; in.read((char*) &expected_file_size, sizeof(uint64_t)); if (actual_file_size != expected_file_size) { - diskann::cout << "Error loading" << name << ". Expected " - "size (metadata): " + diskann::cout << "Error loading" << name + << ". Expected " + "size (metadata): " << expected_file_size << ", actual file size : " << actual_file_size << ". Exitting." << std::endl; diff --git a/include/windows_aligned_file_reader.h b/include/windows_aligned_file_reader.h index 8fec3d4f0..433d3c0bf 100644 --- a/include/windows_aligned_file_reader.h +++ b/include/windows_aligned_file_reader.h @@ -31,7 +31,7 @@ class WindowsAlignedFileReader : public AlignedFileReader { // Open & close ops // Blocking calls DISKANN_DLLEXPORT virtual void open(const std::string &fname); - DISKANN_DLLEXPORT virtual void close(); + DISKANN_DLLEXPORT virtual void close(); DISKANN_DLLEXPORT virtual void register_thread(); DISKANN_DLLEXPORT virtual void deregister_thread() { @@ -41,8 +41,7 @@ class WindowsAlignedFileReader : public AlignedFileReader { // process batch of aligned requests in parallel // NOTE :: blocking call for the calling thread, but can thread-safe DISKANN_DLLEXPORT virtual void read(std::vector &read_reqs, - IOContext &ctx, - bool async); + IOContext &ctx, bool async); }; #endif // USE_BING_INFRA #endif //_WINDOWS diff --git a/src/aux_utils.cpp b/src/aux_utils.cpp index 6a2990ba6..f143b88ea 100644 --- a/src/aux_utils.cpp +++ b/src/aux_utils.cpp @@ -56,7 +56,9 @@ namespace diskann { } gt.insert(gt_vec, gt_vec + tie_breaker); - res.insert(res_vec, res_vec + recall_at); + res.insert(res_vec, + res_vec + recall_at); // change to recall_at for recall k@k or + // dim_or for k@dim_or unsigned cur_recall = 0; for (auto &v : gt) { if (res.find(v) != res.end()) { @@ -68,6 +70,31 @@ namespace diskann { return total_recall / (num_queries) * (100.0 / recall_at); } + double calculate_range_search_recall(unsigned num_queries, std::vector> &groundtruth, + std::vector> &our_results) { + double total_recall = 0; + std::set gt, res; + + for (size_t i = 0; i < num_queries; i++) { + gt.clear(); + res.clear(); + + gt.insert(groundtruth[i].begin(), groundtruth[i].end()); + res.insert(our_results[i].begin(), our_results[i].end()); + unsigned cur_recall = 0; + for (auto &v : gt) { + if (res.find(v) != res.end()) { + cur_recall++; + } + } + if (gt.size() != 0) + total_recall += ((100.0*cur_recall)/gt.size()); + else + total_recall += 100; + } + return total_recall / (num_queries); + } + template T *generateRandomWarmup(uint64_t warmup_num, uint64_t warmup_dim, uint64_t warmup_aligned_dim) { @@ -152,9 +179,8 @@ namespace diskann { std::ifstream reader(fname.c_str(), std::ios::binary); reader.read((char *) &npts32, sizeof(uint32_t)); reader.read((char *) &dim, sizeof(uint32_t)); - if (dim != 1 || - actual_file_size != - ((size_t) npts32) * sizeof(uint32_t) + 2 * sizeof(uint32_t)) { + if (dim != 1 || actual_file_size != ((size_t) npts32) * sizeof(uint32_t) + + 2 * sizeof(uint32_t)) { std::stringstream stream; stream << "Error reading idmap file. Check if the file is bin file with " "1 dimensional data. Actual: " @@ -209,11 +235,11 @@ namespace diskann { node_shard.push_back(std::make_pair((_u32) node_id, (_u32) shard)); } } - std::sort(node_shard.begin(), node_shard.end(), [](const auto &left, - const auto &right) { - return left.first < right.first || - (left.first == right.first && left.second < right.second); - }); + std::sort(node_shard.begin(), node_shard.end(), + [](const auto &left, const auto &right) { + return left.first < right.first || (left.first == right.first && + left.second < right.second); + }); diskann::cout << "Finished computing node -> shards map" << std::endl; // create cached vamana readers @@ -343,7 +369,7 @@ namespace diskann { template int build_merged_vamana_index(std::string base_file, - diskann::Metric _compareMetric, unsigned L, + diskann::Metric compareMetric, unsigned L, unsigned R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_file, @@ -354,20 +380,21 @@ namespace diskann { double full_index_ram = ESTIMATE_RAM_USAGE(base_num, base_dim, sizeof(T), R); if (full_index_ram < ram_budget * 1024 * 1024 * 1024) { - diskann::cout << "Full index fits in RAM, building in one shot" - << std::endl; + diskann::cout << "Full index fits in RAM budget, should consume at most " + << full_index_ram / (1024 * 1024 * 1024) + << "GiBs, so building in one shot" << std::endl; diskann::Parameters paras; paras.Set("L", (unsigned) L); paras.Set("R", (unsigned) R); paras.Set("C", 750); - paras.Set("alpha", 2.0f); + paras.Set("alpha", 1.2f); paras.Set("num_rnds", 2); paras.Set("saturate_graph", 1); paras.Set("save_path", mem_index_path); std::unique_ptr> _pvamanaIndex = std::unique_ptr>( - new diskann::Index(_compareMetric, base_file.c_str())); + new diskann::Index(compareMetric, base_file.c_str())); _pvamanaIndex->build(paras); _pvamanaIndex->save(mem_index_path.c_str()); std::remove(medoids_file.c_str()); @@ -385,6 +412,13 @@ namespace diskann { for (int p = 0; p < num_parts; p++) { std::string shard_base_file = merged_index_prefix + "_subshard-" + std::to_string(p) + ".bin"; + + std::string shard_ids_file = merged_index_prefix + "_subshard-" + + std::to_string(p) + "_ids_uint32.bin"; + + retrieve_shard_data_from_ids(base_file, shard_ids_file, + shard_base_file); + std::string shard_index_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_mem.index"; @@ -392,16 +426,18 @@ namespace diskann { paras.Set("L", L); paras.Set("R", (2 * (R / 3))); paras.Set("C", 750); - paras.Set("alpha", 2.0f); + paras.Set("alpha", 1.2f); paras.Set("num_rnds", 2); paras.Set("saturate_graph", 1); paras.Set("save_path", shard_index_file); std::unique_ptr> _pvamanaIndex = std::unique_ptr>( - new diskann::Index(_compareMetric, shard_base_file.c_str())); + new diskann::Index(compareMetric, shard_base_file.c_str())); _pvamanaIndex->build(paras); _pvamanaIndex->save(shard_index_file.c_str()); + std::remove(shard_base_file.c_str()); + // wait_for_keystroke(); } diskann::merge_shards(merged_index_prefix + "_subshard-", "_mem.index", @@ -429,60 +465,55 @@ namespace diskann { // optimizes the beamwidth to maximize QPS for a given L_search subject to // 99.9 latency not blowing up - template - uint32_t optimize_beamwidth( - std::unique_ptr> &pFlashIndex, T - *tuning_sample, - _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, uint32_t L, - uint32_t nthreads, uint32_t start_bw) { - uint32_t cur_bw = start_bw; - double max_qps = 0; - uint32_t best_bw = start_bw; - bool stop_flag = false; - - while (!stop_flag) { - std::vector tuning_sample_result_ids_64(tuning_sample_num, - 0); - std::vector tuning_sample_result_dists(tuning_sample_num, - 0); - diskann::QueryStats * stats = new - diskann::QueryStats[tuning_sample_num]; - - auto s = std::chrono::high_resolution_clock::now(); - #pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) - for (_s64 i = 0; i < (int64_t) tuning_sample_num; i++) { - pFlashIndex->cached_beam_search( - tuning_sample + (i * tuning_sample_aligned_dim), 1, L, - tuning_sample_result_ids_64.data() + (i * 1), - tuning_sample_result_dists.data() + (i * 1), cur_bw, stats + - i); - } - auto e = std::chrono::high_resolution_clock::now(); - std::chrono::duration diff = e - s; - double qps = (1.0f * tuning_sample_num) / (1.0f * diff.count()); - - double lat_999 = diskann::get_percentile_stats( - stats, tuning_sample_num, 0.999, - [](const diskann::QueryStats &stats) { return stats.total_us; }); - - double mean_latency = diskann::get_mean_stats( - stats, tuning_sample_num, - [](const diskann::QueryStats &stats) { return stats.total_us; }); - - if (qps > max_qps && lat_999 < (15000) + mean_latency * 2) { - max_qps = qps; - best_bw = cur_bw; - cur_bw = (uint32_t)(std::ceil)((float) cur_bw * 1.1); - } else { - stop_flag = true; - } - if (cur_bw > 64) - stop_flag = true; - - delete[] stats; + template + uint32_t optimize_beamwidth( + std::unique_ptr> &pFlashIndex, T *tuning_sample, + _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, uint32_t L, + uint32_t nthreads, uint32_t start_bw) { + uint32_t cur_bw = start_bw; + double max_qps = 0; + uint32_t best_bw = start_bw; + bool stop_flag = false; + + while (!stop_flag) { + std::vector tuning_sample_result_ids_64(tuning_sample_num, 0); + std::vector tuning_sample_result_dists(tuning_sample_num, 0); + diskann::QueryStats * stats = new diskann::QueryStats[tuning_sample_num]; + + auto s = std::chrono::high_resolution_clock::now(); +#pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) + for (_s64 i = 0; i < (int64_t) tuning_sample_num; i++) { + pFlashIndex->cached_beam_search( + tuning_sample + (i * tuning_sample_aligned_dim), 1, L, + tuning_sample_result_ids_64.data() + (i * 1), + tuning_sample_result_dists.data() + (i * 1), cur_bw, stats + i); } - return best_bw; + auto e = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = e - s; + double qps = (1.0f * tuning_sample_num) / (1.0f * diff.count()); + + double lat_999 = diskann::get_percentile_stats( + stats, tuning_sample_num, 0.999, + [](const diskann::QueryStats &stats) { return stats.total_us; }); + + double mean_latency = diskann::get_mean_stats( + stats, tuning_sample_num, + [](const diskann::QueryStats &stats) { return stats.total_us; }); + + if (qps > max_qps && lat_999 < (15000) + mean_latency * 2) { + max_qps = qps; + best_bw = cur_bw; + cur_bw = (uint32_t)(std::ceil)((float) cur_bw * 1.1); + } else { + stop_flag = true; + } + if (cur_bw > 64) + stop_flag = true; + + delete[] stats; } + return best_bw; + } template void create_disk_layout(const std::string base_file, @@ -539,18 +570,15 @@ namespace diskann { << std::endl; // SECTOR_LEN buffer for each sector - std::unique_ptr sector_buf = - std::make_unique(SECTOR_LEN); - std::unique_ptr node_buf = - std::make_unique(max_node_len); + std::unique_ptr sector_buf = std::make_unique(SECTOR_LEN); + std::unique_ptr node_buf = std::make_unique(max_node_len); unsigned &nnbrs = *(unsigned *) (node_buf.get() + ndims_64 * sizeof(T)); unsigned *nhood_buf = (unsigned *) (node_buf.get() + (ndims_64 * sizeof(T)) + sizeof(unsigned)); // number of sectors (1 for meta data) - _u64 n_sectors = ROUND_UP(npts_64, nnodes_per_sector) / - nnodes_per_sector; + _u64 n_sectors = ROUND_UP(npts_64, nnodes_per_sector) / nnodes_per_sector; _u64 disk_index_file_size = (n_sectors + 1) * SECTOR_LEN; // write first sector with metadata *(_u64 *) (sector_buf.get() + 0 * sizeof(_u64)) = disk_index_file_size; @@ -584,8 +612,7 @@ namespace diskann { // write coords of node first // T *node_coords = data + ((_u64) ndims_64 * cur_node_id); - base_reader.read((char *) cur_node_coords.get(), sizeof(T) * - ndims_64); + base_reader.read((char *) cur_node_coords.get(), sizeof(T) * ndims_64); memcpy(node_buf.get(), cur_node_coords.get(), ndims_64 * sizeof(T)); // write nnbrs @@ -612,7 +639,7 @@ namespace diskann { template bool build_disk_index(const char *dataFilePath, const char *indexFilePath, const char * indexBuildParameters, - diskann::Metric _compareMetric) { + diskann::Metric compareMetric) { std::stringstream parser; parser << std::string(indexBuildParameters); std::string cur_param; @@ -620,17 +647,42 @@ namespace diskann { while (parser >> cur_param) param_list.push_back(cur_param); - if (param_list.size() != 5) { + if (param_list.size() != 5 && param_list.size() != 6) { diskann::cout << "Correct usage of parameters is R (max degree) " "L (indexing list size, better if >= R) B (RAM limit of final " "index in " "GB) M (memory limit while indexing) T (number of threads for " - "indexing)" + "indexing) B' (PQ bytes for disk index: optional parameter for " + "very large dimensional data)" << std::endl; return false; } + if (!std::is_same::value && + compareMetric == diskann::Metric::INNER_PRODUCT) { + std::stringstream stream; + stream << "DiskANN currently only supports floating point data for Max " + "Inner Product Search. " + << std::endl; + throw diskann::ANNException(stream.str(), -1); + } + + _u32 disk_pq_dims = 0; + bool use_disk_pq = false; + + // if there is a 6th parameter, it means we compress the disk index vectors + // also using PQ data (for very large dimensionality data). If the provided + // parameter is 0, it means we store full vectors. + if (param_list.size() == 6) { + disk_pq_dims = atoi(param_list[5].c_str()); + use_disk_pq = true; + if (disk_pq_dims == 0) + use_disk_pq = false; + } + + std::string base_file(dataFilePath); + std::string data_file_to_use = base_file; std::string index_prefix_path(indexFilePath); std::string pq_pivots_path = index_prefix_path + "_pq_pivots.bin"; std::string pq_compressed_vectors_path = @@ -640,6 +692,30 @@ namespace diskann { std::string medoids_path = disk_index_path + "_medoids.bin"; std::string centroids_path = disk_index_path + "_centroids.bin"; std::string sample_base_prefix = index_prefix_path + "_sample"; + std::string disk_pq_pivots_path = + index_prefix_path + + "_disk.index_pq_pivots.bin"; // optional if disk index is also storing + // pq data + std::string disk_pq_compressed_vectors_path = // optional if disk index is + // also storing pq data + index_prefix_path + "_disk.index_pq_compressed.bin"; + + // output a new base file which contains extra dimension with sqrt(1 - + // ||x||^2/M^2) for every x, M is max norm of all points. Extra space on + // disk needed! + if (compareMetric == diskann::Metric::INNER_PRODUCT) { + std::cout << "Using Inner Product search, so need to pre-process base " + "data into temp file. Please ensure there is additional " + "(n*(d+1)*4) bytes for storing pre-processed base vectors, " + "apart from the intermin indices and final index." + << std::endl; + std::string prepped_base = index_prefix_path + "_prepped_base.bin"; + data_file_to_use = prepped_base; + float max_norm_of_base = + diskann::prepare_base_for_inner_products(base_file, prepped_base); + std::string norm_file = disk_index_path + "_max_base_norm.bin"; + diskann::save_bin(norm_file, &max_norm_of_base, 1, 1); + } unsigned R = (unsigned) atoi(param_list[0].c_str()); unsigned L = (unsigned) atoi(param_list[1].c_str()); @@ -673,7 +749,7 @@ namespace diskann { size_t points_num, dim; - diskann::get_bin_metadata(dataFilePath, points_num, dim); + diskann::get_bin_metadata(data_file_to_use.c_str(), points_num, dim); size_t num_pq_chunks = (size_t)(std::floor)(_u64(final_index_ram_limit / points_num)); @@ -692,38 +768,68 @@ namespace diskann { double p_val = ((double) TRAINING_SET_SIZE / (double) points_num); // generates random sample and sets it to train_data and updates // train_size - gen_random_slice(dataFilePath, p_val, train_data, train_size, - train_dim); + gen_random_slice(data_file_to_use.c_str(), p_val, train_data, train_size, + train_dim); + + if (use_disk_pq) { + if (disk_pq_dims > dim) + disk_pq_dims = dim; + + std::cout << "Compressing base for disk-PQ into " << disk_pq_dims + << " chunks " << std::endl; + generate_pq_pivots(train_data, train_size, (uint32_t) dim, 256, + (uint32_t) disk_pq_dims, NUM_KMEANS_REPS, + disk_pq_pivots_path, false); + if (compareMetric == diskann::Metric::INNER_PRODUCT) + generate_pq_data_from_pivots( + data_file_to_use.c_str(), 256, (uint32_t) disk_pq_dims, + disk_pq_pivots_path, disk_pq_compressed_vectors_path); + else + generate_pq_data_from_pivots( + data_file_to_use.c_str(), 256, (uint32_t) disk_pq_dims, + disk_pq_pivots_path, disk_pq_compressed_vectors_path); + } + diskann::cout << "Training data loaded of size " << train_size << std::endl; - diskann::cout << "Training data loaded of size " << train_size << - std::endl; + // don't translate data to make zero mean for PQ compression. We must not + // translate for inner product search. + bool make_zero_mean = true; + if (compareMetric == diskann::Metric::INNER_PRODUCT) + make_zero_mean = false; generate_pq_pivots(train_data, train_size, (uint32_t) dim, 256, - (uint32_t) num_pq_chunks, 15, pq_pivots_path); - generate_pq_data_from_pivots(dataFilePath, 256, (uint32_t) - num_pq_chunks, - pq_pivots_path, + (uint32_t) num_pq_chunks, NUM_KMEANS_REPS, + pq_pivots_path, make_zero_mean); + + generate_pq_data_from_pivots(data_file_to_use.c_str(), 256, + (uint32_t) num_pq_chunks, pq_pivots_path, pq_compressed_vectors_path); delete[] train_data; train_data = nullptr; + MallocExtension::instance()->ReleaseFreeMemory(); diskann::build_merged_vamana_index( - dataFilePath, _compareMetric, L, R, p_val, indexing_ram_budget, - mem_index_path, medoids_path, centroids_path); + data_file_to_use.c_str(), diskann::Metric::L2, L, R, p_val, + indexing_ram_budget, mem_index_path, medoids_path, centroids_path); - diskann::create_disk_layout(dataFilePath, mem_index_path, - disk_index_path); + if (!use_disk_pq) { + diskann::create_disk_layout(data_file_to_use.c_str(), mem_index_path, + disk_index_path); + } else + diskann::create_disk_layout<_u8>(disk_pq_compressed_vectors_path, + mem_index_path, disk_index_path); double sample_sampling_rate = (150000.0 / points_num); - gen_random_slice(dataFilePath, sample_base_prefix, - sample_sampling_rate); + gen_random_slice(data_file_to_use.c_str(), sample_base_prefix, + sample_sampling_rate); std::remove(mem_index_path.c_str()); + if (use_disk_pq) + std::remove(disk_pq_compressed_vectors_path.c_str()); - auto e = - std::chrono::high_resolution_clock::now(); + auto e = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = e - s; diskann::cout << "Indexing time: " << diff.count() << std::endl; @@ -781,26 +887,26 @@ namespace diskann { template DISKANN_DLLEXPORT bool build_disk_index( const char *dataFilePath, const char *indexFilePath, - const char *indexBuildParameters, diskann::Metric _compareMetric); + const char *indexBuildParameters, diskann::Metric compareMetric); template DISKANN_DLLEXPORT bool build_disk_index( const char *dataFilePath, const char *indexFilePath, - const char *indexBuildParameters, diskann::Metric _compareMetric); + const char *indexBuildParameters, diskann::Metric compareMetric); template DISKANN_DLLEXPORT bool build_disk_index( const char *dataFilePath, const char *indexFilePath, - const char *indexBuildParameters, diskann::Metric _compareMetric); + const char *indexBuildParameters, diskann::Metric compareMetric); template DISKANN_DLLEXPORT int build_merged_vamana_index( - std::string base_file, diskann::Metric _compareMetric, unsigned L, + std::string base_file, diskann::Metric compareMetric, unsigned L, unsigned R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file); template DISKANN_DLLEXPORT int build_merged_vamana_index( - std::string base_file, diskann::Metric _compareMetric, unsigned L, + std::string base_file, diskann::Metric compareMetric, unsigned L, unsigned R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file); template DISKANN_DLLEXPORT int build_merged_vamana_index( - std::string base_file, diskann::Metric _compareMetric, unsigned L, + std::string base_file, diskann::Metric compareMetric, unsigned L, unsigned R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file); diff --git a/src/index.cpp b/src/index.cpp index 240587866..38b58da96 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -61,6 +61,9 @@ namespace { std::cout << "Older CPU. Using slow distance computation" << std::endl; return new diskann::SlowDistanceL2Float(); } + } else if (m == diskann::Metric::INNER_PRODUCT) { + std::cout << "Using Inner Product computation" << std::endl; + return new diskann::DistanceInnerProduct(); } else { std::stringstream stream; stream << "Only L2 metric supported as of now. Email " @@ -129,8 +132,8 @@ namespace diskann { const size_t nd, const size_t num_frozen_pts, const bool enable_tags, const bool store_data, const bool support_eager_delete) - : _num_frozen_pts(num_frozen_pts), _has_built(false), _width(0), - _can_delete(false), _eager_done(true), _lazy_done(true), + : _metric(m), _num_frozen_pts(num_frozen_pts), _has_built(false), + _width(0), _can_delete(false), _eager_done(true), _lazy_done(true), _compacted_order(true), _enable_tags(enable_tags), _consolidated_order(true), _support_eager_delete(support_eager_delete), _store_data(store_data) { @@ -179,7 +182,6 @@ namespace diskann { this->_distance = ::get_distance_function(m); _locks = std::vector(_max_points + _num_frozen_pts); - _width = 0; } @@ -372,7 +374,7 @@ namespace diskann { center[j] /= _nd; // compute all to one distance - float * distances = new float[_nd](); + float *distances = new float[_nd](); #pragma omp parallel for schedule(static, 65536) for (_s64 i = 0; i < (_s64) _nd; i++) { // extract point and distance reference @@ -535,7 +537,10 @@ namespace diskann { float cur_alpha = 1; while (cur_alpha <= alpha && result.size() < degree) { unsigned start = 0; - + float eps = + cur_alpha + + 0.01; // used for MIPS, where we store a value of eps in cur_alpha to + // denote pruned out entries which we can skip in later rounds. while (result.size() < degree && (start) < pool.size() && start < maxc) { auto &p = pool[start]; if (occlude_factor[start] > cur_alpha) { @@ -550,8 +555,20 @@ namespace diskann { float djk = _distance->compare( _data + _aligned_dim * (size_t) pool[t].id, _data + _aligned_dim * (size_t) p.id, (unsigned) _aligned_dim); - occlude_factor[t] = - (std::max)(occlude_factor[t], pool[t].distance / djk); + if (_metric == diskann::Metric::L2) { + occlude_factor[t] = + (std::max)(occlude_factor[t], pool[t].distance / djk); + } else if (_metric == + diskann::Metric::INNER_PRODUCT) { // stylized rules for + // inner product since + // we want max instead + // of min distance + float x = -pool[t].distance; + float y = -djk; + if (y > cur_alpha * x) { + occlude_factor[t] = (std::max)(occlude_factor[t], eps); + } + } } start++; } @@ -560,7 +577,7 @@ namespace diskann { } template - void Index::prune_neighbors(const unsigned location, + void Index::prune_neighbors(const unsigned location, std::vector &pool, const Parameters & parameter, std::vector &pruned_list) { @@ -640,7 +657,7 @@ namespace diskann { * the current node n. */ template - void Index::inter_insert(unsigned n, + void Index::inter_insert(unsigned n, std::vector &pruned_list, const Parameters & parameter, bool update_in_graph) { @@ -971,7 +988,7 @@ namespace diskann { } template - void Index::build(Parameters ¶meters, + void Index::build(Parameters & parameters, const std::vector &tags) { if (_enable_tags) { if (tags.size() != _nd) { @@ -1009,7 +1026,7 @@ namespace diskann { } template - std::pair Index::search(const T *query, + std::pair Index::search(const T * query, const size_t K, const unsigned L, unsigned * indices) { @@ -1054,6 +1071,8 @@ namespace diskann { for (auto it : best_L_nodes) { indices[pos] = it.id; distances[pos] = it.distance; + if (_metric == diskann::INNER_PRODUCT) + distances[pos] = -distances[pos]; pos++; if (pos == K) break; @@ -1064,7 +1083,7 @@ namespace diskann { template std::pair Index::search_with_tags( const T *query, const size_t K, const unsigned L, TagT *tags, - unsigned frozen_pts, unsigned *indices_buffer) { + unsigned *indices_buffer) { const bool alloc = indices_buffer == NULL; auto indices = alloc ? new unsigned[K] : indices_buffer; auto ret = search(query, K, L, indices); @@ -1273,7 +1292,7 @@ namespace diskann { } template - int Index::eager_delete(const TagT tag, + int Index::eager_delete(const TagT tag, const Parameters ¶meters) { if (_lazy_done && (!_consolidated_order)) { diskann::cout << "Lazy delete reuests issued but data not consolidated, " @@ -1749,7 +1768,7 @@ namespace diskann { template int Index::disable_delete(const Parameters ¶meters, - const bool consolidate) { + const bool consolidate) { LockGuard guard(_change_lock); if (!_can_delete) { diskann::cerr << "Delete not currently enabled" << std::endl; diff --git a/src/linux_aligned_file_reader.cpp b/src/linux_aligned_file_reader.cpp index 69eae9b60..35c8009bd 100644 --- a/src/linux_aligned_file_reader.cpp +++ b/src/linux_aligned_file_reader.cpp @@ -89,7 +89,7 @@ namespace { } std::cout << std::endl;*/ } -} +} // namespace LinuxAlignedFileReader::LinuxAlignedFileReader() { this->file_desc = -1; @@ -141,8 +141,8 @@ void LinuxAlignedFileReader::register_thread() { std::cerr << "io_setup() failed; returned " << ret << ", errno=" << errno << ":" << ::strerror(errno) << std::endl; } else { - std::cerr << "allocating ctx: " << ctx << " to thread-id:" << my_id - << std::endl; + diskann::cout << "allocating ctx: " << ctx << " to thread-id:" << my_id + << std::endl; ctx_map[my_id] = ctx; } lk.unlock(); diff --git a/src/partition_and_pq.cpp b/src/partition_and_pq.cpp index 9da49b220..044bdd70f 100644 --- a/src/partition_and_pq.cpp +++ b/src/partition_and_pq.cpp @@ -35,8 +35,11 @@ #include #endif +// block size for reading/ processing large files and matrices in blocks #define BLOCK_SIZE 5000000 +//#define SAVE_INFLATED_PQ true + template void gen_random_slice(const std::string base_file, const std::string output_prefix, double sampling_rate) { @@ -190,7 +193,7 @@ void gen_random_slice(const T *inputdata, size_t npts, size_t ndims, int generate_pq_pivots(const float *passed_train_data, size_t num_train, unsigned dim, unsigned num_centers, unsigned num_pq_chunks, unsigned max_k_means_reps, - std::string pq_pivots_path) { + std::string pq_pivots_path, bool make_zero_mean) { if (num_pq_chunks > dim) { diskann::cout << " Error: number of chunks more than dimension" << std::endl; @@ -226,17 +229,23 @@ int generate_pq_pivots(const float *passed_train_data, size_t num_train, std::unique_ptr centroid = std::make_unique(dim); for (uint64_t d = 0; d < dim; d++) { centroid[d] = 0; - for (uint64_t p = 0; p < num_train; p++) { - centroid[d] += train_data[p * dim + d]; - } - centroid[d] /= num_train; } + if (make_zero_mean) { // If we use L2 distance, there is an option to + // translate all vectors to make them centered and then + // compute PQ. This needs to be set to false when using + // PQ for MIPS as such translations dont preserve inner + // products. + for (uint64_t d = 0; d < dim; d++) { + for (uint64_t p = 0; p < num_train; p++) { + centroid[d] += train_data[p * dim + d]; + } + centroid[d] /= num_train; + } - // std::memset(centroid, 0 , dim*sizeof(float)); - - for (uint64_t d = 0; d < dim; d++) { - for (uint64_t p = 0; p < num_train; p++) { - train_data[p * dim + d] -= centroid[d]; + for (uint64_t d = 0; d < dim; d++) { + for (uint64_t p = 0; p < num_train; p++) { + train_data[p * dim + d] -= centroid[d]; + } } } @@ -251,7 +260,7 @@ int generate_pq_pivots(const float *passed_train_data, size_t num_train, std::vector> bin_to_dims(num_pq_chunks); tsl::robin_map dim_to_bin; - std::vector bin_loads(num_pq_chunks, 0); + std::vector bin_loads(num_pq_chunks, 0); // Process dimensions not inserted by previous loop for (uint32_t d = 0; d < dim; d++) { @@ -376,6 +385,8 @@ int generate_pq_data_from_pivots(const std::string data_file, std::unique_ptr rearrangement; std::unique_ptr chunk_offsets; + std::string inflated_pq_file = pq_compressed_vectors_path + "_inflated.bin"; + if (!file_exists(pq_pivots_path)) { diskann::cout << "ERROR: PQ k-means pivot file not found" << std::endl; throw diskann::ANNException("PQ k-means pivot file not found", -1); @@ -436,12 +447,23 @@ int generate_pq_data_from_pivots(const std::string data_file, std::ofstream compressed_file_writer(pq_compressed_vectors_path, std::ios::binary); - _u32 num_pq_chunks_u32 = num_pq_chunks; + _u32 num_pq_chunks_u32 = num_pq_chunks; compressed_file_writer.write((char *) &num_points, sizeof(uint32_t)); compressed_file_writer.write((char *) &num_pq_chunks_u32, sizeof(uint32_t)); size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE; + +#ifdef SAVE_INFLATED_PQ + std::ofstream inflated_file_writer(inflated_pq_file, std::ios::binary); + inflated_file_writer.write((char *) &num_points, sizeof(uint32_t)); + inflated_file_writer.write((char *) &basedim32, sizeof(uint32_t)); + + std::unique_ptr block_inflated_base = + std::make_unique(block_size * dim); + std::memset(block_inflated_base.get(), 0, block_size * dim * sizeof(float)); +#endif + std::unique_ptr<_u32[]> block_compressed_base = std::make_unique<_u32[]>(block_size * (_u64) num_pq_chunks); std::memset(block_compressed_base.get(), 0, @@ -514,6 +536,12 @@ int generate_pq_data_from_pivots(const std::string data_file, #pragma omp parallel for schedule(static, 8192) for (int64_t j = 0; j < (_s64) cur_blk_size; j++) { block_compressed_base[j * num_pq_chunks + i] = closest_center[j]; +#ifdef SAVE_INFLATED_PQ + for (uint64_t k = 0; k < cur_chunk_size; k++) + block_inflated_base[j * dim + chunk_offsets[i] + k] = + cur_pivot_data[closest_center[j] * cur_chunk_size + k] + + centroid[chunk_offsets[i] + k]; +#endif } } @@ -530,6 +558,10 @@ int generate_pq_data_from_pivots(const std::string data_file, (char *) (pVec.get()), cur_blk_size * num_pq_chunks * sizeof(uint8_t)); } +#ifdef SAVE_INFLATED_PQ + inflated_file_writer.write((char *) (block_inflated_base.get()), + cur_blk_size * dim * sizeof(float)); +#endif diskann::cout << ".done." << std::endl; } // Gopal. Splittng diskann_dll into separate DLLs for search and build. @@ -538,38 +570,25 @@ int generate_pq_data_from_pivots(const std::string data_file, MallocExtension::instance()->ReleaseFreeMemory(); #endif compressed_file_writer.close(); +#ifdef SAVE_INFLATED_PQ + inflated_file_writer.close(); +#endif return 0; } -template -int estimate_cluster_sizes(const std::string data_file, float *pivots, - const size_t num_centers, const size_t dim, - const size_t k_base, +int estimate_cluster_sizes(float *test_data_float, size_t num_test, + float *pivots, const size_t num_centers, + const size_t test_dim, const size_t k_base, std::vector &cluster_sizes) { cluster_sizes.clear(); - size_t num_test, test_dim; - float *test_data_float; - double sampling_rate = 0.01; - - gen_random_slice(data_file, sampling_rate, test_data_float, num_test, - test_dim); - - if (test_dim != dim) { - diskann::cout << "Error. dimensions dont match for pivot set and base set" - << std::endl; - return -1; - } - size_t *shard_counts = new size_t[num_centers]; for (size_t i = 0; i < num_centers; i++) { shard_counts[i] = 0; } - size_t num_points = 0, num_dim = 0; - diskann::get_bin_metadata(data_file, num_points, num_dim); - size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE; + size_t block_size = num_test <= BLOCK_SIZE ? num_test : BLOCK_SIZE; _u32 * block_closest_centers = new _u32[block_size * k_base]; float *block_data_float; @@ -582,8 +601,8 @@ int estimate_cluster_sizes(const std::string data_file, float *pivots, block_data_float = test_data_float + start_id * test_dim; - math_utils::compute_closest_centers(block_data_float, cur_blk_size, dim, - pivots, num_centers, k_base, + math_utils::compute_closest_centers(block_data_float, cur_blk_size, + test_dim, pivots, num_centers, k_base, block_closest_centers); for (size_t p = 0; p < cur_blk_size; p++) { @@ -597,9 +616,8 @@ int estimate_cluster_sizes(const std::string data_file, float *pivots, diskann::cout << "Estimated cluster sizes: "; for (size_t i = 0; i < num_centers; i++) { _u32 cur_shard_count = (_u32) shard_counts[i]; - cluster_sizes.push_back( - size_t(((double) cur_shard_count) * (1.0 / sampling_rate))); - diskann::cout << cur_shard_count * (1.0 / sampling_rate) << " "; + cluster_sizes.push_back((size_t) cur_shard_count); + diskann::cout << cur_shard_count << " "; } diskann::cout << std::endl; delete[] shard_counts; @@ -706,6 +724,164 @@ int shard_data_into_clusters(const std::string data_file, float *pivots, return 0; } +// useful for partitioning large dataset. we first generate only the IDS for +// each shard, and retrieve the actual vectors on demand. +template +int shard_data_into_clusters_only_ids(const std::string data_file, + float *pivots, const size_t num_centers, + const size_t dim, const size_t k_base, + std::string prefix_path) { + _u64 read_blk_size = 64 * 1024 * 1024; + // _u64 write_blk_size = 64 * 1024 * 1024; + // create cached reader + writer + cached_ifstream base_reader(data_file, read_blk_size); + _u32 npts32; + _u32 basedim32; + base_reader.read((char *) &npts32, sizeof(uint32_t)); + base_reader.read((char *) &basedim32, sizeof(uint32_t)); + size_t num_points = npts32; + if (basedim32 != dim) { + diskann::cout << "Error. dimensions dont match for train set and base set" + << std::endl; + return -1; + } + + std::unique_ptr shard_counts = + std::make_unique(num_centers); + + std::vector shard_idmap_writer(num_centers); + _u32 dummy_size = 0; + _u32 const_one = 1; + + for (size_t i = 0; i < num_centers; i++) { + std::string idmap_filename = + prefix_path + "_subshard-" + std::to_string(i) + "_ids_uint32.bin"; + shard_idmap_writer[i] = + std::ofstream(idmap_filename.c_str(), std::ios::binary); + shard_idmap_writer[i].write((char *) &dummy_size, sizeof(uint32_t)); + shard_idmap_writer[i].write((char *) &const_one, sizeof(uint32_t)); + shard_counts[i] = 0; + } + + size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE; + std::unique_ptr<_u32[]> block_closest_centers = + std::make_unique<_u32[]>(block_size * k_base); + std::unique_ptr block_data_T = std::make_unique(block_size * dim); + std::unique_ptr block_data_float = + std::make_unique(block_size * dim); + + size_t num_blocks = DIV_ROUND_UP(num_points, block_size); + + for (size_t block = 0; block < num_blocks; block++) { + size_t start_id = block * block_size; + size_t end_id = (std::min)((block + 1) * block_size, num_points); + size_t cur_blk_size = end_id - start_id; + + base_reader.read((char *) block_data_T.get(), + sizeof(T) * (cur_blk_size * dim)); + diskann::convert_types(block_data_T.get(), block_data_float.get(), + cur_blk_size, dim); + + math_utils::compute_closest_centers(block_data_float.get(), cur_blk_size, + dim, pivots, num_centers, k_base, + block_closest_centers.get()); + + for (size_t p = 0; p < cur_blk_size; p++) { + for (size_t p1 = 0; p1 < k_base; p1++) { + size_t shard_id = block_closest_centers[p * k_base + p1]; + uint32_t original_point_map_id = (uint32_t)(start_id + p); + shard_idmap_writer[shard_id].write((char *) &original_point_map_id, + sizeof(uint32_t)); + shard_counts[shard_id]++; + } + } + } + + size_t total_count = 0; + diskann::cout << "Actual shard sizes: " << std::flush; + for (size_t i = 0; i < num_centers; i++) { + _u32 cur_shard_count = (_u32) shard_counts[i]; + total_count += cur_shard_count; + diskann::cout << cur_shard_count << " "; + shard_idmap_writer[i].seekp(0); + shard_idmap_writer[i].write((char *) &cur_shard_count, sizeof(uint32_t)); + shard_idmap_writer[i].close(); + } + + diskann::cout << "\n Partitioned " << num_points + << " with replication factor " << k_base << " to get " + << total_count << " points across " << num_centers << " shards " + << std::endl; + return 0; +} + +template +int retrieve_shard_data_from_ids(const std::string data_file, + std::string idmap_filename, + std::string data_filename) { + _u64 read_blk_size = 64 * 1024 * 1024; + // _u64 write_blk_size = 64 * 1024 * 1024; + // create cached reader + writer + cached_ifstream base_reader(data_file, read_blk_size); + _u32 npts32; + _u32 basedim32; + base_reader.read((char *) &npts32, sizeof(uint32_t)); + base_reader.read((char *) &basedim32, sizeof(uint32_t)); + size_t num_points = npts32; + size_t dim = basedim32; + + _u32 dummy_size = 0; + + std::ofstream shard_data_writer(data_filename.c_str(), std::ios::binary); + shard_data_writer.write((char *) &dummy_size, sizeof(uint32_t)); + shard_data_writer.write((char *) &basedim32, sizeof(uint32_t)); + + _u32 *shard_ids; + _u64 shard_size, tmp; + diskann::load_bin<_u32>(idmap_filename, shard_ids, shard_size, tmp); + + _u32 cur_pos = 0; + _u32 num_written = 0; + std::cout << "Shard has " << shard_size << " points" << std::endl; + + size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE; + std::unique_ptr block_data_T = std::make_unique(block_size * dim); + + size_t num_blocks = DIV_ROUND_UP(num_points, block_size); + + for (size_t block = 0; block < num_blocks; block++) { + size_t start_id = block * block_size; + size_t end_id = (std::min)((block + 1) * block_size, num_points); + size_t cur_blk_size = end_id - start_id; + + base_reader.read((char *) block_data_T.get(), + sizeof(T) * (cur_blk_size * dim)); + + for (size_t p = 0; p < cur_blk_size; p++) { + uint32_t original_point_map_id = (uint32_t)(start_id + p); + if (cur_pos == shard_size) + break; + if (original_point_map_id == shard_ids[cur_pos]) { + cur_pos++; + shard_data_writer.write((char *) (block_data_T.get() + p * dim), + sizeof(T) * dim); + num_written++; + } + } + if (cur_pos == shard_size) + break; + } + + diskann::cout << "Written file with " << num_written << " points" + << std::endl; + + shard_data_writer.seekp(0); + shard_data_writer.write((char *) &num_written, sizeof(uint32_t)); + shard_data_writer.close(); + delete[] shard_ids; + return 0; +} + // partitions a large base file into many shards using k-means hueristic // on a random sample generated using sampling_rate probability. After this, it // assignes each base point to the closest k_base nearest centers and creates @@ -751,10 +927,6 @@ int partition(const std::string data_file, const float sampling_rate, // now pivots are ready. need to stream base points and assign them to // closest clusters. - std::vector cluster_sizes; - estimate_cluster_sizes(data_file, pivot_data, num_parts, train_dim, k_base, - cluster_sizes); - shard_data_into_clusters(data_file, pivot_data, num_parts, train_dim, k_base, prefix_path); delete[] pivot_data; @@ -770,7 +942,7 @@ int partition_with_ram_budget(const std::string data_file, size_t train_dim; size_t num_train; float *train_data_float; - size_t max_k_means_reps = 20; + size_t max_k_means_reps = 10; int num_parts = 3; bool fit_in_ram = false; @@ -778,6 +950,12 @@ int partition_with_ram_budget(const std::string data_file, gen_random_slice(data_file, sampling_rate, train_data_float, num_train, train_dim); + size_t test_dim; + size_t num_test; + float *test_data_float; + gen_random_slice(data_file, sampling_rate, test_data_float, num_test, + test_dim); + float *pivot_data = nullptr; std::string cur_file = std::string(prefix_path); @@ -809,10 +987,13 @@ int partition_with_ram_budget(const std::string data_file, // closest clusters. std::vector cluster_sizes; - estimate_cluster_sizes(data_file, pivot_data, num_parts, train_dim, - k_base, cluster_sizes); + estimate_cluster_sizes(test_data_float, num_test, pivot_data, num_parts, + train_dim, k_base, cluster_sizes); for (auto &p : cluster_sizes) { + p = (_u64)(p / + sampling_rate); // to account for the fact that p is the size + // of the shard over the testing sample. double cur_shard_ram_estimate = ESTIMATE_RAM_USAGE(p, train_dim, sizeof(T), graph_degree); @@ -824,7 +1005,7 @@ int partition_with_ram_budget(const std::string data_file, << "GB, budget given is " << ram_budget << std::endl; if (max_ram_usage > 1024 * 1024 * 1024 * ram_budget) { fit_in_ram = false; - num_parts++; + num_parts += 2; } } @@ -832,27 +1013,28 @@ int partition_with_ram_budget(const std::string data_file, diskann::save_bin(output_file.c_str(), pivot_data, (size_t) num_parts, train_dim); - shard_data_into_clusters(data_file, pivot_data, num_parts, train_dim, - k_base, prefix_path); + shard_data_into_clusters_only_ids(data_file, pivot_data, num_parts, + train_dim, k_base, prefix_path); delete[] pivot_data; delete[] train_data_float; + delete[] test_data_float; return num_parts; } // Instantations of supported templates template void DISKANN_DLLEXPORT -gen_random_slice(const std::string base_file, + gen_random_slice(const std::string base_file, const std::string output_prefix, double sampling_rate); template void DISKANN_DLLEXPORT gen_random_slice( const std::string base_file, const std::string output_prefix, double sampling_rate); template void DISKANN_DLLEXPORT -gen_random_slice(const std::string base_file, + gen_random_slice(const std::string base_file, const std::string output_prefix, double sampling_rate); template void DISKANN_DLLEXPORT -gen_random_slice(const float *inputdata, size_t npts, size_t ndims, + gen_random_slice(const float *inputdata, size_t npts, size_t ndims, double p_val, float *&sampled_data, size_t &slice_size); template void DISKANN_DLLEXPORT gen_random_slice( const uint8_t *inputdata, size_t npts, size_t ndims, double p_val, @@ -891,6 +1073,16 @@ template DISKANN_DLLEXPORT int partition_with_ram_budget( const std::string data_file, const double sampling_rate, double ram_budget, size_t graph_degree, const std::string prefix_path, size_t k_base); +template DISKANN_DLLEXPORT int retrieve_shard_data_from_ids( + const std::string data_file, std::string idmap_filename, + std::string data_filename); +template DISKANN_DLLEXPORT int retrieve_shard_data_from_ids( + const std::string data_file, std::string idmap_filename, + std::string data_filename); +template DISKANN_DLLEXPORT int retrieve_shard_data_from_ids( + const std::string data_file, std::string idmap_filename, + std::string data_filename); + template DISKANN_DLLEXPORT int generate_pq_data_from_pivots( const std::string data_file, unsigned num_centers, unsigned num_pq_chunks, std::string pq_pivots_path, std::string pq_compressed_vectors_path); diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 21da3521d..9c1bea8ab 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -49,7 +49,7 @@ // returns region of `node_buf` containing [NNBRS][NBR_ID(_u32)] #define OFFSET_TO_NODE_NHOOD(node_buf) \ - (unsigned *) ((char *) node_buf + data_dim * sizeof(T)) + (unsigned *) ((char *) node_buf + disk_bytes_per_point) // returns region of `node_buf` containing [COORD(T)] #define OFFSET_TO_NODE_COORDS(node_buf) (T *) (node_buf) @@ -86,8 +86,8 @@ namespace { namespace diskann { template<> PQFlashIndex<_u8>::PQFlashIndex( - std::shared_ptr &fileReader) - : reader(fileReader) { + std::shared_ptr &fileReader, diskann::Metric metric) + : reader(fileReader), metric(metric) { diskann::cout << "dist_cmp function for _u8 uses slow implementation." " Please contact gopalsr@microsoft.com if you need an AVX/AVX2" @@ -106,12 +106,18 @@ namespace diskann { << std::endl; this->dist_cmp_float = new SlowDistanceL2Float(); } + if (metric != diskann::Metric::L2) { + std::cout << "Only L2 supported for byte vectors for now. Other distance " + "functions are future work. Falling back to L2 distance." + << std::endl; + this->metric = diskann::Metric::L2; + } } template<> PQFlashIndex<_s8>::PQFlashIndex( - std::shared_ptr &fileReader) - : reader(fileReader) { + std::shared_ptr &fileReader, diskann::Metric metric) + : reader(fileReader), metric(metric) { if (Avx2SupportedCPU) { diskann::cout << "Using AVX2 function for dist_cmp and dist_cmp_float" << std::endl; @@ -130,29 +136,47 @@ namespace diskann { this->dist_cmp = new SlowDistanceL2Int(); this->dist_cmp_float = new SlowDistanceL2Float(); } + if (metric != diskann::Metric::L2) { + std::cout << "Only L2 supported for byte vectors for now. Other distance " + "functions are future work. Falling back to L2 distance." + << std::endl; + this->metric = diskann::Metric::L2; + } } template<> PQFlashIndex::PQFlashIndex( - std::shared_ptr &fileReader) - : reader(fileReader) { - if (Avx2SupportedCPU) { - diskann::cout << "Using AVX2 functions for dist_cmp and dist_cmp_float" - << std::endl; - this->dist_cmp = new DistanceL2(); - this->dist_cmp_float = new DistanceL2(); - } else if (AvxSupportedCPU) { - diskann::cout << "No AVX2 support. Switching to AVX functions for " - "dist_cmp and dist_cmp_float." - << std::endl; - this->dist_cmp = new AVXDistanceL2Float(); - this->dist_cmp_float = new AVXDistanceL2Float(); + std::shared_ptr &fileReader, diskann::Metric metric) + : reader(fileReader), metric(metric) { + if (metric == diskann::Metric::L2) { + if (Avx2SupportedCPU) { + diskann::cout << "Using AVX2 functions for dist_cmp and dist_cmp_float" + << std::endl; + this->dist_cmp = new DistanceL2(); + this->dist_cmp_float = new DistanceL2(); + } else if (AvxSupportedCPU) { + diskann::cout << "No AVX2 support. Switching to AVX functions for " + "dist_cmp and dist_cmp_float." + << std::endl; + this->dist_cmp = new AVXDistanceL2Float(); + this->dist_cmp_float = new AVXDistanceL2Float(); + } else { + diskann::cout + << "No AVX/AVX2 support. Switching to slow implementations " + "for dist_cmp and dist_cmp_float" + << std::endl; + this->dist_cmp = new AVXDistanceL2Float(); + this->dist_cmp_float = new AVXDistanceL2Float(); + } + } else if (metric == diskann::Metric::INNER_PRODUCT) { + std::cout << "Using inner product distance function" << std::endl; + this->dist_cmp = new DistanceInnerProduct(); + this->dist_cmp_float = new DistanceInnerProduct(); } else { - diskann::cout << "No AVX/AVX2 support. Switching to slow implementations " - "for dist_cmp and dist_cmp_float" - << std::endl; + std::cout << "Unsupported metric type. Reverting to float." << std::endl; this->dist_cmp = new AVXDistanceL2Float(); this->dist_cmp_float = new AVXDistanceL2Float(); + this->metric = diskann::Metric::L2; } } @@ -280,7 +304,7 @@ namespace diskann { for (_u64 block = 0; block < num_blocks; block++) { _u64 start_idx = block * BLOCK_SIZE; _u64 end_idx = (std::min)(num_cached_nodes, (block + 1) * BLOCK_SIZE); - std::vector read_reqs; + std::vector read_reqs; std::vector> nhoods; for (_u64 node_idx = start_idx; node_idx < end_idx; node_idx++) { AlignedRead read; @@ -300,7 +324,7 @@ namespace diskann { char *node_buf = OFFSET_TO_NODE(nhood.second, nhood.first); T * node_coords = OFFSET_TO_NODE_COORDS(node_buf); T * cached_coords = coord_cache_buf + node_idx * aligned_dim; - memcpy(cached_coords, node_coords, data_dim * sizeof(T)); + memcpy(cached_coords, node_coords, disk_bytes_per_point); coord_cache.insert(std::make_pair(nhood.first, cached_coords)); // insert node nhood into nhood_cache @@ -447,7 +471,7 @@ namespace diskann { size_t start = block * BLOCK_SIZE; size_t end = (std::min)((block + 1) * BLOCK_SIZE, nodes_to_expand.size()); - std::vector read_reqs; + std::vector read_reqs; std::vector> nhoods; for (size_t cur_pt = start; cur_pt < end; cur_pt++) { char *buf = nullptr; @@ -542,12 +566,17 @@ namespace diskann { // add medoid coords to `coord_cache` T *medoid_coords = new T[data_dim]; T *medoid_disk_coords = OFFSET_TO_NODE_COORDS(medoid_node_buf); - memcpy(medoid_coords, medoid_disk_coords, data_dim * sizeof(T)); - - for (uint32_t i = 0; i < data_dim; i++) - centroid_data[cur_m * aligned_dim + i] = medoid_coords[i]; + memcpy(medoid_coords, medoid_disk_coords, disk_bytes_per_point); + if (!use_disk_index_pq) { + for (uint32_t i = 0; i < data_dim; i++) + centroid_data[cur_m * aligned_dim + i] = medoid_coords[i]; + } else { + disk_pq_table.inflate_vector((_u8 *) medoid_coords, + (centroid_data + cur_m * aligned_dim)); + } aligned_free(medoid_buf); + delete[] medoid_coords; } // return ctx @@ -588,6 +617,12 @@ namespace diskann { } this->data_dim = pq_file_dim; + this->disk_data_dim = + this->data_dim; // will reset later if we use PQ on disk + this->disk_bytes_per_point = + this->data_dim * + sizeof(T); // will change later if we use PQ on disk or if we are using + // inner product without PQ this->aligned_dim = ROUND_UP(pq_file_dim, 8); size_t npts_u64, nchunks_u64; @@ -614,6 +649,26 @@ namespace diskann { << " #aligned_dim: " << aligned_dim << " #chunks: " << n_chunks << std::endl; + std::string disk_pq_pivots_path = this->disk_index_file + "_pq_pivots.bin"; + if (file_exists(disk_pq_pivots_path)) { + use_disk_index_pq = true; +#ifdef EXEC_ENV_OLS + disk_pq_table.load_pq_centroid_bin( + files, disk_pq_pivots_path.c_str(), + 0); // giving 0 chunks to make the pq_table infer from the + // chunk_offsets file the correct value +#else + disk_pq_table.load_pq_centroid_bin( + disk_pq_pivots_path.c_str(), + 0); // giving 0 chunks to make the pq_table infer from the + // chunk_offsets file the correct value +#endif + disk_pq_n_chunks = disk_pq_table.get_num_chunks(); + disk_bytes_per_point = disk_pq_n_chunks * sizeof(_u8); + std::cout << "Disk index uses PQ data compressed down to " + << disk_pq_n_chunks << " bytes per point." << std::endl; + } + // read index metadata #ifdef EXEC_ENV_OLS // This is a bit tricky. We have to read the header from the @@ -658,7 +713,7 @@ namespace diskann { READ_U64(index_metadata, medoid_id_on_file); READ_U64(index_metadata, max_node_len); READ_U64(index_metadata, nnodes_per_sector); - max_degree = ((max_node_len - data_dim * sizeof(T)) / sizeof(unsigned)) - 1; + max_degree = ((max_node_len - disk_bytes_per_point) / sizeof(unsigned)) - 1; diskann::cout << "Disk-Index File Meta-data: "; diskann::cout << "# nodes per sector: " << nnodes_per_sector; @@ -738,6 +793,17 @@ namespace diskann { use_medoids_data_as_centroids(); } + std::string norm_file = std::string(disk_index_file) + "_max_base_norm.bin"; + + if (file_exists(norm_file) && metric == diskann::Metric::INNER_PRODUCT) { + _u64 dumr, dumc; + float *norm_val; + diskann::load_bin(norm_file, norm_val, dumr, dumc); + this->max_base_norm = norm_val[0]; + std::cout << "Setting re-scaling factor of base vectors to " + << this->max_base_norm << std::endl; + delete[] norm_val; + } diskann::cout << "done.." << std::endl; return 0; } @@ -763,22 +829,40 @@ namespace diskann { template void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_search, const _u64 l_search, _u64 *indices, - float * distances, - const _u64 beam_width, - QueryStats * stats, - Distance *output_dist_func) { + float * distances, + const _u64 beam_width, + QueryStats *stats) { ThreadData data = this->thread_data.pop(); while (data.scratch.sector_scratch == nullptr) { this->thread_data.wait_for_push_notify(); data = this->thread_data.pop(); } +//std::cout<data_dim; i++) { data.scratch.aligned_query_float[i] = query1[i]; + data.scratch.aligned_query_T[i] = query1[i]; + query_norm += query1[i] * query1[i]; + } + + // if inner product, we laso normalize the query and set the last coordinate + // to 0 (this is the extra coordindate used to convert MIPS to L2 search) + if (metric == diskann::Metric::INNER_PRODUCT) { + query_norm = std::sqrt(query_norm); + data.scratch.aligned_query_T[this->data_dim - 1] = 0; + data.scratch.aligned_query_float[this->data_dim - 1] = 0; + for (uint32_t i = 0; i < this->data_dim - 1; i++) { + data.scratch.aligned_query_T[i] /= query_norm; + data.scratch.aligned_query_float[i] /= query_norm; + } } - memcpy(data.scratch.aligned_query_T, query1, this->data_dim * sizeof(T)); - const T * query = data.scratch.aligned_query_T; - const float *query_float = data.scratch.aligned_query_float; IOContext &ctx = data.ctx; auto query_scratch = &(data.scratch); @@ -801,15 +885,16 @@ namespace diskann { // query <-> PQ chunk centers distances float *pq_dists = query_scratch->aligned_pqtable_dist_scratch; - pq_table.populate_chunk_distances(query, pq_dists); + pq_table.populate_chunk_distances(query_float, pq_dists); // query <-> neighbor list float *dist_scratch = query_scratch->aligned_dist_scratch; _u8 * pq_coord_scratch = query_scratch->aligned_pq_coord_scratch; // lambda to batch compute query<-> node distances in PQ space - auto compute_dists = [this, pq_coord_scratch, pq_dists]( - const unsigned *ids, const _u64 n_ids, float *dists_out) { + auto compute_dists = [this, pq_coord_scratch, pq_dists](const unsigned *ids, + const _u64 n_ids, + float *dists_out) { ::aggregate_coords(ids, n_ids, this->data, this->n_chunks, pq_coord_scratch); ::pq_dist_lookup(pq_coord_scratch, n_ids, this->n_chunks, pq_dists, @@ -837,6 +922,7 @@ namespace diskann { } compute_dists(&best_medoid, 1, dist_scratch); + retset[0].id = best_medoid; retset[0].distance = dist_scratch[0]; retset[0].flag = true; @@ -852,9 +938,9 @@ namespace diskann { unsigned k = 0; // cleared every iteration - std::vector frontier; + std::vector frontier; std::vector> frontier_nhoods; - std::vector frontier_read_reqs; + std::vector frontier_read_reqs; std::vector>> cached_nhoods; @@ -873,19 +959,6 @@ namespace diskann { _u32 marker = k; _u32 num_seen = 0; - /* - bool marker_set = false; - diskann::cout << "hop " << hops << ": "; - for (_u32 i = 0; i < cur_list_size; i++) { - diskann::cout << retset[i].id << "( " << retset[i].distance; - if (retset[i].flag && !marker_set) { - diskann::cout << ",*) "; - marker_set = true; - } else - diskann::cout << ") "; - } - diskann::cout << std::endl; - */ while (marker < cur_list_size && frontier.size() < beam_width && num_seen < beam_width + 2) { if (retset[marker].flag) { @@ -915,7 +988,7 @@ namespace diskann { if (stats != nullptr) stats->n_hops++; for (_u64 i = 0; i < frontier.size(); i++) { - auto id = frontier[i]; + auto id = frontier[i]; std::pair<_u32, char *> fnhood; fnhood.first = id; fnhood.second = sector_scratch + sector_scratch_idx * SECTOR_LEN; @@ -945,8 +1018,18 @@ namespace diskann { for (auto &cached_nhood : cached_nhoods) { auto global_cache_iter = coord_cache.find(cached_nhood.first); T * node_fp_coords_copy = global_cache_iter->second; - float cur_expanded_dist = dist_cmp->compare(query, node_fp_coords_copy, - (unsigned) aligned_dim); + float cur_expanded_dist; + if (!use_disk_index_pq) { + cur_expanded_dist = dist_cmp->compare(query, node_fp_coords_copy, + (unsigned) aligned_dim); + } else { + if (metric == diskann::Metric::INNER_PRODUCT) + cur_expanded_dist = disk_pq_table.inner_product( + query_float, (_u8 *) node_fp_coords_copy); + else + cur_expanded_dist = disk_pq_table.l2_distance( + query_float, (_u8 *) node_fp_coords_copy); + } full_retset.push_back( Neighbor((unsigned) cached_nhood.first, cur_expanded_dist, true)); @@ -1010,14 +1093,26 @@ namespace diskann { unsigned *node_buf = OFFSET_TO_NODE_NHOOD(node_disk_buf); _u64 nnbrs = (_u64)(*node_buf); T * node_fp_coords = OFFSET_TO_NODE_COORDS(node_disk_buf); - assert(data_buf_idx < MAX_N_CMPS); +// assert(data_buf_idx < MAX_N_CMPS); + if (data_buf_idx == MAX_N_CMPS) + data_buf_idx = 0; T *node_fp_coords_copy = data_buf + (data_buf_idx * aligned_dim); data_buf_idx++; - memcpy(node_fp_coords_copy, node_fp_coords, data_dim * sizeof(T)); - - float cur_expanded_dist = dist_cmp->compare(query, node_fp_coords_copy, - (unsigned) aligned_dim); + memcpy(node_fp_coords_copy, node_fp_coords, disk_bytes_per_point); + + float cur_expanded_dist; + if (!use_disk_index_pq) { + cur_expanded_dist = dist_cmp->compare(query, node_fp_coords_copy, + (unsigned) aligned_dim); + } else { + if (metric == diskann::Metric::INNER_PRODUCT) + cur_expanded_dist = disk_pq_table.inner_product( + query_float, (_u8 *) node_fp_coords_copy); + else + cur_expanded_dist = disk_pq_table.l2_distance( + query_float, (_u8 *) node_fp_coords_copy); + } full_retset.push_back( Neighbor(frontier_nhood.first, cur_expanded_dist, true)); @@ -1077,17 +1172,34 @@ namespace diskann { hops++; } + // re-sort by distance std::sort(full_retset.begin(), full_retset.end(), [](const Neighbor &left, const Neighbor &right) { return left.distance < right.distance; }); + /* + std::cout<<"return set: \n"; + for (auto &x : full_retset) + std::cout< + _u32 PQFlashIndex::range_search(const T *query1, const double range, + const _u64 l_search, _u64* indices, float* distances, + const _u64 beam_width, + QueryStats *stats) { +_u32 res_count = 0; +this->cached_beam_search(query1, l_search, l_search, indices, distances, beam_width, stats); +for (_u32 i = 0; i < l_search; i++) { + //std::cout< (float) range) { + res_count = i; + break; + } else if (i == l_search -1) + res_count = l_search; +} +//std::cout<<"\n\n"< char *PQFlashIndex::getHeaderBytes() { diff --git a/src/windows_aligned_file_reader.cpp b/src/windows_aligned_file_reader.cpp index e8d455360..3dcb15bd3 100644 --- a/src/windows_aligned_file_reader.cpp +++ b/src/windows_aligned_file_reader.cpp @@ -67,7 +67,7 @@ IOContext& WindowsAlignedFileReader::get_ctx() { } void WindowsAlignedFileReader::read(std::vector& read_reqs, - IOContext& ctx, bool async) { + IOContext& ctx, bool async) { using namespace std::chrono_literals; // execute each request sequentially _u64 n_reqs = read_reqs.size(); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index fa6e8ea04..177d34f70 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -41,3 +41,13 @@ else() endif() +add_executable(range_search_disk_index range_search_disk_index.cpp + ${PROJECT_SOURCE_DIR}/src/aux_utils.cpp ) +if(MSVC) + target_link_options(range_search_disk_index PRIVATE /MACHINE:x64 /DEBUG:FULL) + target_link_libraries(range_search_disk_index debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib) + target_link_libraries(range_search_disk_index optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib) +else() + target_link_libraries(range_search_disk_index ${PROJECT_NAME} aio -ltcmalloc) +endif() + diff --git a/tests/build_disk_index.cpp b/tests/build_disk_index.cpp index b29c3f0d1..a83d43dd9 100644 --- a/tests/build_disk_index.cpp +++ b/tests/build_disk_index.cpp @@ -11,29 +11,38 @@ template bool build_index(const char* dataFilePath, const char* indexFilePath, - const char* indexBuildParameters) { - return diskann::build_disk_index( - dataFilePath, indexFilePath, indexBuildParameters, diskann::Metric::L2); + const char* indexBuildParameters, diskann::Metric metric) { + return diskann::build_disk_index(dataFilePath, indexFilePath, + indexBuildParameters, metric); } int main(int argc, char** argv) { - if (argc != 9) { + if (argc != 11) { std::cout << "Usage: " << argv[0] - << " [data_type] [data_file.bin] " + << " [data_type] [dist_fn: l2/mips] " + "[data_file.bin] " "[index_prefix_path] " - "[R] [L] [B] [M] [T]. See README for more information on " + "[R] [L] [B] [M] [T] [PQ_disk_bytes (for very large " + "dimensionality, use 0 for full vectors)]. See README for " + "more information on " "parameters." << std::endl; } else { - std::string params = std::string(argv[4]) + " " + std::string(argv[5]) + - " " + std::string(argv[6]) + " " + - std::string(argv[7]) + " " + std::string(argv[8]); + diskann::Metric metric = diskann::Metric::L2; + + if (std::string(argv[2]) == std::string("mips")) + metric = diskann::Metric::INNER_PRODUCT; + + std::string params = std::string(argv[5]) + " " + std::string(argv[6]) + + " " + std::string(argv[7]) + " " + + std::string(argv[8]) + " " + std::string(argv[9]) + + " " + std::string(argv[10]); if (std::string(argv[1]) == std::string("float")) - build_index(argv[2], argv[3], params.c_str()); + build_index(argv[3], argv[4], params.c_str(), metric); else if (std::string(argv[1]) == std::string("int8")) - build_index(argv[2], argv[3], params.c_str()); + build_index(argv[3], argv[4], params.c_str(), metric); else if (std::string(argv[1]) == std::string("uint8")) - build_index(argv[2], argv[3], params.c_str()); + build_index(argv[3], argv[4], params.c_str(), metric); else std::cout << "Error. wrong file type" << std::endl; } diff --git a/tests/build_memory_index.cpp b/tests/build_memory_index.cpp index a5a13595e..bfa50c849 100644 --- a/tests/build_memory_index.cpp +++ b/tests/build_memory_index.cpp @@ -16,7 +16,8 @@ #include "memory_mapper.h" template -int build_in_memory_index(const std::string& data_path, const unsigned R, +int build_in_memory_index(const std::string& data_path, + const diskann::Metric& metric, const unsigned R, const unsigned L, const float alpha, const std::string& save_path, const unsigned num_threads) { @@ -29,7 +30,7 @@ int build_in_memory_index(const std::string& data_path, const unsigned R, paras.Set("saturate_graph", 0); paras.Set("num_threads", num_threads); - diskann::Index index(diskann::L2, data_path.c_str()); + diskann::Index index(metric, data_path.c_str()); auto s = std::chrono::high_resolution_clock::now(); index.build(paras); std::chrono::duration diff = @@ -42,9 +43,9 @@ int build_in_memory_index(const std::string& data_path, const unsigned R, } int main(int argc, char** argv) { - if (argc != 8) { + if (argc != 9) { std::cout << "Usage: " << argv[0] - << " [data_type] [data_file.bin] " + << " [data_type] [l2/mips] [data_file.bin] " "[output_index_file] " << "[R] [L] [alpha]" << " [num_threads_to_use]. See README for more information on " @@ -53,21 +54,36 @@ int main(int argc, char** argv) { exit(-1); } - const std::string data_path(argv[2]); - const std::string save_path(argv[3]); - const unsigned R = (unsigned) atoi(argv[4]); - const unsigned L = (unsigned) atoi(argv[5]); - const float alpha = (float) atof(argv[6]); - const unsigned num_threads = (unsigned) atoi(argv[7]); + _u32 ctr = 2; + + diskann::Metric metric; + if (std::string(argv[ctr]) == std::string("mips")) + metric = diskann::Metric::INNER_PRODUCT; + else if (std::string(argv[ctr]) == std::string("l2")) + metric = diskann::Metric::L2; + else { + std::cout << "Unsupported distance function. Currently only L2/ Inner " + "Product support." + << std::endl; + return -1; + } + ctr++; + + const std::string data_path(argv[ctr++]); + const std::string save_path(argv[ctr++]); + const unsigned R = (unsigned) atoi(argv[ctr++]); + const unsigned L = (unsigned) atoi(argv[ctr++]); + const float alpha = (float) atof(argv[ctr++]); + const unsigned num_threads = (unsigned) atoi(argv[ctr++]); if (std::string(argv[1]) == std::string("int8")) - build_in_memory_index(data_path, R, L, alpha, save_path, + build_in_memory_index(data_path, metric, R, L, alpha, save_path, num_threads); else if (std::string(argv[1]) == std::string("uint8")) - build_in_memory_index(data_path, R, L, alpha, save_path, + build_in_memory_index(data_path, metric, R, L, alpha, save_path, num_threads); else if (std::string(argv[1]) == std::string("float")) - build_in_memory_index(data_path, R, L, alpha, save_path, + build_in_memory_index(data_path, metric, R, L, alpha, save_path, num_threads); else std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; diff --git a/tests/range_search_disk_index.cpp b/tests/range_search_disk_index.cpp new file mode 100644 index 000000000..a181772ba --- /dev/null +++ b/tests/range_search_disk_index.cpp @@ -0,0 +1,324 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "aux_utils.h" +#include "index.h" +#include "math_utils.h" +#include "memory_mapper.h" +#include "partition_and_pq.h" +#include "timer.h" +#include "utils.h" + +#ifndef _WINDOWS +#include +#include +#include +#include "linux_aligned_file_reader.h" +#else +#ifdef USE_BING_INFRA +#include "bing_aligned_file_reader.h" +#else +#include "windows_aligned_file_reader.h" +#endif +#endif + +#define WARMUP false + +void print_stats(std::string category, std::vector percentiles, + std::vector results) { + diskann::cout << std::setw(20) << category << ": " << std::flush; + for (uint32_t s = 0; s < percentiles.size(); s++) { + diskann::cout << std::setw(8) << percentiles[s] << "%"; + } + diskann::cout << std::endl; + diskann::cout << std::setw(22) << " " << std::flush; + for (uint32_t s = 0; s < percentiles.size(); s++) { + diskann::cout << std::setw(9) << results[s]; + } + diskann::cout << std::endl; +} + +template +int search_disk_index(int argc, char** argv) { + // load query bin + T* query = nullptr; +// unsigned* gt_ids = nullptr; +// float* gt_dists = nullptr; +std::vector> groundtruth_ids; + size_t query_num, query_dim, query_aligned_dim, gt_num; + std::vector<_u64> Lvec; + + _u32 ctr = 2; + diskann::Metric metric; + + if (std::string(argv[ctr]) == std::string("mips")) + metric = diskann::Metric::INNER_PRODUCT; + else if (std::string(argv[ctr]) == std::string("l2")) + metric = diskann::Metric::L2; + else { + std::cout << "Unsupported distance function. Currently only L2/ Inner " + "Product support." + << std::endl; + return -1; + } + + if ((std::string(argv[1]) != std::string("float")) && + (metric == diskann::Metric::INNER_PRODUCT)) { + std::cout << "Currently support only floating point data for Inner Product." + << std::endl; + return -1; + } + + ctr++; + + std::string index_prefix_path(argv[ctr++]); + std::string pq_prefix = index_prefix_path + "_pq"; + std::string disk_index_file = index_prefix_path + "_disk.index"; + std::string warmup_query_file = index_prefix_path + "_sample_data.bin"; + _u64 num_nodes_to_cache = std::atoi(argv[ctr++]); + _u32 num_threads = std::atoi(argv[ctr++]); + _u32 beamwidth = std::atoi(argv[ctr++]); + std::string query_bin(argv[ctr++]); + std::string truthset_bin(argv[ctr++]); + double search_range = std::atof(argv[ctr++]); + std::string result_output_prefix(argv[ctr++]); + + bool calc_recall_flag = false; + + for (; ctr < (_u32) argc; ctr++) { + _u64 curL = std::atoi(argv[ctr]); + Lvec.push_back(curL); + } + + if (Lvec.size() == 0) { + diskann::cout + << "No valid Lsearch found." + << std::endl; + return -1; + } + + diskann::cout << "Search parameters: #threads: " << num_threads << ", "; + if (beamwidth <= 0) + diskann::cout << "beamwidth to be optimized for each L value" << std::endl; + else + diskann::cout << " beamwidth: " << beamwidth << std::endl; + + diskann::load_aligned_bin(query_bin, query, query_num, query_dim, + query_aligned_dim); + + if (file_exists(truthset_bin)) { + diskann::load_range_truthset(truthset_bin, groundtruth_ids, gt_num); // use for range search type of truthset +// diskann::prune_truthset_for_range(truthset_bin, search_range, groundtruth_ids, gt_num); // use for traditional truthset + if (gt_num != query_num) { + diskann::cout + << "Error. Mismatch in number of queries and ground truth data" + << std::endl; + } + calc_recall_flag = true; + } + + std::shared_ptr reader = nullptr; +#ifdef _WINDOWS +#ifndef USE_BING_INFRA + reader.reset(new WindowsAlignedFileReader()); +#else + reader.reset(new diskann::BingAlignedFileReader()); +#endif +#else + reader.reset(new LinuxAlignedFileReader()); +#endif + + std::unique_ptr> _pFlashIndex( + new diskann::PQFlashIndex(reader, metric)); + + int res = _pFlashIndex->load(num_threads, pq_prefix.c_str(), + disk_index_file.c_str()); + + if (res != 0) { + return res; + } + // cache bfs levels + std::vector node_list; + diskann::cout << "Caching " << num_nodes_to_cache + << " BFS nodes around medoid(s)" << std::endl; + _pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list); + // _pFlashIndex->generate_cache_list_from_sample_queries( + // warmup_query_file, 15, 6, num_nodes_to_cache, num_threads, node_list); + _pFlashIndex->load_cache_list(node_list); + node_list.clear(); + node_list.shrink_to_fit(); + + omp_set_num_threads(num_threads); + + uint64_t warmup_L = 20; + uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0; + T* warmup = nullptr; + + if (WARMUP) { + if (file_exists(warmup_query_file)) { + diskann::load_aligned_bin(warmup_query_file, warmup, warmup_num, + warmup_dim, warmup_aligned_dim); + } else { + warmup_num = (std::min)((_u32) 150000, (_u32) 15000 * num_threads); + warmup_dim = query_dim; + warmup_aligned_dim = query_aligned_dim; + diskann::alloc_aligned(((void**) &warmup), + warmup_num * warmup_aligned_dim * sizeof(T), + 8 * sizeof(T)); + std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T)); + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(-128, 127); + for (uint32_t i = 0; i < warmup_num; i++) { + for (uint32_t d = 0; d < warmup_dim; d++) { + warmup[i * warmup_aligned_dim + d] = (T) dis(gen); + } + } + } + diskann::cout << "Warming up index... " << std::flush; + std::vector warmup_result_ids_64(warmup_num, 0); + std::vector warmup_result_dists(warmup_num, 0); + +#pragma omp parallel for schedule(dynamic, 1) + for (_s64 i = 0; i < (int64_t) warmup_num; i++) { + _pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, + warmup_L, + warmup_result_ids_64.data() + (i * 1), + warmup_result_dists.data() + (i * 1), 4); + } + diskann::cout << "..done" << std::endl; + } + + diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); + diskann::cout.precision(2); + + std::string recall_string = "Recall@rng=" + std::to_string(search_range); + diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" + << std::setw(16) << "QPS" << std::setw(16) << "Mean Latency" + << std::setw(16) << "99.9 Latency" << std::setw(16) + << "Mean IOs" << std::setw(16) << "CPU (s)"; + if (calc_recall_flag) { + diskann::cout << std::setw(16) << recall_string << std::endl; + } else + diskann::cout << std::endl; + diskann::cout + << "===============================================================" + "===========================================" + << std::endl; + + std::vector>> query_result_ids(Lvec.size()); + std::vector<_u64> indices; + std::vector distances; + + uint32_t optimized_beamwidth = 2; + + for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) { + _u64 L = Lvec[test_id]; + indices.clear(); + distances.clear(); + indices.resize(L*query_num); + distances.resize(L*query_num); + + if (beamwidth <= 0) { + // diskann::cout<<"Tuning beamwidth.." << std::endl; + optimized_beamwidth = + optimize_beamwidth(_pFlashIndex, warmup, warmup_num, + warmup_aligned_dim, L, optimized_beamwidth); + } else + optimized_beamwidth = beamwidth; + + query_result_ids[test_id].clear(); + query_result_ids[test_id].resize(query_num); + + diskann::QueryStats* stats = new diskann::QueryStats[query_num]; + + auto s = std::chrono::high_resolution_clock::now(); +#pragma omp parallel for schedule(dynamic, 1) + for (_s64 i = 0; i < (int64_t) query_num; i++) { + _u32 res_count = + _pFlashIndex->range_search( + query + (i * query_aligned_dim), search_range, L, + indices.data() + i*L, distances.data() + i *L, + optimized_beamwidth, stats + i); + // std::cout< diff = e - s; + float qps = (1.0 * query_num) / (1.0 * diff.count()); + + float mean_latency = diskann::get_mean_stats( + stats, query_num, + [](const diskann::QueryStats& stats) { return stats.total_us; }); + + float latency_999 = diskann::get_percentile_stats( + stats, query_num, 0.999, + [](const diskann::QueryStats& stats) { return stats.total_us; }); + + float mean_ios = diskann::get_mean_stats( + stats, query_num, + [](const diskann::QueryStats& stats) { return stats.n_ios; }); + + float mean_cpuus = diskann::get_mean_stats( + stats, query_num, + [](const diskann::QueryStats& stats) { return stats.cpu_us; }); + + float recall = 0; + if (calc_recall_flag) { + recall = diskann::calculate_range_search_recall(query_num, groundtruth_ids, query_result_ids[test_id]); + } + + diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth + << std::setw(16) << qps << std::setw(16) << mean_latency + << std::setw(16) << latency_999 << std::setw(16) << mean_ios + << std::setw(16) << mean_cpuus; + if (calc_recall_flag) { + diskann::cout << std::setw(16) << recall << std::endl; + } else + diskann::cout << std::endl; + } + + diskann::cout << "Done searching. " << std::endl; + + diskann::aligned_free(query); + if (warmup != nullptr) + diskann::aligned_free(warmup); + return 0; +} + +int main(int argc, char** argv) { + if (argc < 12) { + diskann::cout + << "Usage: " << argv[0] + << " [index_type] [dist_fn] " + "[index_prefix_path] " + " [num_nodes_to_cache] [num_threads] [beamwidth (use 0 to " + "optimize internally)] " + " [query_file.bin] [truthset.bin (use \"null\" for none)] " + " [range_threshold] [result_output_prefix] " + " [L1] [L2] etc. See README for more information on parameters." + << std::endl; + exit(-1); + } + if (std::string(argv[1]) == std::string("float")) + search_disk_index(argc, argv); + else if (std::string(argv[1]) == std::string("int8")) + search_disk_index(argc, argv); + else if (std::string(argv[1]) == std::string("uint8")) + search_disk_index(argc, argv); + else + diskann::cout << "Unsupported index type. Use float or int8 or uint8" + << std::endl; +} diff --git a/tests/search_disk_index.cpp b/tests/search_disk_index.cpp index 4d5c79ca8..9dacbc394 100644 --- a/tests/search_disk_index.cpp +++ b/tests/search_disk_index.cpp @@ -31,7 +31,7 @@ #endif #endif -#define WARMUP true +#define WARMUP false void print_stats(std::string category, std::vector percentiles, std::vector results) { @@ -56,21 +56,44 @@ int search_disk_index(int argc, char** argv) { size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim; std::vector<_u64> Lvec; - std::string index_prefix_path(argv[2]); + _u32 ctr = 2; + diskann::Metric metric; + + if (std::string(argv[ctr]) == std::string("mips")) + metric = diskann::Metric::INNER_PRODUCT; + else if (std::string(argv[ctr]) == std::string("l2")) + metric = diskann::Metric::L2; + else { + std::cout << "Unsupported distance function. Currently only L2/ Inner " + "Product support." + << std::endl; + return -1; + } + + if ((std::string(argv[1]) != std::string("float")) && + (metric == diskann::Metric::INNER_PRODUCT)) { + std::cout << "Currently support only floating point data for Inner Product." + << std::endl; + return -1; + } + + ctr++; + + std::string index_prefix_path(argv[ctr++]); std::string pq_prefix = index_prefix_path + "_pq"; std::string disk_index_file = index_prefix_path + "_disk.index"; std::string warmup_query_file = index_prefix_path + "_sample_data.bin"; - _u64 num_nodes_to_cache = std::atoi(argv[3]); - _u32 num_threads = std::atoi(argv[4]); - _u32 beamwidth = std::atoi(argv[5]); - std::string query_bin(argv[6]); - std::string truthset_bin(argv[7]); - _u64 recall_at = std::atoi(argv[8]); - std::string result_output_prefix(argv[9]); + _u64 num_nodes_to_cache = std::atoi(argv[ctr++]); + _u32 num_threads = std::atoi(argv[ctr++]); + _u32 beamwidth = std::atoi(argv[ctr++]); + std::string query_bin(argv[ctr++]); + std::string truthset_bin(argv[ctr++]); + _u64 recall_at = std::atoi(argv[ctr++]); + std::string result_output_prefix(argv[ctr++]); bool calc_recall_flag = false; - for (int ctr = 10; ctr < argc; ctr++) { + for (; ctr < (_u32) argc; ctr++) { _u64 curL = std::atoi(argv[ctr]); if (curL >= recall_at) Lvec.push_back(curL); @@ -114,7 +137,7 @@ int search_disk_index(int argc, char** argv) { #endif std::unique_ptr> _pFlashIndex( - new diskann::PQFlashIndex(reader)); + new diskann::PQFlashIndex(reader, metric)); int res = _pFlashIndex->load(num_threads, pq_prefix.c_str(), disk_index_file.c_str()); @@ -126,9 +149,9 @@ int search_disk_index(int argc, char** argv) { std::vector node_list; diskann::cout << "Caching " << num_nodes_to_cache << " BFS nodes around medoid(s)" << std::endl; - // _pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list); - _pFlashIndex->generate_cache_list_from_sample_queries( - warmup_query_file, 15, 6, num_nodes_to_cache, num_threads, node_list); + _pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list); + // _pFlashIndex->generate_cache_list_from_sample_queries( + // warmup_query_file, 15, 6, num_nodes_to_cache, num_threads, node_list); _pFlashIndex->load_cache_list(node_list); node_list.clear(); node_list.shrink_to_fit(); @@ -196,8 +219,6 @@ int search_disk_index(int argc, char** argv) { uint32_t optimized_beamwidth = 2; - // query_num = 1; - for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) { _u64 L = Lvec[test_id]; @@ -216,7 +237,7 @@ int search_disk_index(int argc, char** argv) { std::vector query_result_ids_64(recall_at * query_num); auto s = std::chrono::high_resolution_clock::now(); -#pragma omp parallel for schedule(dynamic, 1) +#pragma omp parallel for schedule(dynamic, 1) for (_s64 i = 0; i < (int64_t) query_num; i++) { _pFlashIndex->cached_beam_search( query + (i * query_aligned_dim), recall_at, L, @@ -286,10 +307,11 @@ int search_disk_index(int argc, char** argv) { } int main(int argc, char** argv) { - if (argc < 11) { + if (argc < 12) { diskann::cout << "Usage: " << argv[0] - << " [index_type] [index_prefix_path] " + << " [index_type] [dist_fn] " + "[index_prefix_path] " " [num_nodes_to_cache] [num_threads] [beamwidth (use 0 to " "optimize internally)] " " [query_file.bin] [truthset.bin (use \"null\" for none)] " diff --git a/tests/search_memory_index.cpp b/tests/search_memory_index.cpp index 32592710e..e9951eb6d 100644 --- a/tests/search_memory_index.cpp +++ b/tests/search_memory_index.cpp @@ -27,26 +27,44 @@ int search_memory_index(int argc, char** argv) { size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim; std::vector<_u64> Lvec; - std::string data_file(argv[2]); - std::string memory_index_file(argv[3]); - _u64 num_threads = std::atoi(argv[4]); - std::string query_bin(argv[5]); - std::string truthset_bin(argv[6]); - _u64 recall_at = std::atoi(argv[7]); - std::string result_output_prefix(argv[8]); - bool use_optimized_search = std::atoi(argv[9]); + _u32 ctr = 2; + diskann::Metric metric; + + if (std::string(argv[ctr]) == std::string("mips")) + metric = diskann::Metric::INNER_PRODUCT; + else if (std::string(argv[ctr]) == std::string("l2")) + metric = diskann::Metric::L2; + else if (std::string(argv[ctr]) == std::string("fast_l2")) + metric = diskann::Metric::FAST_L2; + else { + std::cout << "Unsupported distance function. Currently only L2/ Inner " + "Product/FAST_L2 support." + << std::endl; + return -1; + } + ctr++; if ((std::string(argv[1]) != std::string("float")) && - (use_optimized_search == true)) { - std::cout << "Error. Optimized search currently only supported for " - "floating point datatypes. Using un-optimized search." + ((metric == diskann::Metric::INNER_PRODUCT) || + (metric == diskann::Metric::FAST_L2))) { + std::cout << "Error. Inner product and Fast_L2 search currently only " + "supported for " + "floating point datatypes." << std::endl; - use_optimized_search = false; } + std::string data_file(argv[ctr++]); + std::string memory_index_file(argv[ctr++]); + _u64 num_threads = std::atoi(argv[ctr++]); + std::string query_bin(argv[ctr++]); + std::string truthset_bin(argv[ctr++]); + _u64 recall_at = std::atoi(argv[ctr++]); + std::string result_output_prefix(argv[ctr++]); + // bool use_optimized_search = std::atoi(argv[ctr++]); + bool calc_recall_flag = false; - for (int ctr = 10; ctr < argc; ctr++) { + for (; ctr < (_u32) argc; ctr++) { _u64 curL = std::atoi(argv[ctr]); if (curL >= recall_at) Lvec.push_back(curL); @@ -73,14 +91,11 @@ int search_memory_index(int argc, char** argv) { std::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); std::cout.precision(2); - auto metric = diskann::L2; - if (use_optimized_search) - metric = diskann::FAST_L2; diskann::Index index(metric, data_file.c_str()); index.load(memory_index_file.c_str()); // to load NSG std::cout << "Index loaded" << std::endl; - if (use_optimized_search) + if (metric == diskann::FAST_L2) index.optimize_graph(); diskann::Parameters paras; @@ -106,7 +121,7 @@ int search_memory_index(int argc, char** argv) { #pragma omp parallel for schedule(dynamic, 1) for (int64_t i = 0; i < (int64_t) query_num; i++) { auto qs = std::chrono::high_resolution_clock::now(); - if (use_optimized_search) { + if (metric == diskann::FAST_L2) { index.search_with_opt_graph( query + i * query_aligned_dim, recall_at, L, query_result_ids[test_id].data() + i * recall_at); @@ -160,11 +175,11 @@ int main(int argc, char** argv) { if (argc < 11) { std::cout << "Usage: " << argv[0] - << " [index_type] [data_file.bin] " + << " [index_type] [dist_fn (l2/mips/fast_l2)] " + "[data_file.bin] " "[memory_index_path] [num_threads] " "[query_file.bin] [truthset.bin (use \"null\" for none)] " - " [K] [result_output_prefix] [use_optimized_search (for small ~1M " - "data)] " + " [K] [result_output_prefix]" " [L1] [L2] etc. See README for more information on parameters. " << std::endl; exit(-1); diff --git a/tests/test_incremental_index.cpp b/tests/test_incremental_index.cpp index 8b98233b2..533ca3a16 100644 --- a/tests/test_incremental_index.cpp +++ b/tests/test_incremental_index.cpp @@ -49,7 +49,7 @@ int main(int argc, char** argv) { paras.Set("saturate_graph", false); paras.Set("num_rnds", num_rnds); - typedef int TagT; + typedef int TagT; diskann::Index index(diskann::L2, argv[1], num_points, num_points - num_incr, num_frozen, true, true, true); diff --git a/tests/utils/CMakeLists.txt b/tests/utils/CMakeLists.txt index e69722dcf..a8171e739 100644 --- a/tests/utils/CMakeLists.txt +++ b/tests/utils/CMakeLists.txt @@ -47,6 +47,16 @@ else() target_link_libraries(int8_to_float ${PROJECT_NAME}) endif() +add_executable(uint8_to_float uint8_to_float.cpp) +if(MSVC) + target_link_options(uint8_to_float PRIVATE /MACHINE:x64) + target_link_libraries(uint8_to_float debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib) + target_link_libraries(uint8_to_float optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib) +else() + target_link_libraries(uint8_to_float ${PROJECT_NAME}) +endif() + + add_executable(uint32_to_uint8 uint32_to_uint8.cpp) if(MSVC) target_link_options(uint32_to_uint8 PRIVATE /MACHINE:x64) @@ -56,6 +66,16 @@ else() target_link_libraries(uint32_to_uint8 ${PROJECT_NAME}) endif() + +add_executable(vector_analysis vector_analysis.cpp) +if(MSVC) + target_link_options(vector_analysis PRIVATE /MACHINE:x64) + target_link_libraries(vector_analysis debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib) + target_link_libraries(vector_analysis optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib) +else() + target_link_libraries(vector_analysis ${PROJECT_NAME} -ltcmalloc) +endif() + add_executable(gen_random_slice gen_random_slice.cpp) if(MSVC) target_link_options(gen_random_slice PRIVATE /MACHINE:x64) @@ -136,6 +156,6 @@ endif() # formatter if (LINUX) - add_custom_command(TARGET gen_random_slice PRE_BUILD COMMAND clang-format-4.0 -i ../../../include/*.h ../../../include/dll/*.h ../../../src/*.cpp ../../../tests/*.cpp ../../../src/dll/*.cpp ../../../tests/utils/*.cpp) + add_custom_command(TARGET gen_random_slice PRE_BUILD COMMAND clang-format -i ../../../include/*.h ../../../include/dll/*.h ../../../src/*.cpp ../../../tests/*.cpp ../../../src/dll/*.cpp ../../../tests/utils/*.cpp) endif() diff --git a/tests/utils/bin_to_tsv.cpp b/tests/utils/bin_to_tsv.cpp index 37874e243..9a7e180a4 100644 --- a/tests/utils/bin_to_tsv.cpp +++ b/tests/utils/bin_to_tsv.cpp @@ -22,7 +22,8 @@ void block_convert(std::ofstream& writer, std::ifstream& reader, T* read_buf, int main(int argc, char** argv) { if (argc != 4) { - std::cout << argv[0] << " input_bin output_tsv" << std::endl; + std::cout << argv[0] << " input_bin output_tsv" + << std::endl; exit(-1); } std::string type_string(argv[1]); @@ -50,9 +51,10 @@ int main(int argc, char** argv) { for (_u64 i = 0; i < nblks; i++) { _u64 cblk_size = std::min(npts - i * blk_size, blk_size); if (type_string == std::string("float")) - block_convert(writer, reader, (float*)read_buf, cblk_size, ndims); + block_convert(writer, reader, (float*) read_buf, cblk_size, ndims); else if (type_string == std::string("int8")) - block_convert(writer, reader, (int8_t*) read_buf, cblk_size, ndims); + block_convert(writer, reader, (int8_t*) read_buf, cblk_size, + ndims); else if (type_string == std::string("uint8")) block_convert(writer, reader, (uint8_t*) read_buf, cblk_size, ndims); diff --git a/tests/utils/compute_groundtruth.cpp b/tests/utils/compute_groundtruth.cpp index 8fef8c929..6e331530a 100644 --- a/tests/utils/compute_groundtruth.cpp +++ b/tests/utils/compute_groundtruth.cpp @@ -31,10 +31,11 @@ #define ALIGNMENT 512 void command_line_help() { - std::cerr - << " " - << std::endl; + std::cerr << "./compute_groundtruth " + " " + " " + << std::endl; } template @@ -104,6 +105,27 @@ void distsq_to_points( delete[] ones_vec; } +void inner_prod_to_points( + const size_t dim, + float * dist_matrix, // Col Major, cols are queries, rows are points + size_t npoints, const float *const points, size_t nqueries, + const float *const queries, + float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 +{ + bool ones_vec_alloc = false; + if (ones_vec == NULL) { + ones_vec = new float[nqueries > npoints ? nqueries : npoints]; + std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float) 1.0); + ones_vec_alloc = true; + } + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, + (float) -1.0, points, dim, queries, dim, (float) 0.0, dist_matrix, + npoints); + + if (ones_vec_alloc) + delete[] ones_vec; +} + void exact_knn(const size_t dim, const size_t k, int *const closest_points, // k * num_queries preallocated, col // major, queries columns @@ -112,14 +134,23 @@ void exact_knn(const size_t dim, const size_t k, // corresponding closes_points size_t npoints, const float *const points, // points in Col major - size_t nqueries, - const float *const queries) // queries in Col major + size_t nqueries, const float *const queries, + bool use_mip = false) // queries in Col major { float *points_l2sq = new float[npoints]; float *queries_l2sq = new float[nqueries]; compute_l2sq(points_l2sq, points, npoints, dim); compute_l2sq(queries_l2sq, queries, nqueries, dim); + std::cout << "Going to compute " << k << " NNs for " << nqueries + << " queries over " << npoints << " points in " << dim + << " dimensions using"; + if (use_mip) + std::cout << " inner product "; + else + std::cout << " L2 "; + std::cout << "distance fn. " << std::endl; + size_t q_batch_size = (1 << 9); float *dist_matrix = new float[(size_t) q_batch_size * (size_t) npoints]; @@ -128,9 +159,14 @@ void exact_knn(const size_t dim, const size_t k, int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size; - distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b, - queries + (ptrdiff_t) q_b * (ptrdiff_t) dim, - queries_l2sq + q_b); + if (!use_mip) { + distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, + q_e - q_b, queries + (ptrdiff_t) q_b * (ptrdiff_t) dim, + queries_l2sq + q_b); + } else { + inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b, + queries + (ptrdiff_t) q_b * (ptrdiff_t) dim); + } std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" << std::endl; @@ -265,13 +301,15 @@ inline void save_groundtruth_as_one_file(const std::string filename, } template -int aux_main(int argv, char **argc) { - +int aux_main(char **argc) { size_t npoints, nqueries, dim; std::string base_file(argc[2]); std::string query_file(argc[3]); size_t k = atoi(argc[4]); + bool use_mip = false; std::string gt_file(argc[5]); + if (std::string(argc[6]) == std::string("mips")) + use_mip = true; float *base_data; float *query_data; @@ -291,7 +329,7 @@ int aux_main(int argv, char **argc) { float *dist_closest_points_part = new float[nqueries * k]; exact_knn(dim, k, closest_points_part, dist_closest_points_part, npoints, - base_data, nqueries, query_data); + base_data, nqueries, query_data, use_mip); for (_u64 i = 0; i < nqueries; i++) { for (_u64 j = 0; j < k; j++) { @@ -303,6 +341,7 @@ int aux_main(int argv, char **argc) { delete[] closest_points_part; delete[] dist_closest_points_part; + diskann::aligned_free(base_data); } @@ -311,7 +350,10 @@ int aux_main(int argv, char **argc) { std::sort(cur_res.begin(), cur_res.end(), custom_dist); for (_u64 j = 0; j < k; j++) { closest_points[i * k + j] = (int32_t) cur_res[j].first; - dist_closest_points[i * k + j] = cur_res[j].second; + if (use_mip) + dist_closest_points[i * k + j] = -cur_res[j].second; + else + dist_closest_points[i * k + j] = cur_res[j].second; } } @@ -324,15 +366,15 @@ int aux_main(int argv, char **argc) { } int main(int argc, char **argv) { - if (argc != 6) { + if (argc != 7) { command_line_help(); return -1; } if (std::string(argv[1]) == std::string("float")) - aux_main(argc, argv); + aux_main(argv); if (std::string(argv[1]) == std::string("int8")) - aux_main(argc, argv); + aux_main(argv); if (std::string(argv[1]) == std::string("uint8")) - aux_main(argc, argv); + aux_main(argv); } diff --git a/tests/utils/create_disk_layout.cpp b/tests/utils/create_disk_layout.cpp index ac378272d..d0e871283 100644 --- a/tests/utils/create_disk_layout.cpp +++ b/tests/utils/create_disk_layout.cpp @@ -13,13 +13,7 @@ #include "utils.h" template -int create_disk_layout(int argc, char **argv) { - if (argc != 5) { - std::cout << argv[0] << " data_type data_bin " - "vamana_index_file output_diskann_index_file" - << std::endl; - exit(-1); - } +int create_disk_layout(char **argv) { std::string base_file(argv[2]); std::string vamana_file(argv[3]); std::string output_file(argv[4]); @@ -28,13 +22,21 @@ int create_disk_layout(int argc, char **argv) { } int main(int argc, char **argv) { + if (argc != 5) { + std::cout << argv[0] + << " data_type data_bin " + "vamana_index_file output_diskann_index_file" + << std::endl; + exit(-1); + } + int ret_val = -1; if (std::string(argv[1]) == std::string("float")) - ret_val = create_disk_layout(argc, argv); + ret_val = create_disk_layout(argv); else if (std::string(argv[1]) == std::string("int8")) - ret_val = create_disk_layout(argc, argv); + ret_val = create_disk_layout(argv); else if (std::string(argv[1]) == std::string("uint8")) - ret_val = create_disk_layout(argc, argv); + ret_val = create_disk_layout(argv); else { std::cout << "unsupported type. use int8/uint8/float " << std::endl; ret_val = -2; diff --git a/tests/utils/float_bin_to_int8.cpp b/tests/utils/float_bin_to_int8.cpp index 4f422a233..0620730a5 100644 --- a/tests/utils/float_bin_to_int8.cpp +++ b/tests/utils/float_bin_to_int8.cpp @@ -4,7 +4,6 @@ #include #include "utils.h" - void block_convert(std::ofstream& writer, int8_t* write_buf, std::ifstream& reader, float* read_buf, _u64 npts, _u64 ndims, float bias, float scale) { diff --git a/tests/utils/gen_random_slice.cpp b/tests/utils/gen_random_slice.cpp index 0417c12c0..c91e9d44e 100644 --- a/tests/utils/gen_random_slice.cpp +++ b/tests/utils/gen_random_slice.cpp @@ -21,14 +21,7 @@ #include template -int aux_main(int argc, char** argv) { - if (argc != 5) { - std::cout << argv[0] << " data_type [fliat/int8/uint8] base_bin_file " - "sample_output_prefix sampling_probability" - << std::endl; - exit(-1); - } - +int aux_main(char** argv) { std::string base_file(argv[2]); std::string output_prefix(argv[3]); float sampling_rate = (float) (std::atof(argv[4])); @@ -37,12 +30,20 @@ int aux_main(int argc, char** argv) { } int main(int argc, char** argv) { + if (argc != 5) { + std::cout << argv[0] + << " data_type [float/int8/uint8] base_bin_file " + "sample_output_prefix sampling_probability" + << std::endl; + exit(-1); + } + if (std::string(argv[1]) == std::string("float")) { - aux_main(argc, argv); + aux_main(argv); } else if (std::string(argv[1]) == std::string("int8")) { - aux_main(argc, argv); + aux_main(argv); } else if (std::string(argv[1]) == std::string("uint8")) { - aux_main(argc, argv); + aux_main(argv); } else std::cout << "Unsupported type. Use float/int8/uint8." << std::endl; return 0; diff --git a/tests/utils/partition_data.cpp b/tests/utils/partition_data.cpp index 63476b15e..593bcb119 100644 --- a/tests/utils/partition_data.cpp +++ b/tests/utils/partition_data.cpp @@ -11,9 +11,10 @@ int main(int argc, char** argv) { if (argc != 7) { std::cout << "Usage:\n" - << argv[0] << " datatype " - " " - " " + << argv[0] + << " datatype " + " " + " " << std::endl; exit(-1); } diff --git a/tests/utils/partition_with_ram_budget.cpp b/tests/utils/partition_with_ram_budget.cpp index fee3c84e8..6cd6a9401 100644 --- a/tests/utils/partition_with_ram_budget.cpp +++ b/tests/utils/partition_with_ram_budget.cpp @@ -11,9 +11,10 @@ int main(int argc, char** argv) { if (argc != 8) { std::cout << "Usage:\n" - << argv[0] << " datatype " - " " - " " + << argv[0] + << " datatype " + " " + " " << std::endl; exit(-1); } diff --git a/tests/utils/tsv_to_bin.cpp b/tests/utils/tsv_to_bin.cpp index 111a6bb55..776e06343 100644 --- a/tests/utils/tsv_to_bin.cpp +++ b/tests/utils/tsv_to_bin.cpp @@ -26,8 +26,9 @@ void block_convert(std::ifstream& reader, std::ofstream& writer, _u64 npts, int main(int argc, char** argv) { if (argc != 6) { std::cout << argv[0] - << " input_filename.tsv output_filename.bin dim num_pts>" - << std::endl; + << " input_filename.tsv output_filename.bin " + "dim num_pts>" + << std::endl; exit(-1); } @@ -50,7 +51,7 @@ int main(int argc, char** argv) { _u64 nblks = ROUND_UP(npts, blk_size) / blk_size; std::cout << "# blks: " << nblks << std::endl; std::ofstream writer(argv[3], std::ios::binary); - auto npts_s32 = (_u32) npts; + auto npts_s32 = (_u32) npts; auto ndims_s32 = (_u32) ndims; writer.write((char*) &npts_s32, sizeof(_u32)); writer.write((char*) &ndims_s32, sizeof(_u32)); diff --git a/tests/utils/uint8_to_float.cpp b/tests/utils/uint8_to_float.cpp new file mode 100644 index 000000000..6a6b3b2fe --- /dev/null +++ b/tests/utils/uint8_to_float.cpp @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include "utils.h" + +int main(int argc, char** argv) { + if (argc != 3) { + std::cout << argv[0] << " input_uint8_bin output_float_bin" << std::endl; + exit(-1); + } + + uint8_t* input; + size_t npts, nd; + diskann::load_bin(argv[1], input, npts, nd); + float* output = new float[npts * nd]; + diskann::convert_types(input, output, npts, nd); + diskann::save_bin(argv[2], output, npts, nd); + delete[] output; + delete[] input; +} diff --git a/tests/utils/vector_analysis.cpp b/tests/utils/vector_analysis.cpp new file mode 100644 index 000000000..af75ff772 --- /dev/null +++ b/tests/utils/vector_analysis.cpp @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "partition_and_pq.h" +#include "utils.h" + +#include +#include +#include +#include + +template +int analyze_norm(std::string base_file) { + std::cout << "Analyzing data norms" << std::endl; + T* data; + _u64 npts, ndims; + diskann::load_bin(base_file, data, npts, ndims); + std::vector norms(npts, 0); +#pragma omp parallel for schedule(dynamic) + for (_u32 i = 0; i < npts; i++) { + for (_u32 d = 0; d < ndims; d++) + norms[i] += data[i * ndims + d] * data[i * ndims + d]; + norms[i] = std::sqrt(norms[i]); + } + std::sort(norms.begin(), norms.end()); + for (_u32 p = 0; p < 100; p += 5) + std::cout << "percentile " << p << ": " + << norms[std::floor((p / 100.0) * npts)] << std::endl; + std::cout << "percentile 100" + << ": " << norms[npts - 1] << std::endl; + delete[] data; + return 0; +} + +template +int normalize_base(std::string base_file, std::string out_file) { + std::cout << "Normalizing base" << std::endl; + T* data; + _u64 npts, ndims; + diskann::load_bin(base_file, data, npts, ndims); + // std::vector norms(npts, 0); +#pragma omp parallel for schedule(dynamic) + for (_u32 i = 0; i < npts; i++) { + float pt_norm = 0; + for (_u32 d = 0; d < ndims; d++) + pt_norm += data[i * ndims + d] * data[i * ndims + d]; + pt_norm = std::sqrt(pt_norm); + for (_u32 d = 0; d < ndims; d++) + data[i * ndims + d] = data[i * ndims + d] / pt_norm; + } + diskann::save_bin(out_file, data, npts, ndims); + delete[] data; + return 0; +} + +template +int augment_base(std::string base_file, std::string out_file, + bool prep_base = true) { + std::cout << "Analyzing data norms" << std::endl; + T* data; + _u64 npts, ndims; + diskann::load_bin(base_file, data, npts, ndims); + std::vector norms(npts, 0); + float max_norm = 0; +#pragma omp parallel for schedule(dynamic) + for (_u32 i = 0; i < npts; i++) { + for (_u32 d = 0; d < ndims; d++) + norms[i] += data[i * ndims + d] * data[i * ndims + d]; + max_norm = norms[i] > max_norm ? norms[i] : max_norm; + } + // std::sort(norms.begin(), norms.end()); + max_norm = std::sqrt(max_norm); + std::cout << "Max norm: " << max_norm << std::endl; + T* new_data; + _u64 newdims = ndims + 1; + new_data = new T[npts * newdims]; + for (_u64 i = 0; i < npts; i++) { + if (prep_base) { + for (_u64 j = 0; j < ndims; j++) { + new_data[i * newdims + j] = data[i * ndims + j] / max_norm; + } + float diff = 1 - (norms[i] / (max_norm * max_norm)); + diff = diff <= 0 ? 0 : std::sqrt(diff); + new_data[i * newdims + ndims] = diff; + if (diff <= 0) { + std::cout << i << " has large max norm, investigate if needed. diff = " + << diff << std::endl; + } + } else { + for (_u64 j = 0; j < ndims; j++) { + new_data[i * newdims + j] = data[i * ndims + j] / std::sqrt(norms[i]); + } + new_data[i * newdims + ndims] = 0; + } + } + diskann::save_bin(out_file, new_data, npts, newdims); + delete[] new_data; + delete[] data; + return 0; +} + +template +int aux_main(char** argv) { + std::string base_file(argv[2]); + _u32 option = atoi(argv[3]); + if (option == 1) + analyze_norm(base_file); + else if (option == 2) + augment_base(base_file, std::string(argv[4]), true); + else if (option == 3) + augment_base(base_file, std::string(argv[4]), false); + else if (option == 4) + normalize_base(base_file, std::string(argv[4])); + return 0; +} + +int main(int argc, char** argv) { + if (argc < 4) { + std::cout + << argv[0] + << " data_type [float/int8/uint8] base_bin_file " + "[option: 1-norm analysis, 2-prep_base_for_mip, " + "3-prep_query_for_mip, 4-normalize-vecs] [out_file for options 2/3/4]" + << std::endl; + exit(-1); + } + + if (std::string(argv[1]) == std::string("float")) { + aux_main(argv); + } else if (std::string(argv[1]) == std::string("int8")) { + aux_main(argv); + } else if (std::string(argv[1]) == std::string("uint8")) { + aux_main(argv); + } else + std::cout << "Unsupported type. Use float/int8/uint8." << std::endl; + return 0; +} diff --git a/unit_tester.sh b/unit_tester.sh new file mode 100755 index 000000000..c4e987a42 --- /dev/null +++ b/unit_tester.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# Performs build and search test on disk and memory indices (parameters are tuned for 100K-1M sized datasets) +# All indices and logs will be stored in working_folder after run is complete +# To run, create a catalog text file consisting of the following entries +# For each dataset, specify the following 5 lines, in a line by line format, and then move on to next dataset +# dataset_name[used for save file names] +# /path/to/base.bin +# /path/to/query.bin +# data_type[float/uint8/int8] +# metric[l2/mips] +if [ "$#" -ne "3" ]; then + echo "usage: ./unit_test.sh [build_folder_path] [catalog] [working_folder]" +else + +BUILD_FOLDER=${1} +CATALOG1=${2} +WORK_FOLDER=${3} +mkdir ${WORK_FOLDER} +CATALOG="${WORK_FOLDER}/catalog_formatted.txt" +sed -e '/^$/d' ${CATALOG1} > ${CATALOG} + +echo Running unit testing on various files, with build folder as ${BUILD_FOLDER} and working folder as ${WORK_FOLDER} +# download all unit test files + +#iterate over them and run the corresponding test + + +while IFS= read -r line; do + DATASET=${line} + read -r BASE + read -r QUERY + read -r TYPE + read -r METRIC + GT="${WORK_FOLDER}/${DATASET}_gt30_${METRIC}" + MEM="${WORK_FOLDER}/${DATASET}_mem" + DISK="${WORK_FOLDER}/${DATASET}_disk" + MBLOG="${WORK_FOLDER}/${DATASET}_mb.log" + DBLOG="${WORK_FOLDER}/${DATASET}_db.log" + MSLOG="${WORK_FOLDER}/${DATASET}_ms.log" + DSLOG="${WORK_FOLDER}/${DATASET}_ds.log" + + FILESIZE=`wc -c "${BASE}" | awk '{print $1}'` + BUDGETBUILD=`bc <<< "scale=4; 0.0001 + ${FILESIZE}/(5*1024*1024*1024)"` + BUDGETSERVE=`bc <<< "scale=4; 0.0001 + ${FILESIZE}/(10*1024*1024*1024)"` + echo "=============================================================================================================================================" + echo "Running tests on ${DATASET} dataset, ${TYPE} datatype, $METRIC metric, ${BUDGETBUILD} GiB and ${BUDGETSERVE} GiB build and serve budget" + echo "=============================================================================================================================================" + rm ${DISK}_* + + #echo "Going to run test on ${BASE} base, ${QUERY} query, ${TYPE} datatype, ${METRIC} metric, saving gt at ${GT}" + echo "Computing Groundtruth" + ${BUILD_FOLDER}/tests/utils/compute_groundtruth ${TYPE} ${BASE} ${QUERY} 30 ${GT} ${METRIC} > /dev/null + echo "Building Mem Index" + /usr/bin/time ${BUILD_FOLDER}/tests/build_memory_index ${TYPE} ${METRIC} ${BASE} ${MEM} 32 50 1.2 0 > ${MBLOG} + awk '/^Degree/' ${MBLOG} + awk '/^Indexing/' ${MBLOG} + echo "Searching Mem Index" + ${BUILD_FOLDER}/tests/search_memory_index ${TYPE} ${METRIC} ${BASE} ${MEM} 16 ${QUERY} ${GT} 10 /tmp/res 10 20 30 40 50 60 70 80 90 100 > ${MSLOG} + awk '/===/{x=NR+10}(NR<=x){print}' ${MSLOG} + echo "Building Disk Index" + ${BUILD_FOLDER}/tests/build_disk_index ${TYPE} ${METRIC} ${BASE} ${DISK} 32 50 ${BUDGETSERVE} ${BUDGETBUILD} 32 0 > ${DBLOG} + awk '/^Compressing/' ${DBLOG} + echo "#shards in disk index" + awk '/^Indexing/' ${DBLOG} + echo "Searching Disk Index" + ${BUILD_FOLDER}/tests/search_disk_index ${TYPE} ${METRIC} ${DISK} 10000 10 4 ${QUERY} ${GT} 10 /tmp/res 20 40 60 80 100 > ${DSLOG} + echo "# shards used during index construction:" + awk '/medoids/{x=NR+1}(NR<=x){print}' ${DSLOG} + awk '/===/{x=NR+10}(NR<=x){print}' ${DSLOG} +done < "${CATALOG}" +fi