Skip to content

Commit fa75bab

Browse files
committed
better batch processing
1 parent 9b9b5e8 commit fa75bab

File tree

1 file changed

+24
-29
lines changed

1 file changed

+24
-29
lines changed

Diff for: server.py

+24-29
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import openai
1313

14-
from typing import Dict, Union, Optional
14+
from typing import Dict, List, Union, Optional
1515
from collections import OrderedDict
1616
from flask import Flask, request, jsonify, abort
1717
from sentence_transformers import SentenceTransformer
@@ -85,16 +85,26 @@ def __init__(self, sbert_model: Optional[str] = None, openai_key: Optional[str]
8585

8686
if openai_key is not None:
8787
openai.api_key = self.openai_key
88-
logger.info('enabled model: text-embedding-ada-002')
88+
try:
89+
openai.Model.list()
90+
logger.info('enabled model: text-embedding-ada-002')
91+
except Exception as err:
92+
logger.error(f'Failed to connect to OpenAI API; disabling OpenAI model: {err}')
8993

90-
def generate(self, text: str, model_type: str) -> Dict[str, Union[str, float, list]]:
94+
def generate(self, text_batch: List[str], model_type: str) -> Dict[str, Union[str, float, list]]:
9195
start_time = time.time()
92-
result = {'status': 'success'}
96+
result = {
97+
'status': 'success',
98+
'message': '',
99+
'model': '',
100+
'elapsed': 0,
101+
'embeddings': []
102+
}
93103

94104
if model_type == 'openai':
95105
try:
96-
response = openai.Embedding.create(input=text, model='text-embedding-ada-002')
97-
result['embedding'] = response['data'][0]['embedding']
106+
response = openai.Embedding.create(input=text_batch, model='text-embedding-ada-002')
107+
result['embeddings'] = [data['embedding'] for data in response['data']]
98108
result['model'] = 'text-embedding-ada-002'
99109
except Exception as err:
100110
logger.error(f'Failed to get OpenAI embeddings: {err}')
@@ -103,8 +113,8 @@ def generate(self, text: str, model_type: str) -> Dict[str, Union[str, float, li
103113

104114
else:
105115
try:
106-
embedding = self.model.encode(text).tolist()
107-
result['embedding'] = embedding
116+
embedding = self.model.encode(text_batch, batch_size=len(text_batch), device='cuda').tolist()
117+
result['embeddings'] = embedding
108118
result['model'] = self.sbert_model
109119
except Exception as err:
110120
logger.error(f'Failed to get sentence-transformers embeddings: {err}')
@@ -145,33 +155,18 @@ def submit_text():
145155
if text_data is None:
146156
abort(400, 'Missing text data to embed')
147157

148-
if model_type not in ['local', 'openai']:
149-
abort(400, 'model field must be one of: local, openai')
150-
151-
if isinstance(text_data, str):
152-
text_data = [text_data]
153-
154158
if not all(isinstance(text, str) for text in text_data):
155159
abort(400, 'all data must be text strings')
156160

157161
results = []
158-
for text in text_data:
159-
result = None
160-
161-
if embedding_cache:
162-
result = embedding_cache.get(text, model_type)
163-
if result:
164-
logger.info('found embedding in cache!')
165-
result = {'embedding': result, 'cache': True, "status": 'success'}
166-
167-
if result is None:
168-
result = embedding_generator.generate(text, model_type)
162+
result = embedding_generator.generate(text_data, model_type)
169163

170-
if embedding_cache and result['status'] == 'success':
171-
embedding_cache.set(text, model_type, result['embedding'])
172-
logger.info('added to cache')
164+
if embedding_cache and result['status'] == 'success':
165+
for text, embedding in zip(text_data, result['embeddings']):
166+
embedding_cache.set(text, model_type, embedding)
167+
logger.info('added to cache')
173168

174-
results.append(result)
169+
results.append(result)
175170

176171
return jsonify(results)
177172

0 commit comments

Comments
 (0)