-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
229 lines (191 loc) · 8.42 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import openai
import os
import uuid
import io
import faiss
import json
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import pymongo
import boto3
import fitz # PyMuPDF for PDFs
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from langchain.docstore import InMemoryDocstore
# Load environment variables
load_dotenv()
# Initialize FastAPI
app = FastAPI()
# Enable CORS Middleware
app.add_middleware(
CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
)
# Connect to MongoDB
client = pymongo.MongoClient(os.getenv("MONGO_URL"))
db = client["chat_with_doc"]
conversationcol = db["chat-history"]
feedback_col = db["feedback"]
# AWS S3 Client
s3 = boto3.client("s3")
# OpenAI API Key
openai.api_key = os.getenv("OPENAI_API_KEY")
# Global FAISS database
faiss_db = None
# Chat Message Model
class ChatMessage(BaseModel):
session_id: str = None
user_input: str
data_source: str
# Function to Extract Text from PDF in S3
def extract_text_from_s3(bucket_name, s3_key):
"""Reads a PDF from S3 and extracts text."""
try:
obj = s3.get_object(Bucket=bucket_name, Key=s3_key)
print("✅ Successfully downloaded PDF from S3.")
doc = fitz.open(stream=io.BytesIO(obj["Body"].read()), filetype="pdf")
text_data = [page.get_text("text") for page in doc]
print(f"✅ Extracted {len(text_data)} pages of text from the PDF.")
print("Sample text from the first page:", text_data[0][:500]) # Print first 500 characters
return text_data
except Exception as e:
print(f"❌ PDF Processing Error: {e}")
return None
# Function to Store Extracted Text in FAISS
def store_text_in_faiss(text_data):
"""Splits text into chunks, converts them into embeddings, and stores in FAISS."""
global faiss_db
try:
# Check if text_data is empty
if not text_data:
print("❌ No text extracted from the PDF. Check S3 and PDF contents.")
return False
# Print a sample of the extracted text
print("✅ Extracted text sample:", text_data[:500]) # Print first 500 characters
# Enhanced text splitter
text_splitter = RecursiveCharacterTextSplitter(
separators=["\n# ", "\n## ", "\n### ", "\n\n", "\n", " ", ""],
chunk_size=1000, # Increased chunk size for better context
chunk_overlap=200, # Increased overlap for better context
length_function=len,
is_separator_regex=False
)
# Join all pages into one text if it's a list
if isinstance(text_data, list):
text_data = "\n\n".join(text_data)
# Print how many characters we are processing
print(f"Processing {len(text_data)} characters from the PDF")
# Create chunks
chunks = text_splitter.create_documents([text_data])
print(f"✅ Created {len(chunks)} text chunks")
# Use OpenAI embeddings
embeddings = OpenAIEmbeddings()
# Print FAISS storage details
print(f"✅ Storing {len(chunks)} text chunks into FAISS")
# Create FAISS index
faiss_db = FAISS.from_documents(chunks, embeddings)
# Save the FAISS index
faiss.write_index(faiss_db.index, "faiss_index.bin")
print("✅ FAISS storage successful!")
return True
except Exception as e:
print(f"❌ FAISS Storage Error: {e}")
return False
# Function to Load FAISS Index
def load_faiss_index():
"""Loads the FAISS index from a file."""
global faiss_db
try:
if not os.path.exists("faiss_index.bin"):
print("❌ FAISS index file not found. Process a PDF first.")
return False
# Load the FAISS index from the file
index = faiss.read_index("faiss_index.bin")
embeddings = OpenAIEmbeddings()
# Initialize the docstore and index_to_docstore_id
documents = [Document(page_content="dummy")] # Dummy document for mock example
docstore = InMemoryDocstore({str(i): doc for i, doc in enumerate(documents)})
index_to_docstore_id = {str(i): str(i) for i in range(len(documents))}
# Reinitialize the FAISS database with the updated index
faiss_db = FAISS(
embedding_function=embeddings.embed_query,
index=index,
docstore=docstore,
index_to_docstore_id=index_to_docstore_id
)
print("✅ FAISS index loaded successfully!")
return True
except Exception as e:
print(f"❌ Failed to load FAISS index: {e}")
return False
# Lifespan Event handler (new method)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Setup and cleanup for the app."""
print("Starting up the application and loading FAISS index...")
if not load_faiss_index():
print("❌ Failed to load FAISS index during startup.")
yield # When app is running, this is active
print("Shutting down the application...")
# Add lifespan to the FastAPI app
app = FastAPI(lifespan=lifespan)
# API Endpoint: Store PDF Text in FAISS
@app.post("/process-pdf")
async def process_pdf():
"""Extracts text from the S3 PDF, stores it in FAISS, and rebuilds the index."""
global faiss_db
bucket_name = "ai-document-storage" # S3 bucket name
s3_key = "ai_document.pdf" # Updated document path
# Fetch the document from S3 and extract its text
pdf_text = extract_text_from_s3(bucket_name, s3_key)
if not pdf_text:
return JSONResponse({"error": "Failed to extract text from PDF"}, status_code=500)
# Rebuild FAISS index with the updated content
if not store_text_in_faiss(pdf_text):
return JSONResponse({"error": "Failed to store text in FAISS"}, status_code=500)
# Notify the user that the process was successful
return JSONResponse({"message": "✅ PDF processed and FAISS index updated successfully!"})
# API Endpoint: Chat Using FAISS Search
@app.post("/chat")
async def chat(chat: ChatMessage):
"""Retrieves answers from FAISS-stored PDF data."""
global faiss_db
if not faiss_db:
# If FAISS database is empty, load it from the file
if not load_faiss_index():
return JSONResponse({"error": "FAISS index not loaded. Process a PDF first."}, status_code=500)
session_id = chat.session_id or str(uuid.uuid4())
chat_history = conversationcol.find_one({"session_id": session_id}) or {"conversation": []}
# Search FAISS for relevant text
retrieved_docs = faiss_db.similarity_search(chat.user_input, k=5) # Increase k to 5 or higher
context = "\n".join([doc.page_content for doc in retrieved_docs]) if retrieved_docs else "No relevant context found."
# Generate AI Response via OpenAI API Call
try:
if not retrieved_docs:
ai_response = "I don't have enough information to answer that question."
else:
response = openai.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": """You are an AI assistant for Psymeon's ALVIE app.
Answer questions **only** based on the provided context. If the context doesn't contain enough
information to answer accurately, say, "I don't have enough information to answer that question."
Use markdown formatting for better readability."""},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {chat.user_input}"}
],
temperature=0.7 # Adjust for creativity vs. accuracy
)
ai_response = response.choices[0].message.content
except Exception as e:
print(f"❌ OpenAI API Error: {e}")
return JSONResponse({"error": "Failed to generate AI response"}, status_code=500)
# Store chat in MongoDB
conversationcol.update_one(
{"session_id": session_id}, {"$push": {"conversation": [chat.user_input, ai_response]}}, upsert=True
)
return JSONResponse({"response": ai_response, "session_id": session_id, "chat_history": chat_history["conversation"]})