Skip to content

Commit 734ae05

Browse files
committed
Adding pytorch model to cloud and other minor changes
1 parent 00e8881 commit 734ae05

File tree

6 files changed

+28
-18
lines changed

6 files changed

+28
-18
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ Icon
1414
# Keys
1515
keyfile.json
1616

17+
# Pytorch Model
18+
models/*.bin
19+
1720
# Thumbnails
1821
._*
1922

app.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from search import *
44
from build import *
55
from utils import docs2text, id2details
6-
from cloud_storage import test_file_exists, download_blob, upload_blob, pull_indices
6+
from cloud_storage import test_file_exists, download_blob, upload_blob, pull_indices, download_pytorch_model
77

88
# from deeppavlov import build_model
99
# dp_model = build_model('models/squad_torch_bert.json', download=True)
@@ -16,7 +16,7 @@
1616
# GCP test connections
1717
test_file_exists()
1818
download_blob("symptomizer_indices_bucket-1", "hello.txt", "test.txt")
19-
19+
download_pytorch_model()
2020
pull_indices()
2121

2222
bert_model = load_bert_model()

build.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@
103103
# else:
104104
# yield collection.find(query,projection)[chunks[i-1]:chunks.stop]
105105

106+
def upload_indices_and_vectors():
107+
upload_blob("symptomizer_indices_bucket-1", "models/bert.index", "bert.index")
108+
upload_blob("symptomizer_indices_bucket-1", "models/tfidf.index", "tfidf.index")
109+
upload_blob("symptomizer_indices_bucket-1", "models/ids.joblib", "ids.joblib")
110+
upload_blob("symptomizer_indices_bucket-1", "models/tfidf_model.joblib", "tfidf_model.joblib")
111+
print("Completed Uploading indices to bucket")
106112

107113
def build_tfidf_model(num_docs=2000, max_features=1500):
108114
print("Building TF-IDF model...")
@@ -125,16 +131,11 @@ def load_bert_model():
125131
print("Completed BERT model.")
126132
return model
127133

128-
def upload_faiss():
129-
upload_blob("symptomizer_indices_bucket-1", "models/bert.index", "bert.index")
130-
upload_blob("symptomizer_indices_bucket-1", "models/tfidf.index", "tfidf.index")
131-
upload_blob("symptomizer_indices_bucket-1", "models/ids.joblib", "ids.joblib")
132-
print("Completed Uploading indices to bucket")
133134

134135
def build_faiss(tfidf_model, bert_model):
135136
tr = tracker.SummaryTracker()
136137
print(f"Building indices ...")
137-
c = collection.find().count()
138+
c = 1500
138139
batch_size = 500
139140
encoder = None
140141
bert_index = None
@@ -181,7 +182,7 @@ def build_faiss(tfidf_model, bert_model):
181182
faiss.write_index(tfidf_index,f"models/tfidf.index")
182183
dump(ids,'models/ids.joblib')
183184
print(f"Completed indices.")
184-
upload_faiss()
185+
upload_indices_and_vectors()
185186
return [tfidf_index, bert_index]
186187

187188
def load_faiss(tfidf_model, bert_model):

cloud_storage.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22
import os.path
33
import os
44
import json
5+
import wget
6+
7+
def download_pytorch_model():
8+
if(not os.path.exists("models/pytorch_model.bin")):
9+
print("Pytorch QA model not detected. Downloading...")
10+
url = 'https://storage.googleapis.com/symptomizer_model_bucket/pytorch_model.bin'
11+
wget.download(url, 'models/pytorch_model.bin')
12+
print("Download Complete.")
13+
else:
14+
print("Pytorch model exists. Skipping download.")
515

616
def fix_json():
717
with open('keyfile.json') as json_file:
@@ -11,11 +21,7 @@ def fix_json():
1121

1222

1323
def test_file_exists():
14-
print("Key File Exists")
15-
print(os.path.isfile("keyfile.json"))
16-
1724
f = open("keyfile.json", "r")
18-
print(f.read())
1925
fix_json()
2026

2127
def check_if_exists(bucket_name, file_name):
@@ -42,7 +48,7 @@ def upload_blob(bucket_name, source_file_name, destination_blob_name):
4248
)
4349
)
4450

45-
# upload_blob("symptomizer_indices_bucket-1", "hello.txt", "hello.txt")
51+
# upload_blob("symptomizer_model_bucket", "models/pytorch_model.bin", "pytorch_model.bin")
4652
# print(check_if_exists("symptomizer_indices_bucket-1", "hello.txt"))
4753
# download_blob("symptomizer_indices_bucket-1", "hello.txt", "downloaded.txt")
4854

@@ -76,5 +82,7 @@ def pull_indices():
7682
download_blob("symptomizer_indices_bucket-1", "tfidf.index", "models/tfidf.index")
7783
download_blob("symptomizer_indices_bucket-1", "bert.index", "models/bert.index")
7884
download_blob("symptomizer_indices_bucket-1", "ids.joblib", "models/ids.joblib")
85+
download_blob("symptomizer_indices_bucket-1", "tfidf_model.joblib", "models/tfidf_model.joblib")
7986
else:
8087
print("No PULL_INDS env found. Rebuilding indices")
88+

models/pytorch_model.bin

-3
This file was deleted.

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ sentence-transformers
1313
pytorch_transformers
1414
dask[complete]
1515
Pympler
16-
google-cloud-storage
16+
google-cloud-storage
17+
wget

0 commit comments

Comments
 (0)