Skip to content

Commit 92cb1dc

Browse files
committed
Trying to make pulling indices
1 parent 4dd655f commit 92cb1dc

File tree

4 files changed

+40
-9
lines changed

4 files changed

+40
-9
lines changed

app.py

+34-5
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,6 @@
2525
ids = load('models/ids.joblib')
2626
qa_model = QA('models')
2727

28-
@strawberry.type
29-
class User:
30-
name: str
31-
age: int
32-
3328

3429
@strawberry.type
3530
class Document:
@@ -54,6 +49,17 @@ class QAResult:
5449
answer: str
5550
confidence : float
5651

52+
53+
@strawberry.type
54+
class MetaData:
55+
tf_idf_len_diff: int
56+
bert_len_diff: int
57+
58+
@strawberry.type
59+
class IndexingResult:
60+
status: str
61+
metadata: MetaData
62+
5763
@strawberry.type
5864
class Query:
5965
@strawberry.field
@@ -80,4 +86,27 @@ def qa(self, q: str) -> QAResult:
8086
answer = qa_model.predict(" ".join(reference),q)
8187
return QAResult(answer = answer['answer'], confidence = answer['confidence'], )
8288

89+
@strawberry.field
90+
def pull_updates_from_index_cloud(self) -> IndexingResult:
91+
global tfidf_faiss, bert_faiss, ids, qa_model
92+
tf_idf_prev_len = tfidf_faiss.ntotal
93+
bert_prev_len = bert_faiss.ntotal
94+
95+
print("Previous TFIDF length: {}".format(tf_idf_prev_len))
96+
pull_indices(True)
97+
98+
tfidf_faiss, bert_faiss = load_faiss(tfidf_model, bert_model)
99+
ids = load('models/ids.joblib')
100+
qa_model = QA('models')
101+
102+
metadata = {'tfidf_len_diff': tfidf_faiss.ntotal - tf_idf_prev_len, 'bert_len_diff': bert_faiss.ntotal - bert_prev_len}
103+
return IndexingResult(status = "Success", metadata = metadata)
104+
105+
# @strawberry.field
106+
# def reindex(self) -> IndexingResult:
107+
# D, I = vector_search(q, bert_model, bert_faiss)
108+
# reference = [x["description"] for x in collection.find({'_id': {'$in': (np.array(ids)[I[0][:2]]).tolist()}})]
109+
# answer = qa_model.predict(" ".join(reference),q)
110+
# return QAResult(answer = answer['answer'], confidence = answer['confidence'], )
111+
83112
schema = strawberry.Schema(query=Query)

build.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ def load_bert_model():
135135
def build_faiss(tfidf_model, bert_model):
136136
tr = tracker.SummaryTracker()
137137
print(f"Building indices ...")
138-
c = collection.find().count()
138+
# c = collection.find().count()
139+
c = 1000
139140
batch_size = 500
140141
encoder = None
141142
bert_index = None
@@ -182,7 +183,7 @@ def build_faiss(tfidf_model, bert_model):
182183
faiss.write_index(tfidf_index,f"models/tfidf.index")
183184
dump(ids,'models/ids.joblib')
184185
print(f"Completed indices.")
185-
upload_indices_and_vectors()
186+
# upload_indices_and_vectors()
186187
return [tfidf_index, bert_index]
187188

188189
def load_faiss(tfidf_model, bert_model):

cloud_storage.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ def download_blob(bucket_name, source_blob_name, destination_file_name):
7676
)
7777
)
7878

79-
def pull_indices():
79+
def pull_indices(run_manually=False):
8080
# Checks if PULL_INDS environment variable is present, and calls pull function
81-
if os.environ.get('PULL_INDS') != None:
81+
if run_manually or os.environ.get('PULL_INDS') != None:
8282
download_blob("symptomizer_indices_bucket-1", "tfidf.index", "models/tfidf.index")
8383
download_blob("symptomizer_indices_bucket-1", "bert.index", "models/bert.index")
8484
download_blob("symptomizer_indices_bucket-1", "ids.joblib", "models/ids.joblib")

test.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
bye fsh

0 commit comments

Comments
 (0)