DeepakKolhe1995 commited on
Commit
1490204
·
verified ·
1 Parent(s): b30095c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -74
app.py CHANGED
@@ -14,8 +14,10 @@ from langchain_core.prompts import ChatPromptTemplate
14
  from langchain_core.runnables import RunnablePassthrough
15
  from langchain_core.output_parsers import StrOutputParser
16
  from qdrant_client import QdrantClient
 
17
  import re
18
  import json
 
19
 
20
  # Set up logging
21
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -23,14 +25,19 @@ logger = logging.getLogger(__name__)
23
 
24
  def load_environment():
25
  """Load and validate environment variables."""
26
- load_dotenv()
27
- required_vars = ['AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'AWS_REGION', 'QDRANT_URL', 'QDRANT_API_KEY']
28
- missing_vars = [var for var in required_vars if not os.getenv(var)]
29
- if missing_vars:
30
- logger.error(f"Missing environment variables: {missing_vars}")
31
- st.error(f"Missing environment variables: {missing_vars}")
32
- raise ValueError(f"Missing environment variables: {missing_vars}")
33
- logger.info("Environment variables loaded successfully")
 
 
 
 
 
34
 
35
  @st.cache_resource
36
  def load_wikipedia_documents():
@@ -42,6 +49,10 @@ def load_wikipedia_documents():
42
  )
43
  documents = [Document(page_content=item["text"]) for item in dataset]
44
  logger.info(f"Loaded {len(documents)} Wikipedia documents")
 
 
 
 
45
  return documents
46
  except Exception as e:
47
  logger.error(f"Error loading dataset: {e}")
@@ -52,9 +63,17 @@ def load_wikipedia_documents():
52
  def split_documents(_documents):
53
  """Split documents into chunks."""
54
  try:
 
 
 
 
55
  splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
56
  chunks = splitter.split_documents(_documents)
57
  logger.info(f"Split into {len(chunks)} chunks")
 
 
 
 
58
  return chunks
59
  except Exception as e:
60
  logger.error(f"Error splitting documents: {e}")
@@ -76,47 +95,79 @@ def initialize_embeddings():
76
  st.error(f"Failed to initialize embeddings: {e}")
77
  return None
78
 
79
- @st.cache_resource
80
  def store_in_qdrant(_chunks, _embeddings):
81
- """Store document chunks in a hosted Qdrant instance after deleting existing collection."""
82
  try:
83
  # Initialize Qdrant client
84
  client = QdrantClient(
85
  url=os.getenv("QDRANT_URL"),
86
- api_key=os.getenv("QDRANT_API_KEY")
 
87
  )
88
 
89
- # Delete existing collection if it exists
90
- collection_name = "wikipedia_chunks"
91
  try:
92
- client.delete_collection(collection_name)
93
- logger.info(f"Deleted existing Qdrant collection: {collection_name}")
94
  except Exception as e:
