1
1
# Copyright (c) Microsoft. All rights reserved.
2
2
3
3
import sys
4
- from collections .abc import Mapping , Sequence
4
+ from collections .abc import AsyncIterable , Callable , Mapping , Sequence
5
5
from typing import Any , ClassVar , TypeVar
6
6
7
+ from pydantic import Field
8
+
9
+ from semantic_kernel .data .filter_clauses .any_tags_equal_to_filter_clause import AnyTagsEqualTo
10
+ from semantic_kernel .data .filter_clauses .equal_to_filter_clause import EqualTo
11
+
7
12
if sys .version_info >= (3 , 12 ):
8
13
from typing import override # pragma: no cover
9
14
else :
10
15
from typing_extensions import override # pragma: no cover
11
16
12
- from pydantic import Field
13
-
17
+ from semantic_kernel .connectors .memory .in_memory .const import DISTANCE_FUNCTION_MAP
18
+ from semantic_kernel .data .const import DistanceFunction
19
+ from semantic_kernel .data .filter_clauses .filter_clause_base import FilterClauseBase
20
+ from semantic_kernel .data .kernel_search_results import KernelSearchResults
14
21
from semantic_kernel .data .record_definition .vector_store_model_definition import VectorStoreRecordDefinition
15
- from semantic_kernel .data .vector_storage .vector_store_record_collection import VectorStoreRecordCollection
22
+ from semantic_kernel .data .record_definition .vector_store_record_fields import (
23
+ VectorStoreRecordVectorField ,
24
+ )
25
+ from semantic_kernel .data .vector_search .vector_search import VectorSearchBase
26
+ from semantic_kernel .data .vector_search .vector_search_options import VectorSearchOptions
27
+ from semantic_kernel .data .vector_search .vector_search_result import VectorSearchResult
28
+ from semantic_kernel .data .vector_search .vector_text_search import VectorTextSearchMixin
29
+ from semantic_kernel .data .vector_search .vectorized_search import VectorizedSearchMixin
30
+ from semantic_kernel .exceptions import VectorSearchExecutionException , VectorStoreModelValidationError
16
31
from semantic_kernel .kernel_types import OneOrMany
17
32
18
33
KEY_TYPES = str | int | float
19
34
20
35
TModel = TypeVar ("TModel" )
21
36
37
+ IN_MEMORY_SCORE_KEY = "in_memory_search_score"
38
+
22
39
23
- class InMemoryVectorCollection (VectorStoreRecordCollection [KEY_TYPES , TModel ]):
40
+ class InMemoryVectorCollection (
41
+ VectorSearchBase [KEY_TYPES , TModel ], VectorTextSearchMixin [TModel ], VectorizedSearchMixin [TModel ]
42
+ ):
24
43
"""In Memory Collection."""
25
44
26
45
inner_storage : dict [KEY_TYPES , dict ] = Field (default_factory = dict )
@@ -39,6 +58,12 @@ def __init__(
39
58
collection_name = collection_name ,
40
59
)
41
60
61
+ def _validate_data_model (self ):
62
+ """Check if the In Memory Score key is not used."""
63
+ super ()._validate_data_model ()
64
+ if IN_MEMORY_SCORE_KEY in self .data_model_definition .field_names :
65
+ raise VectorStoreModelValidationError (f"Field name '{ IN_MEMORY_SCORE_KEY } ' is reserved for internal use." )
66
+
42
67
@override
43
68
async def _inner_delete (self , keys : Sequence [KEY_TYPES ], ** kwargs : Any ) -> None :
44
69
for key in keys :
@@ -74,3 +99,139 @@ async def delete_collection(self, **kwargs: Any) -> None:
74
99
@override
75
100
async def does_collection_exist (self , ** kwargs : Any ) -> bool :
76
101
return True
102
+
103
+ @override
104
+ async def _inner_search (
105
+ self ,
106
+ options : VectorSearchOptions | None = None ,
107
+ search_text : str | None = None ,
108
+ vectorizable_text : str | None = None ,
109
+ vector : list [float | int ] | None = None ,
110
+ ** kwargs : Any ,
111
+ ) -> KernelSearchResults [VectorSearchResult [TModel ]]:
112
+ """Inner search method."""
113
+ if search_text :
114
+ return await self ._inner_search_text (search_text , options , ** kwargs )
115
+ if vector :
116
+ if not options :
117
+ raise VectorSearchExecutionException ("Options must be provided for vector search." )
118
+ return await self ._inner_search_vectorized (vector , options , ** kwargs )
119
+ raise VectorSearchExecutionException ("Search text or vector must be provided." )
120
+
121
+ async def _inner_search_text (
122
+ self ,
123
+ search_text : str ,
124
+ options : VectorSearchOptions | None = None ,
125
+ ** kwargs : Any ,
126
+ ) -> KernelSearchResults [VectorSearchResult [TModel ]]:
127
+ """Inner search method."""
128
+ return_records : dict [KEY_TYPES , float ] = {}
129
+ for key , record in self ._get_filtered_records (options ).items ():
130
+ if self ._should_add_text_search (search_text , record ):
131
+ return_records [key ] = 1.0
132
+ if return_records :
133
+ return KernelSearchResults (
134
+ results = self ._get_vector_search_results_from_results (
135
+ self ._generate_return_list (return_records , options )
136
+ ),
137
+ total_count = len (return_records ) if options and options .include_total_count else None ,
138
+ )
139
+ return KernelSearchResults (results = None )
140
+
141
+ async def _inner_search_vectorized (
142
+ self ,
143
+ vector : list [float | int ],
144
+ options : VectorSearchOptions ,
145
+ ** kwargs : Any ,
146
+ ) -> KernelSearchResults [VectorSearchResult [TModel ]]:
147
+ return_records : dict [KEY_TYPES , float ] = {}
148
+ if not options .vector_field_name :
149
+ raise ValueError ("Vector field name must be provided in options for vector search." )
150
+ field = options .vector_field_name
151
+ assert isinstance (self .data_model_definition .fields .get (field ), VectorStoreRecordVectorField ) # nosec
152
+ distance_metric = self .data_model_definition .fields .get (field ).distance_function or "default" # type: ignore
153
+ distance_func = DISTANCE_FUNCTION_MAP [distance_metric ]
154
+
155
+ for key , record in self ._get_filtered_records (options ).items ():
156
+ if vector and field is not None :
157
+ return_records [key ] = self ._calculate_vector_similarity (
158
+ vector ,
159
+ record [field ],
160
+ distance_func ,
161
+ invert_score = distance_metric == DistanceFunction .COSINE_SIMILARITY ,
162
+ )
163
+ if distance_metric in [DistanceFunction .COSINE_SIMILARITY , DistanceFunction .DOT_PROD ]:
164
+ sorted_records = dict (sorted (return_records .items (), key = lambda item : item [1 ], reverse = True ))
165
+ else :
166
+ sorted_records = dict (sorted (return_records .items (), key = lambda item : item [1 ]))
167
+ if sorted_records :
168
+ return KernelSearchResults (
169
+ results = self ._get_vector_search_results_from_results (
170
+ self ._generate_return_list (sorted_records , options )
171
+ ),
172
+ total_count = len (return_records ) if options and options .include_total_count else None ,
173
+ )
174
+ return KernelSearchResults (results = None )
175
+
176
+ async def _generate_return_list (
177
+ self , return_records : dict [KEY_TYPES , float ], options : VectorSearchOptions | None
178
+ ) -> AsyncIterable [dict ]:
179
+ top = 3 if not options else options .top
180
+ skip = 0 if not options else options .skip
181
+ returned = 0
182
+ for idx , key in enumerate (return_records .keys ()):
183
+ if idx >= skip :
184
+ returned += 1
185
+ rec = self .inner_storage [key ]
186
+ rec [IN_MEMORY_SCORE_KEY ] = return_records [key ]
187
+ yield rec
188
+ if returned >= top :
189
+ break
190
+
191
+ def _get_filtered_records (self , options : VectorSearchOptions | None ) -> dict [KEY_TYPES , dict ]:
192
+ if options and options .filter :
193
+ for filter in options .filter .filters :
194
+ return {key : record for key , record in self .inner_storage .items () if self ._apply_filter (record , filter )}
195
+ return self .inner_storage
196
+
197
+ def _should_add_text_search (self , search_text : str , record : dict ) -> bool :
198
+ for field in self .data_model_definition .fields .values ():
199
+ if not isinstance (field , VectorStoreRecordVectorField ) and search_text in record .get (field .name , "" ):
200
+ return True
201
+ return False
202
+
203
+ def _calculate_vector_similarity (
204
+ self ,
205
+ search_vector : list [float | int ],
206
+ record_vector : list [float | int ],
207
+ distance_func : Callable ,
208
+ invert_score : bool = False ,
209
+ ) -> float :
210
+ calc = distance_func (record_vector , search_vector )
211
+ if invert_score :
212
+ return 1.0 - float (calc )
213
+ return float (calc )
214
+
215
+ @staticmethod
216
+ def _apply_filter (record : dict [str , Any ], filter : FilterClauseBase ) -> bool :
217
+ match filter :
218
+ case EqualTo ():
219
+ value = record .get (filter .field_name )
220
+ if not value :
221
+ return False
222
+ return value .lower () == filter .value .lower ()
223
+ case AnyTagsEqualTo ():
224
+ tag_list = record .get (filter .field_name )
225
+ if not tag_list :
226
+ return False
227
+ if not isinstance (tag_list , list ):
228
+ tag_list = [tag_list ]
229
+ return filter .value in tag_list
230
+ case _:
231
+ return True
232
+
233
+ def _get_record_from_result (self , result : Any ) -> Any :
234
+ return result
235
+
236
+ def _get_score_from_result (self , result : Any ) -> float | None :
237
+ return result .get (IN_MEMORY_SCORE_KEY )
0 commit comments