Skip to content

Commit

Permalink
Log PQ Training time as well.
Browse files Browse the repository at this point in the history
  • Loading branch information
Suryansh Gupta committed Jan 29, 2025
1 parent cbc76f8 commit 88743fd
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
6 changes: 4 additions & 2 deletions apps/search_disk_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
std::string recall_string = "Recall@" + std::to_string(recall_at);
diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" << std::setw(16) << "QPS" << std::setw(16)
<< "Mean Latency" << std::setw(16) << "99.9 Latency" << std::setw(16) << "Mean IOs" << std::setw(16)
<< "Mean IO (us)" << std::setw(16) << "CPU (s)";
<< "Mean IO (us)" << std::setw(16) << "CPU (s)" << std::setw(16)<< "PQ Training(s)";
if (calc_recall_flag)
{
diskann::cout << std::setw(16) << recall_string << std::endl;
Expand Down Expand Up @@ -272,6 +272,8 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre

auto mean_io_us = diskann::get_mean_stats<float>(stats, query_num,
[](const diskann::QueryStats &stats) { return stats.io_us; });
auto mean_pq_training_us = diskann::get_mean_stats<float>(stats, query_num,
[](const diskann::QueryStats &stats) { return stats.pq_training_us; });

double recall = 0;
if (calc_recall_flag)
Expand All @@ -283,7 +285,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre

diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps
<< std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios
<< std::setw(16) << mean_io_us << std::setw(16) << mean_cpuus;
<< std::setw(16) << mean_io_us << std::setw(16) << mean_cpuus << std::setw(16) << mean_pq_training_us;
if (calc_recall_flag)
{
diskann::cout << std::setw(16) << recall << std::endl;
Expand Down
1 change: 1 addition & 0 deletions include/percentile_stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct QueryStats
float total_us = 0; // total time to process query in micros
float io_us = 0; // total time spent in IO
float cpu_us = 0; // total time spent in CPU
float pq_training_us = 0; // total time spent in PQ training

unsigned n_4k = 0; // # of 4kB reads
unsigned n_8k = 0; // # of 8kB reads
Expand Down
4 changes: 3 additions & 1 deletion src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1275,7 +1275,7 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
if (beam_width > num_sector_per_nodes * defaults::MAX_N_SECTOR_READS)
throw ANNException("Beamwidth can not be higher than defaults::MAX_N_SECTOR_READS", -1, __FUNCSIG__, __FILE__,
__LINE__);
Timer query_timer, io_timer, cpu_timer;
Timer query_timer, io_timer, cpu_timer, pq_training_timer;
ScratchStoreManager<SSDThreadData<T>> manager(this->_thread_data);
auto data = manager.scratch_space();
IOContext &ctx = data->ctx;
Expand Down Expand Up @@ -1334,13 +1334,15 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
_nnodes_per_sector > 0 ? 1 : DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN);

cpu_timer.reset();
pq_training_timer.reset();
// query <-> PQ chunk centers distances
_pq_table.preprocess_query(query_rotated); // center the query and rotate if
// we have a rotation matrix
float *pq_dists = pq_query_scratch->aligned_pqtable_dist_scratch;
_pq_table.populate_chunk_distances(query_rotated, pq_dists);
if (stats != nullptr)
{
stats->pq_training_us = (float)pq_training_timer.elapsed();
stats->cpu_us += (float)cpu_timer.elapsed();
}

Expand Down

0 comments on commit 88743fd

Please sign in to comment.