11
11
12
12
import openai
13
13
14
- from typing import Dict , Union , Optional
14
+ from typing import Dict , List , Union , Optional
15
15
from collections import OrderedDict
16
16
from flask import Flask , request , jsonify , abort
17
17
from sentence_transformers import SentenceTransformer
@@ -85,16 +85,26 @@ def __init__(self, sbert_model: Optional[str] = None, openai_key: Optional[str]
85
85
86
86
if openai_key is not None :
87
87
openai .api_key = self .openai_key
88
- logger .info ('enabled model: text-embedding-ada-002' )
88
+ try :
89
+ openai .Model .list ()
90
+ logger .info ('enabled model: text-embedding-ada-002' )
91
+ except Exception as err :
92
+ logger .error (f'Failed to connect to OpenAI API; disabling OpenAI model: { err } ' )
89
93
90
- def generate (self , text : str , model_type : str ) -> Dict [str , Union [str , float , list ]]:
94
+ def generate (self , text_batch : List [ str ] , model_type : str ) -> Dict [str , Union [str , float , list ]]:
91
95
start_time = time .time ()
92
- result = {'status' : 'success' }
96
+ result = {
97
+ 'status' : 'success' ,
98
+ 'message' : '' ,
99
+ 'model' : '' ,
100
+ 'elapsed' : 0 ,
101
+ 'embeddings' : []
102
+ }
93
103
94
104
if model_type == 'openai' :
95
105
try :
96
- response = openai .Embedding .create (input = text , model = 'text-embedding-ada-002' )
97
- result ['embedding ' ] = response [ ' data' ][ 0 ][ 'embedding' ]
106
+ response = openai .Embedding .create (input = text_batch , model = 'text-embedding-ada-002' )
107
+ result ['embeddings ' ] = [ data [ 'embedding' ] for data in response [ 'data' ] ]
98
108
result ['model' ] = 'text-embedding-ada-002'
99
109
except Exception as err :
100
110
logger .error (f'Failed to get OpenAI embeddings: { err } ' )
@@ -103,8 +113,8 @@ def generate(self, text: str, model_type: str) -> Dict[str, Union[str, float, li
103
113
104
114
else :
105
115
try :
106
- embedding = self .model .encode (text ).tolist ()
107
- result ['embedding ' ] = embedding
116
+ embedding = self .model .encode (text_batch , batch_size = len ( text_batch ), device = 'cuda' ).tolist ()
117
+ result ['embeddings ' ] = embedding
108
118
result ['model' ] = self .sbert_model
109
119
except Exception as err :
110
120
logger .error (f'Failed to get sentence-transformers embeddings: { err } ' )
@@ -145,33 +155,18 @@ def submit_text():
145
155
if text_data is None :
146
156
abort (400 , 'Missing text data to embed' )
147
157
148
- if model_type not in ['local' , 'openai' ]:
149
- abort (400 , 'model field must be one of: local, openai' )
150
-
151
- if isinstance (text_data , str ):
152
- text_data = [text_data ]
153
-
154
158
if not all (isinstance (text , str ) for text in text_data ):
155
159
abort (400 , 'all data must be text strings' )
156
160
157
161
results = []
158
- for text in text_data :
159
- result = None
160
-
161
- if embedding_cache :
162
- result = embedding_cache .get (text , model_type )
163
- if result :
164
- logger .info ('found embedding in cache!' )
165
- result = {'embedding' : result , 'cache' : True , "status" : 'success' }
166
-
167
- if result is None :
168
- result = embedding_generator .generate (text , model_type )
162
+ result = embedding_generator .generate (text_data , model_type )
169
163
170
- if embedding_cache and result ['status' ] == 'success' :
171
- embedding_cache .set (text , model_type , result ['embedding' ])
172
- logger .info ('added to cache' )
164
+ if embedding_cache and result ['status' ] == 'success' :
165
+ for text , embedding in zip (text_data , result ['embeddings' ]):
166
+ embedding_cache .set (text , model_type , embedding )
167
+ logger .info ('added to cache' )
173
168
174
- results .append (result )
169
+ results .append (result )
175
170
176
171
return jsonify (results )
177
172
0 commit comments