Skip to content

Commit

Permalink
Type hints and returns actually align this time. (#444)
Browse files Browse the repository at this point in the history
  • Loading branch information
daxpryce authored Aug 29, 2023
1 parent 8afb38a commit 353e538
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 12 deletions.
6 changes: 4 additions & 2 deletions python/src/_dynamic_memory_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ def search(
f"k_neighbors={k_neighbors} asked for, but list_size={complexity} was smaller. Increasing {complexity} to {k_neighbors}"
)
complexity = k_neighbors
return self._index.search(query=_query, knn=k_neighbors, complexity=complexity)
neighbors, distances = self._index.search(query=_query, knn=k_neighbors, complexity=complexity)
return QueryResponse(identifiers=neighbors, distances=distances)

def batch_search(
self,
Expand Down Expand Up @@ -351,13 +352,14 @@ def batch_search(
complexity = k_neighbors

num_queries, dim = queries.shape
return self._index.batch_search(
neighbors, distances = self._index.batch_search(
queries=_queries,
num_queries=num_queries,
knn=k_neighbors,
complexity=complexity,
num_threads=num_threads,
)
return QueryResponseBatch(identifiers=neighbors, distances=distances)

def save(self, save_path: str, index_prefix: str = "ann"):
"""
Expand Down
6 changes: 4 additions & 2 deletions python/src/_static_disk_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,13 @@ def search(
)
complexity = k_neighbors

return self._index.search(
neighbors, distances = self._index.search(
query=_query,
knn=k_neighbors,
complexity=complexity,
beam_width=beam_width,
)
return QueryResponse(identifiers=neighbors, distances=distances)

def batch_search(
self,
Expand Down Expand Up @@ -187,11 +188,12 @@ def batch_search(
complexity = k_neighbors

num_queries, dim = _queries.shape
return self._index.batch_search(
neighbors, distances = self._index.batch_search(
queries=_queries,
num_queries=num_queries,
knn=k_neighbors,
complexity=complexity,
beam_width=beam_width,
num_threads=num_threads,
)
return QueryResponseBatch(identifiers=neighbors, distances=distances)
6 changes: 4 additions & 2 deletions python/src/_static_memory_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def search(
f"k_neighbors={k_neighbors} asked for, but list_size={complexity} was smaller. Increasing {complexity} to {k_neighbors}"
)
complexity = k_neighbors
return self._index.search(query=_query, knn=k_neighbors, complexity=complexity)
neighbors, distances = self._index.search(query=_query, knn=k_neighbors, complexity=complexity)
return QueryResponse(identifiers=neighbors, distances=distances)

def batch_search(
self,
Expand Down Expand Up @@ -178,10 +179,11 @@ def batch_search(
complexity = k_neighbors

num_queries, dim = _queries.shape
return self._index.batch_search(
neighbors, distances = self._index.batch_search(
queries=_queries,
num_queries=num_queries,
knn=k_neighbors,
complexity=complexity,
num_threads=num_threads,
)
return QueryResponseBatch(identifiers=neighbors, distances=distances)
9 changes: 7 additions & 2 deletions python/tests/test_dynamic_memory_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,15 @@ def test_recall_and_batch(self):
)

k = 5
diskann_neighbors, diskann_distances = index.batch_search(
batch_response = index.batch_search(
query_vectors,
k_neighbors=k,
complexity=5,
num_threads=16,
)
self.assertIsInstance(batch_response, dap.QueryResponseBatch)

diskann_neighbors, diskann_distances = batch_response
if metric == "l2" or metric == "cosine":
knn = NearestNeighbors(
n_neighbors=100, algorithm="auto", metric=metric
Expand Down Expand Up @@ -115,7 +118,9 @@ def test_single(self):
index.batch_insert(vectors=index_vectors, vector_ids=generated_tags)

k = 5
ids, dists = index.search(query_vectors[0], k_neighbors=k, complexity=5)
response = index.search(query_vectors[0], k_neighbors=k, complexity=5)
self.assertIsInstance(response, dap.QueryResponse)
ids, dists = response
self.assertEqual(ids.shape[0], k)
self.assertEqual(dists.shape[0], k)

Expand Down
9 changes: 7 additions & 2 deletions python/tests/test_static_disk_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,16 @@ def test_recall_and_batch(self):
)

k = 5
diskann_neighbors, diskann_distances = index.batch_search(
batch_response = index.batch_search(
query_vectors,
k_neighbors=k,
complexity=5,
beam_width=2,
num_threads=16,
)
self.assertIsInstance(batch_response, dap.QueryResponseBatch)

diskann_neighbors, diskann_distances = batch_response
if metric == "l2":
knn = NearestNeighbors(
n_neighbors=100, algorithm="auto", metric="l2"
Expand All @@ -93,9 +96,11 @@ def test_single(self):
)

k = 5
ids, dists = index.search(
response = index.search(
query_vectors[0], k_neighbors=k, complexity=5, beam_width=2
)
self.assertIsInstance(response, dap.QueryResponse)
ids, dists = response
self.assertEqual(ids.shape[0], k)
self.assertEqual(dists.shape[0], k)

Expand Down
9 changes: 7 additions & 2 deletions python/tests/test_static_memory_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,15 @@ def test_recall_and_batch(self):
)

k = 5
diskann_neighbors, diskann_distances = index.batch_search(
batch_response = index.batch_search(
query_vectors,
k_neighbors=k,
complexity=5,
num_threads=16,
)
self.assertIsInstance(batch_response, dap.QueryResponseBatch)

diskann_neighbors, diskann_distances = batch_response
if metric in ["l2", "cosine"]:
knn = NearestNeighbors(
n_neighbors=100, algorithm="auto", metric=metric
Expand Down Expand Up @@ -86,7 +89,9 @@ def test_single(self):
)

k = 5
ids, dists = index.search(query_vectors[0], k_neighbors=k, complexity=5)
response = index.search(query_vectors[0], k_neighbors=k, complexity=5)
self.assertIsInstance(response, dap.QueryResponse)
ids, dists = response
self.assertEqual(ids.shape[0], k)
self.assertEqual(dists.shape[0], k)

Expand Down

0 comments on commit 353e538

Please sign in to comment.