-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathapp.py
265 lines (220 loc) · 8.34 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
# -*- coding: utf-8 -*-
from fastapi import FastAPI, responses, status, Request
from pydantic import BaseModel
from typing import List, Optional, Dict, Any, Union
from neural_search import util
from submodules.model.business_objects import (
general,
playground_question as playground_question_db_bo,
)
from submodules.model import session
app = FastAPI()
@app.middleware("http")
async def handle_db_session(request: Request, call_next):
session_token = general.get_ctx_token()
try:
response = await call_next(request)
finally:
general.remove_and_refresh_session(session_token)
return response
@app.post("/most_similar")
def most_similar(
project_id: str,
embedding_id: str,
record_id: str,
limit: int = 100,
att_filter: Optional[List[Dict[str, Any]]] = None,
record_sub_key: Optional[int] = None,
) -> responses.JSONResponse:
"""Find the n most similar records with respect to the specified record.
Args:
embedding_id (str): Embedding id.
record_id (str): The record for which similar records are searched.
limit (int): Specifies the maximum amount of returned records.
att_filter(Optional[Dict[str, Any]]]): Specifies the attribute filter for the search as dict objects.
example_filter = [
{"key": "name", "value": ["John", "Doe"]}, -> name IN ("John", "Doe")
{"key": "age", "value": 42}, -> age = 42
{"key": "age", "value": [35,40]}, -> age IN (35,40)
{"key": "age", "value": [35,40], type:"between"} -> age BETWEEN 35 AND 40 (includes 35 and 40)
]
Returns:
JSONResponse: containing HTML status code and the n most similar records
"""
similar_records = util.most_similar(
project_id, embedding_id, record_id, limit, att_filter, record_sub_key
)
return responses.JSONResponse(
status_code=status.HTTP_200_OK,
content=similar_records,
)
class MostSimilarByEmbeddingRequest(BaseModel):
project_id: str
embedding_id: str
embedding_tensor: List[float]
limit: int = 5
att_filter: Optional[List[Dict[str, Any]]] = None
threshold: Optional[Union[float, int]] = None
question: Optional[str] = None
@app.post("/most_similar_by_embedding")
def most_similar_by_embedding(
request: MostSimilarByEmbeddingRequest,
include_scores: bool = False,
) -> responses.JSONResponse:
"""Find the n most similar records with respect to the specified embedding.
Args:
embedding_id (str): Embedding id.
record_id (str): The record for which similar records are searched.
limit (int): Specifies the maximum amount of returned records.
att_filter(Optional[Dict[str, Any]]]): Specifies the attribute filter for the search as dict objects.
threshold: Optional[float]: None = calculated DB threshold, -9999 = no threshold, specified = use value
example_filter = [
{"key": "name", "value": ["John", "Doe"]}, -> name IN ("John", "Doe")
{"key": "age", "value": 42}, -> age = 42
{"key": "age", "value": [35,40]}, -> age IN (35,40)
{"key": "age", "value": [35,40], type:"between"} -> age BETWEEN 35 AND 40 (includes 35 and 40)
]
Returns:
JSONResponse: containing HTML status code and the n most similar records
"""
if isinstance(request.threshold, int):
request.threshold = float(request.threshold)
similar_records = util.most_similar_by_embedding(
request.project_id,
request.embedding_id,
request.embedding_tensor,
request.limit,
request.att_filter,
request.threshold,
include_scores,
)
if request.question:
playground_question_db_bo.create(
request.project_id,
request.question,
with_commit=True,
)
return responses.JSONResponse(
status_code=status.HTTP_200_OK,
content=similar_records,
)
@app.post("/recreate_collection")
def recreate_collection(
project_id: str, embedding_id: str
) -> responses.PlainTextResponse:
"""Create collection in Qdrant for the given embedding.
Args:
embedding_id (str): Embedding id.
Returns:
JSONResponse: html status code
"""
status_code = util.recreate_collection(project_id, embedding_id)
return responses.PlainTextResponse(status_code=status_code)
@app.get("/collections")
def get_collections() -> responses.JSONResponse:
"""
Get list of existing collections.
Returns:
JSONResponse: html status code, list of collection names
"""
collections = util.get_collections()
return responses.JSONResponse(status_code=status.HTTP_200_OK, content=collections)
@app.get("/collection/exist")
def collection_exists(
project_id: str, embedding_id: str, include_db_check: bool = False
) -> responses.JSONResponse:
"""
Check if a collection exists in Qdrant and optionally in the database.
Returns:
JSONResponse: html status code, exists: bool
"""
return responses.JSONResponse(
status_code=status.HTTP_200_OK,
content={
"exists": util.collection_exists(project_id, embedding_id, include_db_check)
},
)
@app.put("/create_missing_collections")
def create_missing_collections() -> responses.JSONResponse:
"""
Looks up embeddings for which no collection in Qdrant exists and creates these missing collections.
Returns:
JSONResponse: html status code
"""
status_code, content = util.create_missing_collections()
return responses.JSONResponse(status_code=status_code, content=content)
class UpdateAttributePayloadsRequest(BaseModel):
project_id: str
embedding_id: str
record_ids: Optional[List[str]] = None
@app.post("/update_attribute_payloads")
def update_attribute_payloads(
request: UpdateAttributePayloadsRequest,
) -> responses.PlainTextResponse:
try:
util.update_attribute_payloads(
request.project_id,
request.embedding_id,
request.record_ids,
)
except Exception:
return responses.PlainTextResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)
class UpdateLabelPayloadsRequest(BaseModel):
project_id: str
embedding_ids: List[str]
record_ids: Optional[List[str]] = None
@app.post("/update_label_payloads")
def update_label_payloads(
request: UpdateLabelPayloadsRequest,
) -> responses.PlainTextResponse:
util.update_label_payloads(
request.project_id,
request.embedding_ids,
request.record_ids,
)
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)
@app.put("/delete_collection")
def delete_collection(embedding_id: str) -> responses.PlainTextResponse:
"""
Delete collection in Qdrant for the given embedding.
Args:
embedding_id (str)
Returns:
JSONResponse: html status code
"""
util.delete_collection(embedding_id)
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)
@app.get("/detect_outliers")
def detect_outliers(
project_id: str, embedding_id: str, limit: int = 100
) -> responses.JSONResponse:
"""Detect outliers in the unlabeled records with respect to the already labeled records.
The unlabeled record ids are returned sorted, beginning with the most outlying.
Args:
project_id (str): Project id.
embedding_id (str): Embedding id.
limit (int): Specifies the maximum amount of returned records.
Returns:
JSONResponse: html status code, if successful the response the top_n most outlying records.
"""
status_code, content = util.detect_outliers(project_id, embedding_id, limit)
return responses.JSONResponse(
status_code=status_code,
content=content,
)
@app.get("/healthcheck")
def healthcheck() -> responses.PlainTextResponse:
text = ""
status_code = status.HTTP_200_OK
database_test = general.test_database_connection()
if not database_test.get("success"):
error_name = database_test.get("error")
text += f"database_error:{error_name}:"
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
if not text:
text = "OK"
return responses.PlainTextResponse(text, status_code=status_code)
session.start_session_cleanup_thread()