-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
86 lines (65 loc) · 3.06 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import streamlit as st
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEndpoint
# FAISS database path
DB_FAISS_PATH = "vectorstore/db_faiss"
@st.cache_resource
def get_vectorstore():
embedding_model = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
db = FAISS.load_local(DB_FAISS_PATH, embedding_model, allow_dangerous_deserialization=True)
return db
def set_custom_prompt(custom_prompt_template):
return PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"])
def load_llm(huggingface_repo_id, HF_TOKEN):
return HuggingFaceEndpoint(
repo_id=huggingface_repo_id,
temperature=0.5,
model_kwargs={"token": HF_TOKEN, "max_length": "512"}
)
def main():
st.title("Medical Chatbot !! Made by Iqbal ")
if 'messages' not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
st.chat_message(message['role']).markdown(message['content'])
prompt = st.chat_input("Pass your prompt here")
if prompt:
st.chat_message('user').markdown(prompt)
st.session_state.messages.append({'role': 'user', 'content': prompt})
CUSTOM_PROMPT_TEMPLATE = """
Use the pieces of information provided in the context to answer the user's question.
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
Don't provide anything out of the given context.
Context: {context}
Question: {question}
Start the answer directly. No small talk, please.
"""
HUGGINGFACE_REPO_ID = "mistralai/Mistral-7B-Instruct-v0.3"
HF_TOKEN = os.environ.get("HF_TOKEN")
try:
vectorstore = get_vectorstore()
if vectorstore is None:
st.error("Failed to load the vector store")
return
qa_chain = RetrievalQA.from_chain_type(
llm=load_llm(huggingface_repo_id=HUGGINGFACE_REPO_ID, HF_TOKEN=HF_TOKEN),
chain_type="stuff",
retriever=vectorstore.as_retriever(search_kwargs={'k': 3}),
return_source_documents=True,
chain_type_kwargs={'prompt': set_custom_prompt(CUSTOM_PROMPT_TEMPLATE)}
)
response = qa_chain.invoke({'query': prompt})
result = response["result"]
source_documents = response["source_documents"]
st.chat_message('assistant').markdown(result)
st.session_state.messages.append({'role': 'assistant', 'content': result})
st.subheader("Source Documents")
st.write(source_documents)
except Exception as e:
st.error(f"Error: {str(e)}")
if __name__ == "__main__":
main()