Skip to content

Commit

Permalink
Fixes #432, bug in using openmp with gcc and omp_get_num_threads() (#445
Browse files Browse the repository at this point in the history
)

* Fixes #432, bug in using openmp with gcc and omp_get_num_threads() only reporting the number of threads collaborating on the current code region not available overall. I made this error and transitioned us from omp_get_num_procs() about 5 or 6 months ago and only with bug #432 did I really get to see how problematic my naive expectations were.

* Removed cosine distance metric from disk index until we can properly fix it in pqflashindex. Documented what distance metrics can be used with what vector dtypes in tables in the documentation.
  • Loading branch information
daxpryce authored Aug 30, 2023
1 parent fa6c279 commit a112411
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 10 deletions.
2 changes: 1 addition & 1 deletion include/parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class IndexWriteParametersBuilder

IndexWriteParametersBuilder &with_num_threads(const uint32_t num_threads)
{
_num_threads = num_threads == 0 ? omp_get_num_threads() : num_threads;
_num_threads = num_threads == 0 ? omp_get_num_procs() : num_threads;
return *this;
}

Expand Down
28 changes: 28 additions & 0 deletions python/src/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ def build_disk_index(
in the format DiskANN's PQ Flash Index builder requires. This temp folder is deleted upon index creation completion
or error.
## Distance Metric and Vector Datatype Restrictions
| Metric \ Datatype | np.float32 | np.uint8 | np.int8 |
|-------------------|------------|----------|---------|
| L2 | ✅ | ✅ | ✅ |
| MIPS | ✅ | ❌ | ❌ |
| Cosine [^bug-in-disk-cosine] | ❌ | ❌ | ❌ |
[^bug-in-disk-cosine]: For StaticDiskIndex, Cosine distances are not currently supported.
### Parameters
- **data**: Either a `str` representing a path to a DiskANN vector bin file, or a numpy.ndarray,
of a supported dtype, in 2 dimensions. Note that `vector_dtype` must be provided if data is a `str`
Expand Down Expand Up @@ -119,6 +128,12 @@ def build_disk_index(
vector_bin_path, vector_dtype_actual = _valid_path_and_dtype(
data, vector_dtype, index_directory, index_prefix
)
_assert(dap_metric != _native_dap.COSINE, "Cosine is currently not supported in StaticDiskIndex")
if dap_metric == _native_dap.INNER_PRODUCT:
_assert(
vector_dtype_actual == np.float32,
"Integral vector dtypes (np.uint8, np.int8) are not supported with distance metric mips"
)

num_points, dimensions = vectors_metadata_from_file(vector_bin_path)

Expand Down Expand Up @@ -176,6 +191,14 @@ def build_memory_index(
`diskannpy.DynamicMemoryIndex`, you **must** supply a valid value for the `tags` parameter. **Do not supply
tags if the index is intended to be `diskannpy.StaticMemoryIndex`**!
## Distance Metric and Vector Datatype Restrictions
| Metric \ Datatype | np.float32 | np.uint8 | np.int8 |
|-------------------|------------|----------|---------|
| L2 | ✅ | ✅ | ✅ |
| MIPS | ✅ | ❌ | ❌ |
| Cosine | ✅ | ✅ | ✅ |
### Parameters
- **data**: Either a `str` representing a path to an existing DiskANN vector bin file, or a numpy.ndarray of a
Expand Down Expand Up @@ -232,6 +255,11 @@ def build_memory_index(
vector_bin_path, vector_dtype_actual = _valid_path_and_dtype(
data, vector_dtype, index_directory, index_prefix
)
if dap_metric == _native_dap.INNER_PRODUCT:
_assert(
vector_dtype_actual == np.float32,
"Integral vector dtypes (np.uint8, np.int8) are not supported with distance metric mips"
)

num_points, dimensions = vectors_metadata_from_file(vector_bin_path)

Expand Down
3 changes: 1 addition & 2 deletions python/src/dynamic_memory_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ diskann::Index<DT, DynamicIdType, filterT> dynamic_index_builder(const diskann::
const uint32_t initial_search_threads,
const bool concurrent_consolidation)
{
const uint32_t _initial_search_threads =
initial_search_threads != 0 ? initial_search_threads : omp_get_num_threads();
const uint32_t _initial_search_threads = initial_search_threads != 0 ? initial_search_threads : omp_get_num_procs();

auto index_search_params = diskann::IndexSearchParams(initial_search_complexity, _initial_search_threads);
return diskann::Index<DT, DynamicIdType, filterT>(
Expand Down
5 changes: 3 additions & 2 deletions python/src/static_disk_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@ StaticDiskIndex<DT>::StaticDiskIndex(const diskann::Metric metric, const std::st
const uint32_t cache_mechanism)
: _reader(std::make_shared<PlatformSpecificAlignedFileReader>()), _index(_reader, metric)
{
int load_success = _index.load(num_threads, index_path_prefix.c_str());
const uint32_t _num_threads = num_threads != 0 ? num_threads : omp_get_num_procs();
int load_success = _index.load(_num_threads, index_path_prefix.c_str());
if (load_success != 0)
{
throw std::runtime_error("index load failed.");
}
if (cache_mechanism == 1)
{
std::string sample_file = index_path_prefix + std::string("_sample_data.bin");
cache_sample_paths(num_nodes_to_cache, sample_file, num_threads);
cache_sample_paths(num_nodes_to_cache, sample_file, _num_threads);
}
else if (cache_mechanism == 2)
{
Expand Down
6 changes: 3 additions & 3 deletions python/src/static_memory_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ diskann::Index<DT, StaticIdType, filterT> static_index_builder(const diskann::Me
{
throw std::runtime_error("initial_search_complexity must be a positive uint32_t");
}
auto index_search_params = diskann::IndexSearchParams(initial_search_complexity, omp_get_num_threads());
auto index_search_params = diskann::IndexSearchParams(initial_search_complexity, omp_get_num_procs());
return diskann::Index<DT>(m, dimensions, num_points,
nullptr, // index write params
std::make_shared<diskann::IndexSearchParams>(index_search_params), // index search params
Expand All @@ -36,7 +36,7 @@ StaticMemoryIndex<DT>::StaticMemoryIndex(const diskann::Metric m, const std::str
const uint32_t initial_search_complexity)
: _index(static_index_builder<DT>(m, num_points, dimensions, initial_search_complexity))
{
const uint32_t _num_threads = num_threads != 0 ? num_threads : omp_get_num_threads();
const uint32_t _num_threads = num_threads != 0 ? num_threads : omp_get_num_procs();
_index.load(index_prefix.c_str(), _num_threads, initial_search_complexity);
}

Expand All @@ -56,7 +56,7 @@ NeighborsAndDistances<StaticIdType> StaticMemoryIndex<DT>::batch_search(
py::array_t<DT, py::array::c_style | py::array::forcecast> &queries, const uint64_t num_queries, const uint64_t knn,
const uint64_t complexity, const uint32_t num_threads)
{
const uint32_t _num_threads = num_threads != 0 ? num_threads : omp_get_num_threads();
const uint32_t _num_threads = num_threads != 0 ? num_threads : omp_get_num_procs();
py::array_t<StaticIdType> ids({num_queries, knn});
py::array_t<float> dists({num_queries, knn});
std::vector<DT *> empty_vector;
Expand Down
24 changes: 24 additions & 0 deletions python/tests/test_dynamic_memory_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def setUpClass(cls) -> None:
build_random_vectors_and_memory_index(np.float32, "cosine", with_tags=True),
build_random_vectors_and_memory_index(np.uint8, "cosine", with_tags=True),
build_random_vectors_and_memory_index(np.int8, "cosine", with_tags=True),
build_random_vectors_and_memory_index(np.float32, "mips", with_tags=True),
]
cls._example_ann_dir = cls._test_matrix[0][4]

Expand Down Expand Up @@ -442,4 +443,27 @@ def _tiny_index():
warnings.simplefilter("error") # turns warnings into raised exceptions
index.batch_insert(rng.random((2, 10), dtype=np.float32), np.array([15, 25], dtype=np.uint32))

def test_zero_threads(self):
for (
metric,
dtype,
query_vectors,
index_vectors,
ann_dir,
vector_bin_file,
generated_tags,
) in self._test_matrix:
with self.subTest(msg=f"Testing dtype {dtype}"):
index = dap.DynamicMemoryIndex(
distance_metric="l2",
vector_dtype=dtype,
dimensions=10,
max_vectors=11_000,
complexity=64,
graph_degree=32,
num_threads=0, # explicitly asking it to use all available threads.
)
index.batch_insert(vectors=index_vectors, vector_ids=generated_tags, num_threads=0)

k = 5
ids, dists = index.batch_search(query_vectors, k_neighbors=k, complexity=5, num_threads=0)
19 changes: 18 additions & 1 deletion python/tests/test_static_disk_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _build_random_vectors_and_index(dtype, metric):
complexity=32,
search_memory_maximum=0.00003,
build_memory_maximum=1,
num_threads=1,
num_threads=0,
pq_disk_bytes=0,
)
return metric, dtype, query_vectors, index_vectors, ann_dir
Expand All @@ -38,6 +38,7 @@ def setUpClass(cls) -> None:
_build_random_vectors_and_index(np.float32, "l2"),
_build_random_vectors_and_index(np.uint8, "l2"),
_build_random_vectors_and_index(np.int8, "l2"),
_build_random_vectors_and_index(np.float32, "mips"),
]
cls._example_ann_dir = cls._test_matrix[0][4]

Expand Down Expand Up @@ -149,3 +150,19 @@ def test_value_ranges_batch_search(self):
index.batch_search(
queries=np.array([[]], dtype=np.single), **kwargs
)

def test_zero_threads(self):
for metric, dtype, query_vectors, index_vectors, ann_dir in self._test_matrix:
with self.subTest(msg=f"Testing dtype {dtype}"):
index = dap.StaticDiskIndex(
distance_metric="l2",
vector_dtype=dtype,
index_directory=ann_dir,
num_threads=0, # Issue #432
num_nodes_to_cache=10,
)

k = 5
ids, dists = index.batch_search(
query_vectors, k_neighbors=k, complexity=5, beam_width=2, num_threads=0
)
21 changes: 21 additions & 0 deletions python/tests/test_static_memory_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def setUpClass(cls) -> None:
build_random_vectors_and_memory_index(np.float32, "cosine"),
build_random_vectors_and_memory_index(np.uint8, "cosine"),
build_random_vectors_and_memory_index(np.int8, "cosine"),
build_random_vectors_and_memory_index(np.float32, "mips"),
]
cls._example_ann_dir = cls._test_matrix[0][4]

Expand Down Expand Up @@ -165,3 +166,23 @@ def test_value_ranges_batch_search(self):
index.batch_search(
queries=np.array([[]], dtype=np.single), **kwargs
)

def test_zero_threads(self):
for (
metric,
dtype,
query_vectors,
index_vectors,
ann_dir,
vector_bin_file,
_,
) in self._test_matrix:
with self.subTest(msg=f"Testing dtype {dtype}"):
index = dap.StaticMemoryIndex(
index_directory=ann_dir,
num_threads=0,
initial_search_complexity=32,
)

k = 5
ids, dists = index.batch_search(query_vectors, k_neighbors=k, complexity=5, num_threads=0)
2 changes: 1 addition & 1 deletion src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2370,7 +2370,7 @@ consolidation_report Index<T, TagT, LabelT>::consolidate_deletes(const IndexWrit
const uint32_t range = params.max_degree;
const uint32_t maxc = params.max_occlusion_size;
const float alpha = params.alpha;
const uint32_t num_threads = params.num_threads == 0 ? omp_get_num_threads() : params.num_threads;
const uint32_t num_threads = params.num_threads == 0 ? omp_get_num_procs() : params.num_threads;

uint32_t num_calls_to_process_delete = 0;
diskann::Timer timer;
Expand Down

0 comments on commit a112411

Please sign in to comment.