Skip to content

Commit

Permalink
add16bytes tag type (#506)
Browse files Browse the repository at this point in the history
* add 16 bytes tag type

* clean up code

* format doc

* fix compile issue

* fix compile issue

* revert change

* format doc

* separate static search and streaming search 

* clean up code

* resolve comment

* format doc

* fix test

* resolve comment
  • Loading branch information
Sanhaoji2 authored Feb 6, 2024
1 parent 5cf0360 commit 58de98d
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 79 deletions.
17 changes: 14 additions & 3 deletions apps/search_memory_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
for (int64_t i = 0; i < (int64_t)query_num; i++)
{
auto qs = std::chrono::high_resolution_clock::now();
if (filtered_search)
if (filtered_search && !tags)
{
std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];

Expand All @@ -179,8 +179,19 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
}
else if (tags)
{
index->search_with_tags(query + i * query_aligned_dim, recall_at, L,
query_result_tags.data() + i * recall_at, nullptr, res);
if (!filtered_search)
{
index->search_with_tags(query + i * query_aligned_dim, recall_at, L,
query_result_tags.data() + i * recall_at, nullptr, res);
}
else
{
std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];

index->search_with_tags(query + i * query_aligned_dim, recall_at, L,
query_result_tags.data() + i * recall_at, nullptr, res, true, raw_filter);
}

for (int64_t r = 0; r < (int64_t)recall_at; r++)
{
query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r];
Expand Down
6 changes: 4 additions & 2 deletions include/abstract_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class AbstractIndex
// Initialize space for res_vectors before calling.
template <typename data_type, typename tag_type>
size_t search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags,
float *distances, std::vector<data_type *> &res_vectors);
float *distances, std::vector<data_type *> &res_vectors, bool use_filters = false,
const std::string filter_label = "");

// Added search overload that takes L as parameter, so that we
// can customize L on a per-query basis without tampering with "Parameters"
Expand Down Expand Up @@ -120,7 +121,8 @@ class AbstractIndex
virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) = 0;
virtual int _get_vector_by_tag(TagType &tag, DataType &vec) = 0;
virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags,
float *distances, DataVector &res_vectors) = 0;
float *distances, DataVector &res_vectors, bool use_filters = false,
const std::string filter_label = "") = 0;
virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) = 0;
virtual void _set_universal_label(const LabelType universal_label) = 0;
};
Expand Down
6 changes: 4 additions & 2 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas

// Initialize space for res_vectors before calling.
DISKANN_DLLEXPORT size_t search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags,
float *distances, std::vector<T *> &res_vectors);
float *distances, std::vector<T *> &res_vectors, bool use_filters = false,
const std::string filter_label = "");

// Filter support search
template <typename IndexType>
Expand Down Expand Up @@ -226,7 +227,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) override;

virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags,
float *distances, DataVector &res_vectors) override;
float *distances, DataVector &res_vectors, bool use_filters = false,
const std::string filter_label = "") override;

virtual void _set_universal_label(const LabelType universal_label) override;