95
- logger.warning(f"No existing collection {collection_name} to delete or error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  # Create and populate new collection
98
- vector_store = Qdrant.from_documents(
99
- documents=_chunks,
100
- embedding=_embeddings,
101
- url=os.getenv("QDRANT_URL"),
102
- api_key=os.getenv("QDRANT_API_KEY"),
103
- collection_name=collection_name
104
- )
105
-
106
- # Verify storage by checking collection size
107
- collection_info = client.get_collection(collection_name)
108
- stored_points = collection_info.points_count
109
- logger.info(f"Stored {stored_points} chunks in Qdrant at {os.getenv('QDRANT_URL')}")
110
- if stored_points == 0:
111
- logger.error("No documents stored in Qdrant collection")
112
- st.error("No documents stored in Qdrant collection")
113
  return None
114
- if stored_points != len(_chunks):
115
- logger.warning(f"Expected {len(_chunks)} chunks, but stored {stored_points} in Qdrant")
116
-
117
- return vector_store
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  except Exception as e:
119
- logger.error(f"Error storing in Qdrant: {e}")
120
  st.error(f"Failed to store documents in Qdrant: {e}")
121
  return None
122
 
@@ -138,62 +189,78 @@ def initialize_llm():
138
 
139
  def extract_score_from_text(text):
140
  """Extract the first float number between 0 and 1 from the text using regex."""
141
- matches = re.findall(r'\b0(?:\.\d+)?\b|\b1(?:\.0+)?\b', text)
142
- if not matches:
143
- logger.warning("No score found in text.")
144
- return None
145
  try:
 
 
 
 
146
  score = float(matches[0])
147
  if 0.0 <= score <= 1.0:
148
  return score
149
- else:
150
- logger.warning(f"Score {score} out of expected range 0-1.")
151
- return None
152
- except ValueError:
153
- logger.warning(f"Cannot convert match to float: {matches[0]}")
154
  return None
155
 
156
  def claude_rerank(docs, query, llm, top_n=5):
157
  """Rerank documents based on relevance using the LLM."""
158
- rerank_prompt = ChatPromptTemplate.from_template(
159
- """
 
160
  Given the query: "{query}" and the document chunk: "{chunk}", please rate
161
  the relevance on a scale from 0 to 1 (0=not relevant, 1=highly relevant).
162
 
163
  Respond with a number only, like: 0.8
164
  """
165
- )
166
- scored_docs = []
167
- for idx, doc in enumerate(docs):
168
- prompt = rerank_prompt.format(query=query, chunk=doc.page_content)
169
- response = llm.invoke(prompt)
170
- text = response.content.strip()
171
- logger.info(f"Doc {idx} rerank raw output: {text}")
172
- score = extract_score_from_text(text)
173
- if score is None:
174
- logger.warning(f"Failed to extract valid score for doc {idx}. Assigning 0.")
175
- score = 0.0
176
- scored_docs.append((doc, score))
177
- scored_docs.sort(key=lambda x: x[1], reverse=True)
178
- logger.info(f"Reranked top {top_n} docs based on scores.")
179
- return [doc for doc, _ in scored_docs[:top_n]]
 
 
 
 
180
 
181
  def create_rag_chain(vector_store, llm, use_rerank=False):
182
  """Create a RAG chain with or without reranking."""
183
- prompt_template = ChatPromptTemplate.from_template(
184
- """You are a helpful assistant. Use the following context to answer the question concisely.\n\nContext:\n{context}\n\nQuestion: {question}\n\nAnswer:"""
185
- )
186
- retriever = vector_store.as_retriever(search_kwargs={"k": 20 if use_rerank else 5})
 
187
 
188
- def rerank_context(inputs):
189
- docs = retriever.invoke(inputs["question"])
190
- if use_rerank:
191
- docs = claude_rerank(docs, inputs["question"], llm)
192
- return {"context": "\n\n".join(doc.page_content for doc in docs), "question": inputs["question"]}
 
 
 
 
 
 
 
193
 
194
- chain = rerank_context | prompt_template | llm | StrOutputParser()
195
- logger.info(f"Initialized {'re-ranked' if use_rerank else 'baseline'} RAG chain")
196
- return chain
 
 
 
 
197
 
198
  def main():
199
  st.title("Wikipedia Q&A with RAG (Qdrant + AWS Bedrock)")
@@ -208,22 +275,33 @@ def main():
208
  # Initialize components
209
  documents = load_wikipedia_documents()
210
  if not documents:
 
211
  return
212
  chunks = split_documents(documents)
213
  if not chunks:
 
214
  return
215
  embeddings = initialize_embeddings()
216
  if embeddings is None:
 
217
  return
218
  vector_store = store_in_qdrant(chunks, embeddings)
219
  if vector_store is None:
 
220
  return
221
  llm = initialize_llm()
222
  if llm is None:
 
223
  return
224
 
225
  baseline_chain = create_rag_chain(vector_store, llm, use_rerank=False)
 
 
 
226
  rerank_chain = create_rag_chain(vector_store, llm, use_rerank=True)
 
 
 
227
 
228
  # Streamlit input
229
  query = st.text_input("Enter your question:", placeholder="e.g., What are the main causes of climate change?")
 
14
  from langchain_core.runnables import RunnablePassthrough
15
  from langchain_core.output_parsers import StrOutputParser
16
  from qdrant_client import QdrantClient
17
+ from qdrant_client.models import Distance, VectorParams
18
  import re
19
  import json
20
+ from urllib.error import URLError
21
 
22
  # Set up logging
23
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
25
 
26
  def load_environment():
27
  """Load and validate environment variables."""
28
+ try:
29
+ load_dotenv()
30
+ required_vars = ['AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'AWS_REGION', 'QDRANT_URL', 'QDRANT_API_KEY']
31
+ missing_vars = [var for var in required_vars if not os.getenv(var)]
32
+ if missing_vars:
33
+ logger.error(f"Missing environment variables: {missing_vars}")
34
+ st.error(f"Missing environment variables: {missing_vars}")
35
+ raise ValueError(f"Missing environment variables: {missing_vars}")
36
+ logger.info("Environment variables loaded successfully")
37
+ except Exception as e:
38
+ logger.error(f"Error loading environment variables: {e}")
39
+ st.error(f"Error loading environment variables: {e}")
40
+ raise
41
 
42
  @st.cache_resource
43
  def load_wikipedia_documents():
 
49
  )
50
  documents = [Document(page_content=item["text"]) for item in dataset]
51
  logger.info(f"Loaded {len(documents)} Wikipedia documents")
52
+ if not documents:
53
+ logger.error("No documents loaded from dataset")
54
+ st.error("No documents loaded from dataset")
55
+ return []
56
  return documents
57
  except Exception as e:
58
  logger.error(f"Error loading dataset: {e}")
 
63
  def split_documents(_documents):
64
  """Split documents into chunks."""
65
  try:
66
+ if not _documents:
67
+ logger.error("No documents provided for splitting")
68
+ st.error("No documents provided for splitting")
69
+ return []
70
  splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
71
  chunks = splitter.split_documents(_documents)
72
  logger.info(f"Split into {len(chunks)} chunks")
73
+ if not chunks:
74
+ logger.error("No chunks created from documents")
75
+ st.error("No chunks created from documents")
76
+ return []
77
  return chunks
78
  except Exception as e:
79
  logger.error(f"Error splitting documents: {e}")
 
95
  st.error(f"Failed to initialize embeddings: {e}")
96
  return None
97
 
 
98
  def store_in_qdrant(_chunks, _embeddings):
99
+ """Store document chunks in a hosted Qdrant instance after deleting all collections."""
100
  try:
101
  # Initialize Qdrant client
102
  client = QdrantClient(
103
  url=os.getenv("QDRANT_URL"),
104
+ api_key=os.getenv("QDRANT_API_KEY"),
105
+ timeout=30
106
  )
107
 
108
+ # Test Qdrant connection
 
109
  try:
110
+ client.get_collections()
111
+ logger.info("Successfully connected to Qdrant at %s", os.getenv("QDRANT_URL"))
112
  except Exception as e:
113
+ logger.error("Failed to connect to Qdrant: %s", e)
114
+ st.error(f"Failed to connect to Qdrant: {e}")
115
+ return None
116
+
117
+ # Delete all existing collections
118
+ try:
119
+ collections = client.get_collections().collections
120
+ for collection in collections:
121
+ client.delete_collection(collection.name)
122
+ logger.info(f"Deleted Qdrant collection: {collection.name}")
123
+ logger.info("All Qdrant collections deleted")
124
+ except Exception as e:
125
+ logger.warning(f"Error deleting collections: {e}")
126
+ st.warning(f"Error deleting collections: {e}")
127
+
128
+ # Validate input chunks
129
+ if not _chunks:
130
+ logger.error("No chunks provided for Qdrant storage")
131
+ st.error("No chunks provided for Qdrant storage")
132
+ return None
133
 
134
  # Create and populate new collection
135
+ collection_name = "wikipedia_chunks"
136
+ try:
137
+ vector_store = Qdrant.from_documents(
138
+ documents=_chunks,
139
+ embedding=_embeddings,
140
+ url=os.getenv("QDRANT_URL"),
141
+ api_key=os.getenv("QDRANT_API_KEY"),
142
+ collection_name=collection_name,
143
+ force_recreate=True # Ensure fresh collection
144
+ )
145
+ logger.info(f"Created Qdrant collection {collection_name} with {len(_chunks)} chunks")
146
+ except Exception as e:
147
+ logger.error(f"Error creating Qdrant collection: {e}")
148
+ st.error(f"Failed to create Qdrant collection: {e}")
 
149
  return None
150
+
151
+ # Verify storage
152
+ try:
153
+ collection_info = client.get_collection(collection_name)
154
+ stored_points = collection_info.points_count
155
+ logger.info(f"Stored {stored_points} points in Qdrant collection {collection_name}")
156
+ if stored_points == 0:
157
+ logger.error("No documents stored in Qdrant collection")
158
+ st.error("No documents stored in Qdrant collection")
159
+ return None
160
+ if stored_points != len(_chunks):
161
+ logger.warning(f"Expected {len(_chunks)} chunks, but stored {stored_points} in Qdrant")
162
+ st.warning(f"Expected {len(_chunks)} chunks, but stored {stored_points} in Qdrant")
163
+ return vector_store
164
+ except Exception as e:
165
+ logger.error(f"Error verifying Qdrant storage: {e}")
166
+ st.error(f"Failed to verify Qdrant storage: {e}")
167
+ return None
168
+
169
  except Exception as e:
170
+ logger.error(f"Error in Qdrant storage process: {e}")
171
  st.error(f"Failed to store documents in Qdrant: {e}")
172
  return None
173
 
 
189
 
190
  def extract_score_from_text(text):
191
  """Extract the first float number between 0 and 1 from the text using regex."""
 
 
 
 
192
  try:
193
+ matches = re.findall(r'\b0(?:\.\d+)?\b|\b1(?:\.0+)?\b', text)
194
+ if not matches:
195
+ logger.warning("No score found in text")
196
+ return None
197
  score = float(matches[0])
198
  if 0.0 <= score <= 1.0:
199
  return score
200
+ logger.warning(f"Score {score} out of expected range 0-1")
201
+ return None
202
+ except ValueError as e:
203
+ logger.warning(f"Cannot convert match to float: {e}")
 
204
  return None
205
 
206
  def claude_rerank(docs, query, llm, top_n=5):
207
  """Rerank documents based on relevance using the LLM."""
208
+ try:
209
+ rerank_prompt = ChatPromptTemplate.from_template(
210
+ """
211
  Given the query: "{query}" and the document chunk: "{chunk}", please rate
212
  the relevance on a scale from 0 to 1 (0=not relevant, 1=highly relevant).
213
 
214
  Respond with a number only, like: 0.8
215
  """
216
+ )
217
+ scored_docs = []
218
+ for idx, doc in enumerate(docs):
219
+ prompt = rerank_prompt.format(query=query, chunk=doc.page_content)
220
+ response = llm.invoke(prompt)
221
+ text = response.content.strip()
222
+ logger.info(f"Doc {idx} rerank raw output: {text}")
223
+ score = extract_score_from_text(text)
224
+ if score is None:
225
+ logger.warning(f"Failed to extract valid score for doc {idx}. Assigning 0.")
226
+ score = 0.0
227
+ scored_docs.append((doc, score))
228
+ scored_docs.sort(key=lambda x: x[1], reverse=True)
229
+ logger.info(f"Reranked top {top_n} docs based on scores")
230
+ return [doc for doc, _ in scored_docs[:top_n]]
231
+ except Exception as e:
232
+ logger.error(f"Error in reranking: {e}")
233
+ st.error(f"Error in reranking: {e}")
234
+ return docs[:top_n] # Fallback to original docs
235
 
236
  def create_rag_chain(vector_store, llm, use_rerank=False):
237
  """Create a RAG chain with or without reranking."""
238
+ try:
239
+ prompt_template = ChatPromptTemplate.from_template(
240
+ """You are a helpful assistant. Use the following context to answer the question concisely.\n\nContext:\n{context}\n\nQuestion: {question}\n\nAnswer:"""
241
+ )
242
+ retriever = vector_store.as_retriever(search_kwargs={"k": 20 if use_rerank else 5})
243
 
244
+ def rerank_context(inputs):
245
+ try:
246
+ docs = retriever.invoke(inputs["question"])
247
+ if not docs:
248
+ logger.warning("No documents retrieved for query")
249
+ return {"context": "", "question": inputs["question"]}
250
+ if use_rerank:
251
+ docs = claude_rerank(docs, inputs["question"], llm)
252
+ return {"context": "\n\n".join(doc.page_content for doc in docs), "question": inputs["question"]}
253
+ except Exception as e:
254
+ logger.error(f"Error in rerank_context: {e}")
255
+ return {"context": "", "question": inputs["question"]}
256
 
257
+ chain = rerank_context | prompt_template | llm | StrOutputParser()
258
+ logger.info(f"Initialized {'re-ranked' if use_rerank else 'baseline'} RAG chain")
259
+ return chain
260
+ except Exception as e:
261
+ logger.error(f"Error creating RAG chain: {e}")
262
+ st.error(f"Failed to create RAG chain: {e}")
263
+ return None
264
 
265
  def main():
266
  st.title("Wikipedia Q&A with RAG (Qdrant + AWS Bedrock)")
 
275
  # Initialize components
276
  documents = load_wikipedia_documents()
277
  if not documents:
278
+ st.error("Cannot proceed without documents")
279
  return
280
  chunks = split_documents(documents)
281
  if not chunks:
282
+ st.error("Cannot proceed without document chunks")
283
  return
284
  embeddings = initialize_embeddings()
285
  if embeddings is None:
286
+ st.error("Cannot proceed without embeddings")
287
  return
288
  vector_store = store_in_qdrant(chunks, embeddings)
289
  if vector_store is None:
290
+ st.error("Cannot proceed without vector store")
291
  return
292
  llm = initialize_llm()
293
  if llm is None:
294
+ st.error("Cannot proceed without LLM")
295
  return
296
 
297
  baseline_chain = create_rag_chain(vector_store, llm, use_rerank=False)
298
+ if baseline_chain is None:
299
+ st.error("Cannot proceed without baseline chain")
300
+ return
301
  rerank_chain = create_rag_chain(vector_store, llm, use_rerank=True)
302
+ if rerank_chain is None:
303
+ st.error("Cannot proceed without rerank chain")
304
+ return
305
 
306
  # Streamlit input
307
  query = st.text_input("Enter your question:", placeholder="e.g., What are the main causes of climate change?")