From 353e538f458d4f775565b82ee070cd0a01433839 Mon Sep 17 00:00:00 2001 From: Dax Pryce Date: Tue, 29 Aug 2023 15:49:30 -0700 Subject: [PATCH] Type hints and returns actually align this time. (#444) --- python/src/_dynamic_memory_index.py | 6 ++++-- python/src/_static_disk_index.py | 6 ++++-- python/src/_static_memory_index.py | 6 ++++-- python/tests/test_dynamic_memory_index.py | 9 +++++++-- python/tests/test_static_disk_index.py | 9 +++++++-- python/tests/test_static_memory_index.py | 9 +++++++-- 6 files changed, 33 insertions(+), 12 deletions(-) diff --git a/python/src/_dynamic_memory_index.py b/python/src/_dynamic_memory_index.py index 9570b8345..0346a2c76 100644 --- a/python/src/_dynamic_memory_index.py +++ b/python/src/_dynamic_memory_index.py @@ -309,7 +309,8 @@ def search( f"k_neighbors={k_neighbors} asked for, but list_size={complexity} was smaller. Increasing {complexity} to {k_neighbors}" ) complexity = k_neighbors - return self._index.search(query=_query, knn=k_neighbors, complexity=complexity) + neighbors, distances = self._index.search(query=_query, knn=k_neighbors, complexity=complexity) + return QueryResponse(identifiers=neighbors, distances=distances) def batch_search( self, @@ -351,13 +352,14 @@ def batch_search( complexity = k_neighbors num_queries, dim = queries.shape - return self._index.batch_search( + neighbors, distances = self._index.batch_search( queries=_queries, num_queries=num_queries, knn=k_neighbors, complexity=complexity, num_threads=num_threads, ) + return QueryResponseBatch(identifiers=neighbors, distances=distances) def save(self, save_path: str, index_prefix: str = "ann"): """ diff --git a/python/src/_static_disk_index.py b/python/src/_static_disk_index.py index 1ca93c0a4..769099d8f 100644 --- a/python/src/_static_disk_index.py +++ b/python/src/_static_disk_index.py @@ -138,12 +138,13 @@ def search( ) complexity = k_neighbors - return self._index.search( + neighbors, distances = self._index.search( query=_query, knn=k_neighbors, complexity=complexity, beam_width=beam_width, ) + return QueryResponse(identifiers=neighbors, distances=distances) def batch_search( self, @@ -187,7 +188,7 @@ def batch_search( complexity = k_neighbors num_queries, dim = _queries.shape - return self._index.batch_search( + neighbors, distances = self._index.batch_search( queries=_queries, num_queries=num_queries, knn=k_neighbors, @@ -195,3 +196,4 @@ def batch_search( beam_width=beam_width, num_threads=num_threads, ) + return QueryResponseBatch(identifiers=neighbors, distances=distances) diff --git a/python/src/_static_memory_index.py b/python/src/_static_memory_index.py index 8b87cd561..b1ffb468d 100644 --- a/python/src/_static_memory_index.py +++ b/python/src/_static_memory_index.py @@ -136,7 +136,8 @@ def search( f"k_neighbors={k_neighbors} asked for, but list_size={complexity} was smaller. Increasing {complexity} to {k_neighbors}" ) complexity = k_neighbors - return self._index.search(query=_query, knn=k_neighbors, complexity=complexity) + neighbors, distances = self._index.search(query=_query, knn=k_neighbors, complexity=complexity) + return QueryResponse(identifiers=neighbors, distances=distances) def batch_search( self, @@ -178,10 +179,11 @@ def batch_search( complexity = k_neighbors num_queries, dim = _queries.shape - return self._index.batch_search( + neighbors, distances = self._index.batch_search( queries=_queries, num_queries=num_queries, knn=k_neighbors, complexity=complexity, num_threads=num_threads, ) + return QueryResponseBatch(identifiers=neighbors, distances=distances) diff --git a/python/tests/test_dynamic_memory_index.py b/python/tests/test_dynamic_memory_index.py index ff9c8981d..48c05443c 100644 --- a/python/tests/test_dynamic_memory_index.py +++ b/python/tests/test_dynamic_memory_index.py @@ -72,12 +72,15 @@ def test_recall_and_batch(self): ) k = 5 - diskann_neighbors, diskann_distances = index.batch_search( + batch_response = index.batch_search( query_vectors, k_neighbors=k, complexity=5, num_threads=16, ) + self.assertIsInstance(batch_response, dap.QueryResponseBatch) + + diskann_neighbors, diskann_distances = batch_response if metric == "l2" or metric == "cosine": knn = NearestNeighbors( n_neighbors=100, algorithm="auto", metric=metric @@ -115,7 +118,9 @@ def test_single(self): index.batch_insert(vectors=index_vectors, vector_ids=generated_tags) k = 5 - ids, dists = index.search(query_vectors[0], k_neighbors=k, complexity=5) + response = index.search(query_vectors[0], k_neighbors=k, complexity=5) + self.assertIsInstance(response, dap.QueryResponse) + ids, dists = response self.assertEqual(ids.shape[0], k) self.assertEqual(dists.shape[0], k) diff --git a/python/tests/test_static_disk_index.py b/python/tests/test_static_disk_index.py index 4ba544106..c36c581d2 100644 --- a/python/tests/test_static_disk_index.py +++ b/python/tests/test_static_disk_index.py @@ -62,13 +62,16 @@ def test_recall_and_batch(self): ) k = 5 - diskann_neighbors, diskann_distances = index.batch_search( + batch_response = index.batch_search( query_vectors, k_neighbors=k, complexity=5, beam_width=2, num_threads=16, ) + self.assertIsInstance(batch_response, dap.QueryResponseBatch) + + diskann_neighbors, diskann_distances = batch_response if metric == "l2": knn = NearestNeighbors( n_neighbors=100, algorithm="auto", metric="l2" @@ -93,9 +96,11 @@ def test_single(self): ) k = 5 - ids, dists = index.search( + response = index.search( query_vectors[0], k_neighbors=k, complexity=5, beam_width=2 ) + self.assertIsInstance(response, dap.QueryResponse) + ids, dists = response self.assertEqual(ids.shape[0], k) self.assertEqual(dists.shape[0], k) diff --git a/python/tests/test_static_memory_index.py b/python/tests/test_static_memory_index.py index cb7f0f01d..ce12ed3bf 100644 --- a/python/tests/test_static_memory_index.py +++ b/python/tests/test_static_memory_index.py @@ -50,12 +50,15 @@ def test_recall_and_batch(self): ) k = 5 - diskann_neighbors, diskann_distances = index.batch_search( + batch_response = index.batch_search( query_vectors, k_neighbors=k, complexity=5, num_threads=16, ) + self.assertIsInstance(batch_response, dap.QueryResponseBatch) + + diskann_neighbors, diskann_distances = batch_response if metric in ["l2", "cosine"]: knn = NearestNeighbors( n_neighbors=100, algorithm="auto", metric=metric @@ -86,7 +89,9 @@ def test_single(self): ) k = 5 - ids, dists = index.search(query_vectors[0], k_neighbors=k, complexity=5) + response = index.search(query_vectors[0], k_neighbors=k, complexity=5) + self.assertIsInstance(response, dap.QueryResponse) + ids, dists = response self.assertEqual(ids.shape[0], k) self.assertEqual(dists.shape[0], k)