Expand Down
3 changes: 0 additions & 3 deletions include/natural_number_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ template <typename Key, typename Value> class natural_number_map
{
public:
static_assert(std::is_trivial<Key>::value, "Key must be a trivial type");
// Some of the class member prototypes are done with this assumption to
// minimize verbosity since it's the only use case.
static_assert(std::is_trivial<Value>::value, "Value must be a trivial type");

// Represents a reference to a element in the map. Used while iterating
// over map entries.
Expand Down
68 changes: 68 additions & 0 deletions include/tag_uint128.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#pragma once
#include <cstdint>
#include <type_traits>

namespace diskann
{
#pragma pack(push, 1)

struct tag_uint128
{
std::uint64_t _data1 = 0;
std::uint64_t _data2 = 0;

bool operator==(const tag_uint128 &other) const
{
return _data1 == other._data1 && _data2 == other._data2;
}

bool operator==(std::uint64_t other) const
{
return _data1 == other && _data2 == 0;
}

tag_uint128 &operator=(const tag_uint128 &other)
{
_data1 = other._data1;
_data2 = other._data2;

return *this;
}

tag_uint128 &operator=(std::uint64_t other)
{
_data1 = other;
_data2 = 0;

return *this;
}
};

#pragma pack(pop)
} // namespace diskann

namespace std
{
// Hash 128 input bits down to 64 bits of output.
// This is intended to be a reasonably good hash function.
inline std::uint64_t Hash128to64(const std::uint64_t &low, const std::uint64_t &high)
{
// Murmur-inspired hashing.
const std::uint64_t kMul = 0x9ddfea08eb382d69ULL;
std::uint64_t a = (low ^ high) * kMul;
a ^= (a >> 47);
std::uint64_t b = (high ^ a) * kMul;
b ^= (b >> 47);
b *= kMul;
return b;
}

template <> struct hash<diskann::tag_uint128>
{
size_t operator()(const diskann::tag_uint128 &key) const noexcept
{
return Hash128to64(key._data1, key._data2); // map -0 to 0
}
};

} // namespace std
12 changes: 12 additions & 0 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ typedef int FileHandle;
#include "windows_customizations.h"
#include "tsl/robin_set.h"
#include "types.h"
#include "tag_uint128.h"
#include <any>

#ifdef EXEC_ENV_OLS
Expand Down Expand Up @@ -1007,6 +1008,17 @@ void block_convert(std::ofstream &writr, std::ifstream &readr, float *read_buf,

DISKANN_DLLEXPORT void normalize_data_file(const std::string &inFileName, const std::string &outFileName);

inline std::string get_tag_string(std::uint64_t tag)
{
return std::to_string(tag);
}

inline std::string get_tag_string(const tag_uint128 &tag)
{
std::string str = std::to_string(tag._data2) + "_" + std::to_string(tag._data1);
return str;
}

}; // namespace diskann

struct PivotContainer
Expand Down
77 changes: 35 additions & 42 deletions src/abstract_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ std::pair<uint32_t, uint32_t> AbstractIndex::search(const data_type *query, cons

template <typename data_type, typename tag_type>
size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags,
float *distances, std::vector<data_type *> &res_vectors)
float *distances, std::vector<data_type *> &res_vectors, bool use_filters,
const std::string filter_label)
{
auto any_query = std::any(query);
auto any_tags = std::any(tags);
auto any_res_vectors = DataVector(res_vectors);
return this->_search_with_tags(any_query, K, L, any_tags, distances, any_res_vectors);
return this->_search_with_tags(any_query, K, L, any_tags, distances, any_res_vectors, use_filters, filter_label);
}

template <typename IndexType>
Expand Down Expand Up @@ -162,61 +163,53 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search_w
const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint64_t *indices,
float *distances);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, int32_t>(const float *query, const uint64_t K,
const uint32_t L, int32_t *tags,
float *distances,
std::vector<float *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, int32_t>(
const float *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances,
std::vector<float *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t
AbstractIndex::search_with_tags<uint8_t, int32_t>(const uint8_t *query, const uint64_t K, const uint32_t L,
int32_t *tags, float *distances, std::vector<uint8_t *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<uint8_t, int32_t>(
const uint8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances,
std::vector<uint8_t *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, int32_t>(const int8_t *query,
const uint64_t K, const uint32_t L,
int32_t *tags, float *distances,
std::vector<int8_t *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, int32_t>(
const int8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances,
std::vector<int8_t *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, uint32_t>(const float *query, const uint64_t K,
const uint32_t L, uint32_t *tags,
float *distances,
std::vector<float *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, uint32_t>(
const float *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances,
std::vector<float *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<uint8_t, uint32_t>(
const uint8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances,
std::vector<uint8_t *> &res_vectors);
std::vector<uint8_t *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, uint32_t>(const int8_t *query,
const uint64_t K, const uint32_t L,
uint32_t *tags, float *distances,
std::vector<int8_t *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, uint32_t>(
const int8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances,
std::vector<int8_t *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, int64_t>(const float *query, const uint64_t K,
const uint32_t L, int64_t *tags,
float *distances,
std::vector<float *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, int64_t>(
const float *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances,
std::vector<float *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t
AbstractIndex::search_with_tags<uint8_t, int64_t>(const uint8_t *query, const uint64_t K, const uint32_t L,
int64_t *tags, float *distances, std::vector<uint8_t *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<uint8_t, int64_t>(
const uint8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances,
std::vector<uint8_t *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, int64_t>(const int8_t *query,
const uint64_t K, const uint32_t L,
int64_t *tags, float *distances,
std::vector<int8_t *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, int64_t>(
const int8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances,
std::vector<int8_t *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, uint64_t>(const float *query, const uint64_t K,
const uint32_t L, uint64_t *tags,
float *distances,
std::vector<float *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, uint64_t>(
const float *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances,
std::vector<float *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<uint8_t, uint64_t>(
const uint8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances,
std::vector<uint8_t *> &res_vectors);
std::vector<uint8_t *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, uint64_t>(const int8_t *query,
const uint64_t K, const uint32_t L,
uint64_t *tags, float *distances,
std::vector<int8_t *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, uint64_t>(
const int8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances,
std::vector<int8_t *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout<float>(const float *query, size_t K,
size_t L, uint32_t *indices);
Expand Down
Loading

0 comments on commit 58de98d

Please sign in to comment.