Skip to content

Commit 1db12c4

Browse files
Python: Adding Vector Search to the In Memory collection (#9574)
### Motivation and Context <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> Adds vectorized search and text search to the In Memory connector. ### Description <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄 --------- Co-authored-by: Tao Chen <[email protected]>
1 parent daafde4 commit 1db12c4

File tree

12 files changed

+267
-76
lines changed

12 files changed

+267
-76
lines changed

.github/workflows/python-lint.yml

+3
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ jobs:
3333
- name: Install the project
3434
run: uv sync --all-extras --dev
3535
- uses: pre-commit/[email protected]
36+
name: Run Pre-Commit Hooks
3637
with:
3738
extra_args: --config python/.pre-commit-config.yaml --all-files
39+
- name: Run Mypy
40+
run: uv run mypy -p semantic_kernel --config-file mypy.ini
3841
- name: Minimize uv cache
3942
run: uv cache prune --ci

python/.pre-commit-config.yaml

+1-10
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,10 @@ repos:
3939
- id: ruff-format
4040
- repo: https://github.com/astral-sh/uv-pre-commit
4141
# uv version.
42-
rev: 0.4.29
42+
rev: 0.4.30
4343
hooks:
4444
# Update the uv lockfile
4545
- id: uv-lock
46-
- repo: local
47-
hooks:
48-
- id: mypy
49-
files: ^python/semantic_kernel/
50-
name: mypy
51-
entry: bash -c 'cd python && uv run mypy -p semantic_kernel --config-file mypy.ini'
52-
language: system
53-
types: [python]
54-
pass_filenames: true
5546
- repo: https://github.com/PyCQA/bandit
5647
rev: 1.7.8
5748
hooks:

python/.vscode/tasks.json

+5-4
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,11 @@
6464
"command": "uv",
6565
"args": [
6666
"run",
67-
"pre-commit",
68-
"run",
69-
"-a",
70-
"mypy"
67+
"mypy",
68+
"-p",
69+
"semantic_kernel",
70+
"--config-file",
71+
"mypy.ini"
7172
],
7273
"problemMatcher": {
7374
"owner": "python",

python/semantic_kernel/connectors/memory/azure_ai_search/azure_ai_search_collection.py

-34
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from semantic_kernel.data.vector_search.vector_text_search import VectorTextSearchMixin
3737
from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin
3838
from semantic_kernel.exceptions import MemoryConnectorException, MemoryConnectorInitializationError
39-
from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata
4039
from semantic_kernel.utils.experimental_decorator import experimental_class
4140

4241
logger: logging.Logger = logging.getLogger(__name__)
@@ -304,39 +303,6 @@ def _build_filter_string(self, search_filter: VectorSearchFilter) -> str:
304303
filter_string = filter_string[:-5]
305304
return filter_string
306305

307-
@staticmethod
308-
def _default_parameter_metadata() -> list[KernelParameterMetadata]:
309-
"""Default parameter metadata for text search functions.
310-
311-
This function should be overridden when necessary.
312-
"""
313-
return [
314-
KernelParameterMetadata(
315-
name="query",
316-
description="What to search for.",
317-
type="str",
318-
is_required=False,
319-
default_value="*",
320-
type_object=str,
321-
),
322-
KernelParameterMetadata(
323-
name="count",
324-
description="Number of results to return.",
325-
type="int",
326-
is_required=False,
327-
default_value=2,
328-
type_object=int,
329-
),
330-
KernelParameterMetadata(
331-
name="skip",
332-
description="Number of results to skip.",
333-
type="int",
334-
is_required=False,
335-
default_value=0,
336-
type_object=int,
337-
),
338-
]
339-
340306
@override
341307
def _get_record_from_result(self, result: dict[str, Any]) -> dict[str, Any]:
342308
return result
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
4+
from collections.abc import Callable
5+
from typing import Any
6+
7+
from numpy import dot
8+
from scipy.spatial.distance import cityblock, cosine, euclidean, hamming, sqeuclidean
9+
10+
from semantic_kernel.data.const import DistanceFunction
11+
12+
DISTANCE_FUNCTION_MAP: dict[DistanceFunction | str, Callable[..., Any]] = {
13+
DistanceFunction.COSINE_DISTANCE: cosine,
14+
DistanceFunction.COSINE_SIMILARITY: cosine,
15+
DistanceFunction.EUCLIDEAN_DISTANCE: euclidean,
16+
DistanceFunction.EUCLIDEAN_SQUARED_DISTANCE: sqeuclidean,
17+
DistanceFunction.MANHATTAN: cityblock,
18+
DistanceFunction.HAMMING: hamming,
19+
DistanceFunction.DOT_PROD: dot,
20+
"default": cosine,
21+
}

python/semantic_kernel/connectors/memory/in_memory/in_memory_collection.py

+166-5
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,45 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

33
import sys
4-
from collections.abc import Mapping, Sequence
4+
from collections.abc import AsyncIterable, Callable, Mapping, Sequence
55
from typing import Any, ClassVar, TypeVar
66

7+
from pydantic import Field
8+
9+
from semantic_kernel.data.filter_clauses.any_tags_equal_to_filter_clause import AnyTagsEqualTo
10+
from semantic_kernel.data.filter_clauses.equal_to_filter_clause import EqualTo
11+
712
if sys.version_info >= (3, 12):
813
from typing import override # pragma: no cover
914
else:
1015
from typing_extensions import override # pragma: no cover
1116

12-
from pydantic import Field
13-
17+
from semantic_kernel.connectors.memory.in_memory.const import DISTANCE_FUNCTION_MAP
18+
from semantic_kernel.data.const import DistanceFunction
19+
from semantic_kernel.data.filter_clauses.filter_clause_base import FilterClauseBase
20+
from semantic_kernel.data.kernel_search_results import KernelSearchResults
1421
from semantic_kernel.data.record_definition.vector_store_model_definition import VectorStoreRecordDefinition
15-
from semantic_kernel.data.vector_storage.vector_store_record_collection import VectorStoreRecordCollection
22+
from semantic_kernel.data.record_definition.vector_store_record_fields import (
23+
VectorStoreRecordVectorField,
24+
)
25+
from semantic_kernel.data.vector_search.vector_search import VectorSearchBase
26+
from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions
27+
from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult
28+
from semantic_kernel.data.vector_search.vector_text_search import VectorTextSearchMixin
29+
from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin
30+
from semantic_kernel.exceptions import VectorSearchExecutionException, VectorStoreModelValidationError
1631
from semantic_kernel.kernel_types import OneOrMany
1732

1833
KEY_TYPES = str | int | float
1934

2035
TModel = TypeVar("TModel")
2136

37+
IN_MEMORY_SCORE_KEY = "in_memory_search_score"
38+
2239

23-
class InMemoryVectorCollection(VectorStoreRecordCollection[KEY_TYPES, TModel]):
40+
class InMemoryVectorCollection(
41+
VectorSearchBase[KEY_TYPES, TModel], VectorTextSearchMixin[TModel], VectorizedSearchMixin[TModel]
42+
):
2443
"""In Memory Collection."""
2544

2645
inner_storage: dict[KEY_TYPES, dict] = Field(default_factory=dict)
@@ -39,6 +58,12 @@ def __init__(
3958
collection_name=collection_name,
4059
)
4160

61+
def _validate_data_model(self):
62+
"""Check if the In Memory Score key is not used."""
63+
super()._validate_data_model()
64+
if IN_MEMORY_SCORE_KEY in self.data_model_definition.field_names:
65+
raise VectorStoreModelValidationError(f"Field name '{IN_MEMORY_SCORE_KEY}' is reserved for internal use.")
66+
4267
@override
4368
async def _inner_delete(self, keys: Sequence[KEY_TYPES], **kwargs: Any) -> None:
4469
for key in keys:
@@ -74,3 +99,139 @@ async def delete_collection(self, **kwargs: Any) -> None:
7499
@override
75100
async def does_collection_exist(self, **kwargs: Any) -> bool:
76101
return True
102+
103+
@override
104+
async def _inner_search(
105+
self,
106+
options: VectorSearchOptions | None = None,
107+
search_text: str | None = None,
108+
vectorizable_text: str | None = None,
109+
vector: list[float | int] | None = None,
110+
**kwargs: Any,
111+
) -> KernelSearchResults[VectorSearchResult[TModel]]:
112+
"""Inner search method."""
113+
if search_text:
114+
return await self._inner_search_text(search_text, options, **kwargs)
115+
if vector:
116+
if not options:
117+
raise VectorSearchExecutionException("Options must be provided for vector search.")
118+
return await self._inner_search_vectorized(vector, options, **kwargs)
119+
raise VectorSearchExecutionException("Search text or vector must be provided.")
120+
121+
async def _inner_search_text(
122+
self,
123+
search_text: str,
124+
options: VectorSearchOptions | None = None,
125+
**kwargs: Any,
126+
) -> KernelSearchResults[VectorSearchResult[TModel]]:
127+
"""Inner search method."""
128+
return_records: dict[KEY_TYPES, float] = {}
129+
for key, record in self._get_filtered_records(options).items():
130+
if self._should_add_text_search(search_text, record):
131+
return_records[key] = 1.0
132+
if return_records:
133+
return KernelSearchResults(
134+
results=self._get_vector_search_results_from_results(
135+
self._generate_return_list(return_records, options)
136+
),
137+
total_count=len(return_records) if options and options.include_total_count else None,
138+
)
139+
return KernelSearchResults(results=None)
140+
141+
async def _inner_search_vectorized(
142+
self,
143+
vector: list[float | int],
144+
options: VectorSearchOptions,
145+
**kwargs: Any,
146+
) -> KernelSearchResults[VectorSearchResult[TModel]]:
147+
return_records: dict[KEY_TYPES, float] = {}
148+
if not options.vector_field_name:
149+
raise ValueError("Vector field name must be provided in options for vector search.")
150+
field = options.vector_field_name
151+
assert isinstance(self.data_model_definition.fields.get(field), VectorStoreRecordVectorField) # nosec
152+
distance_metric = self.data_model_definition.fields.get(field).distance_function or "default" # type: ignore
153+
distance_func = DISTANCE_FUNCTION_MAP[distance_metric]
154+
155+
for key, record in self._get_filtered_records(options).items():
156+
if vector and field is not None:
157+
return_records[key] = self._calculate_vector_similarity(
158+
vector,
159+
record[field],
160+
distance_func,
161+
invert_score=distance_metric == DistanceFunction.COSINE_SIMILARITY,
162+
)
163+
if distance_metric in [DistanceFunction.COSINE_SIMILARITY, DistanceFunction.DOT_PROD]:
164+
sorted_records = dict(sorted(return_records.items(), key=lambda item: item[1], reverse=True))
165+
else:
166+
sorted_records = dict(sorted(return_records.items(), key=lambda item: item[1]))
167+
if sorted_records:
168+
return KernelSearchResults(
169+
results=self._get_vector_search_results_from_results(
170+
self._generate_return_list(sorted_records, options)
171+
),
172+
total_count=len(return_records) if options and options.include_total_count else None,
173+
)
174+
return KernelSearchResults(results=None)
175+
176+
async def _generate_return_list(
177+
self, return_records: dict[KEY_TYPES, float], options: VectorSearchOptions | None
178+
) -> AsyncIterable[dict]:
179+
top = 3 if not options else options.top
180+
skip = 0 if not options else options.skip
181+
returned = 0
182+
for idx, key in enumerate(return_records.keys()):
183+
if idx >= skip:
184+
returned += 1
185+
rec = self.inner_storage[key]
186+
rec[IN_MEMORY_SCORE_KEY] = return_records[key]
187+
yield rec
188+
if returned >= top:
189+
break
190+
191+
def _get_filtered_records(self, options: VectorSearchOptions | None) -> dict[KEY_TYPES, dict]:
192+
if options and options.filter:
193+
for filter in options.filter.filters:
194+
return {key: record for key, record in self.inner_storage.items() if self._apply_filter(record, filter)}
195+
return self.inner_storage
196+
197+
def _should_add_text_search(self, search_text: str, record: dict) -> bool:
198+
for field in self.data_model_definition.fields.values():
199+
if not isinstance(field, VectorStoreRecordVectorField) and search_text in record.get(field.name, ""):
200+
return True
201+
return False
202+
203+
def _calculate_vector_similarity(
204+
self,
205+
search_vector: list[float | int],
206+
record_vector: list[float | int],
207+
distance_func: Callable,
208+
invert_score: bool = False,
209+
) -> float:
210+
calc = distance_func(record_vector, search_vector)
211+
if invert_score:
212+
return 1.0 - float(calc)
213+
return float(calc)
214+
215+
@staticmethod
216+
def _apply_filter(record: dict[str, Any], filter: FilterClauseBase) -> bool:
217+
match filter:
218+
case EqualTo():
219+
value = record.get(filter.field_name)
220+
if not value:
221+
return False
222+
return value.lower() == filter.value.lower()
223+
case AnyTagsEqualTo():
224+
tag_list = record.get(filter.field_name)
225+
if not tag_list:
226+
return False
227+
if not isinstance(tag_list, list):
228+
tag_list = [tag_list]
229+
return filter.value in tag_list
230+
case _:
231+
return True
232+
233+
def _get_record_from_result(self, result: Any) -> Any:
234+
return result
235+
236+
def _get_score_from_result(self, result: Any) -> float | None:
237+
return result.get(IN_MEMORY_SCORE_KEY)

python/semantic_kernel/data/filter_clauses/any_tags_equal_to_filter_clause.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44
from pydantic import Field
55

66
from semantic_kernel.data.filter_clauses.filter_clause_base import FilterClauseBase
7-
from semantic_kernel.kernel_pydantic import KernelBaseModel
87
from semantic_kernel.utils.experimental_decorator import experimental_class
98

109

1110
@experimental_class
12-
class AnyTagsEqualTo(FilterClauseBase, KernelBaseModel):
13-
"""A filter clause for a any tags equals comparison."""
11+
class AnyTagsEqualTo(FilterClauseBase):
12+
"""A filter clause for a any tags equals comparison.
1413
15-
filter_clause_type: str = Field("any_tags_equal_to", init=False) # type: ignore
14+
Args:
15+
field_name: The name of the field containing the list of tags.
16+
value: The value to compare against the list of tags.
17+
"""
1618

17-
field_name: str
18-
value: str
19+
filter_clause_type: str = Field("any_tags_equal_to", init=False) # type: ignore

python/semantic_kernel/data/filter_clauses/equal_to_filter_clause.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@
33
from pydantic import Field
44

55
from semantic_kernel.data.filter_clauses.filter_clause_base import FilterClauseBase
6-
from semantic_kernel.kernel_pydantic import KernelBaseModel
76
from semantic_kernel.utils.experimental_decorator import experimental_class
87

98

109
@experimental_class
11-
class EqualTo(FilterClauseBase, KernelBaseModel):
12-
"""A filter clause for an equals comparison."""
10+
class EqualTo(FilterClauseBase):
11+
"""A filter clause for an equals comparison.
1312
14-
filter_clause_type: str = Field("equal_to", init=False) # type: ignore
13+
Args:
14+
field_name: The name of the field to compare.
15+
value: The value to compare against the field.
16+
17+
"""
1518

16-
field_name: str
17-
value: str
19+
filter_clause_type: str = Field("equal_to", init=False) # type: ignore

python/semantic_kernel/data/filter_clauses/filter_clause_base.py

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
from abc import ABC
5+
from typing import Any
56

67
from pydantic import Field
78

@@ -14,3 +15,5 @@ class FilterClauseBase(ABC, KernelBaseModel):
1415
"""A base for all filter clauses."""
1516

1617
filter_clause_type: str = Field("FilterClauseBase", init=False) # type: ignore
18+
field_name: str
19+
value: Any

python/semantic_kernel/search/__init__.py

Whitespace-only changes.

python/semantic_kernel/search/const.py

-11
This file was deleted.

0 commit comments

Comments
 (0)