|
| 1 | +#include "../../hnswlib/hnswlib.h" |
| 2 | + |
| 3 | +typedef unsigned int docidtype; |
| 4 | +typedef float dist_t; |
| 5 | + |
| 6 | +int main() { |
| 7 | + int dim = 16; // Dimension of the elements |
| 8 | + int max_elements = 10000; // Maximum number of elements, should be known beforehand |
| 9 | + int M = 16; // Tightly connected with internal dimensionality of the data |
| 10 | + // strongly affects the memory consumption |
| 11 | + int ef_construction = 200; // Controls index search speed/build speed tradeoff |
| 12 | + |
| 13 | + int num_queries = 5; |
| 14 | + int num_docs = 5; // Number of documents to search |
| 15 | + int ef_collection = 6; // Number of candidate documents during the search |
| 16 | + // Controlls the recall: higher ef leads to better accuracy, but slower search |
| 17 | + docidtype min_doc_id = 0; |
| 18 | + docidtype max_doc_id = 9; |
| 19 | + |
| 20 | + // Initing index |
| 21 | + hnswlib::MultiVectorL2Space<docidtype> space(dim); |
| 22 | + hnswlib::HierarchicalNSW<dist_t>* alg_hnsw = new hnswlib::HierarchicalNSW<dist_t>(&space, max_elements, M, ef_construction); |
| 23 | + |
| 24 | + // Generate random data |
| 25 | + std::mt19937 rng; |
| 26 | + rng.seed(47); |
| 27 | + std::uniform_real_distribution<> distrib_real; |
| 28 | + std::uniform_int_distribution<docidtype> distrib_docid(min_doc_id, max_doc_id); |
| 29 | + |
| 30 | + size_t data_point_size = space.get_data_size(); |
| 31 | + char* data = new char[data_point_size * max_elements]; |
| 32 | + for (int i = 0; i < max_elements; i++) { |
| 33 | + // set vector value |
| 34 | + char* point_data = data + i * data_point_size; |
| 35 | + for (int j = 0; j < dim; j++) { |
| 36 | + char* vec_data = point_data + j * sizeof(float); |
| 37 | + float value = distrib_real(rng); |
| 38 | + *(float*)vec_data = value; |
| 39 | + } |
| 40 | + // set document id |
| 41 | + docidtype doc_id = distrib_docid(rng); |
| 42 | + space.set_doc_id(point_data, doc_id); |
| 43 | + } |
| 44 | + |
| 45 | + // Add data to index |
| 46 | + std::unordered_map<hnswlib::labeltype, docidtype> label_docid_lookup; |
| 47 | + for (int i = 0; i < max_elements; i++) { |
| 48 | + hnswlib::labeltype label = i; |
| 49 | + char* point_data = data + i * data_point_size; |
| 50 | + alg_hnsw->addPoint(point_data, label); |
| 51 | + label_docid_lookup[label] = space.get_doc_id(point_data); |
| 52 | + } |
| 53 | + |
| 54 | + // Query random vectors |
| 55 | + size_t query_size = dim * sizeof(float); |
| 56 | + for (int i = 0; i < num_queries; i++) { |
| 57 | + char* query_data = new char[query_size]; |
| 58 | + for (int j = 0; j < dim; j++) { |
| 59 | + size_t offset = j * sizeof(float); |
| 60 | + char* vec_data = query_data + offset; |
| 61 | + float value = distrib_real(rng); |
| 62 | + *(float*)vec_data = value; |
| 63 | + } |
| 64 | + std::cout << "Query #" << i << "\n"; |
| 65 | + hnswlib::MultiVectorSearchStopCondition<docidtype, dist_t> stop_condition(space, num_docs, ef_collection); |
| 66 | + std::vector<std::pair<float, hnswlib::labeltype>> result = |
| 67 | + alg_hnsw->searchStopConditionClosest(query_data, stop_condition); |
| 68 | + size_t num_vectors = result.size(); |
| 69 | + |
| 70 | + std::unordered_map<docidtype, size_t> doc_counter; |
| 71 | + for (auto pair: result) { |
| 72 | + hnswlib::labeltype label = pair.second; |
| 73 | + docidtype doc_id = label_docid_lookup[label]; |
| 74 | + doc_counter[doc_id] += 1; |
| 75 | + } |
| 76 | + std::cout << "Found " << doc_counter.size() << " documents, " << num_vectors << " vectors\n"; |
| 77 | + delete[] query_data; |
| 78 | + } |
| 79 | + |
| 80 | + delete[] data; |
| 81 | + delete alg_hnsw; |
| 82 | + return 0; |
| 83 | +} |
0 commit comments