From d4ee77a909479bc6b2f88b985ddb6bf32e6f8aac Mon Sep 17 00:00:00 2001 From: chasingegg Date: Mon, 27 May 2024 14:19:25 +0800 Subject: [PATCH] Fix kmeans init endless loop when dist is inf Signed-off-by: chasingegg --- src/math_utils.cpp | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/math_utils.cpp b/src/math_utils.cpp index 7481da848..9fb82c845 100644 --- a/src/math_utils.cpp +++ b/src/math_utils.cpp @@ -397,10 +397,7 @@ void kmeanspp_selecting_pivots(float *data, size_t num_points, size_t dim, float std::uniform_real_distribution<> distribution(0, 1); std::uniform_int_distribution int_dist(0, num_points - 1); size_t init_id = int_dist(generator); - size_t num_picked = 1; - picked.push_back(init_id); - std::memcpy(pivot_data, data + init_id * dim, dim * sizeof(float)); float *dist = new float[num_points]; @@ -410,6 +407,22 @@ void kmeanspp_selecting_pivots(float *data, size_t num_points, size_t dim, float dist[i] = math_utils::calc_distance(data + i * dim, data + init_id * dim, dim); } + for (int64_t i = 0; i < (int64_t)num_points; i++) + { + if (std::isif(dist[i])) { + diskann::cout << "dist is inf, falling back to random pivot"; + << std::endl; + delete[] dist; + selecting_pivots(data, num_points, dim, pivot_data, num_centers); + return; + } + } + + size_t num_picked = 1; + + picked.push_back(init_id); + std::memcpy(pivot_data, data + init_id * dim, dim * sizeof(float)); + double dart_val; size_t tmp_pivot; bool sum_flag = false;