brendon-ai commited on
Commit
05d39f8
·
verified ·
1 Parent(s): b2f9df1

Create RAGSample.py

Browse files
Files changed (1) hide show
  1. src/RAGSample.py +388 -0
src/RAGSample.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langchain_community.document_loaders import WebBaseLoader
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain_community.vectorstores import Chroma
6
+ from langchain_ollama import ChatOllama
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.retrievers import BaseRetriever
10
+ from langchain_core.runnables import Runnable
11
+ from langchain_core.documents import Document
12
+ from langchain_core.embeddings import Embeddings
13
+ import chromadb
14
+ import numpy as np
15
+ from sklearn.feature_extraction.text import TfidfVectorizer
16
+ from sklearn.metrics.pairwise import cosine_similarity
17
+ import pandas as pd
18
+ from typing import Optional, List
19
+ import re
20
+
21
+ # Disable ChromaDB telemetry to avoid the error
22
+ os.environ["ANONYMIZED_TELEMETRY"] = "False"
23
+ os.environ["CHROMA_SERVER_HOST"] = "localhost"
24
+ os.environ["CHROMA_SERVER_HTTP_PORT"] = "8000"
25
+
26
+
27
+ class ImprovedTFIDFEmbeddings(Embeddings):
28
+ """Improved TF-IDF based embedding function with better preprocessing."""
29
+
30
+ def __init__(self):
31
+ self.vectorizer = TfidfVectorizer(
32
+ max_features=5000,
33
+ stop_words='english',
34
+ ngram_range=(1, 3),
35
+ min_df=1,
36
+ max_df=0.85,
37
+ lowercase=True,
38
+ strip_accents='unicode',
39
+ analyzer='word'
40
+ )
41
+ self.fitted = False
42
+ self.documents = []
43
+
44
+ def embed_documents(self, texts):
45
+ """Create embeddings for a list of texts."""
46
+ if not self.fitted:
47
+ self.documents = texts
48
+ self.vectorizer.fit(texts)
49
+ self.fitted = True
50
+
51
+ # Transform texts to TF-IDF vectors
52
+ tfidf_matrix = self.vectorizer.transform(texts)
53
+
54
+ # Convert to dense arrays and normalize
55
+ embeddings = []
56
+ for i in range(tfidf_matrix.shape[0]):
57
+ embedding = tfidf_matrix[i].toarray().flatten()
58
+ # Normalize the embedding
59
+ norm = np.linalg.norm(embedding)
60
+ if norm > 0:
61
+ embedding = embedding / norm
62
+ # Pad or truncate to 512 dimensions
63
+ if len(embedding) < 512:
64
+ embedding = np.pad(embedding, (0, 512 - len(embedding)))
65
+ else:
66
+ embedding = embedding[:512]
67
+ embeddings.append(embedding.tolist())
68
+
69
+ return embeddings
70
+
71
+ def embed_query(self, text):
72
+ """Create embedding for a single query text."""
73
+ if not self.fitted:
74
+ # If not fitted, fit with just this text
75
+ self.vectorizer.fit([text])
76
+ self.fitted = True
77
+
78
+ # Transform query to TF-IDF vector
79
+ tfidf_matrix = self.vectorizer.transform([text])
80
+ embedding = tfidf_matrix[0].toarray().flatten()
81
+
82
+ # Normalize the embedding
83
+ norm = np.linalg.norm(embedding)
84
+ if norm > 0:
85
+ embedding = embedding / norm
86
+
87
+ # Pad or truncate to 512 dimensions
88
+ if len(embedding) < 512:
89
+ embedding = np.pad(embedding, (0, 512 - len(embedding)))
90
+ else:
91
+ embedding = embedding[:512]
92
+
93
+ return embedding.tolist()
94
+
95
+
96
+ class SmartFAQRetriever(BaseRetriever):
97
+ """Smart retriever optimized for FAQ datasets with semantic similarity."""
98
+
99
+ def __init__(self, documents: List[Document], k: int = 4):
100
+ super().__init__()
101
+ self._documents = documents
102
+ self._k = k
103
+ self._vectorizer = None # Use private attribute
104
+
105
+ @property
106
+ def documents(self):
107
+ return self._documents
108
+
109
+ @property
110
+ def k(self):
111
+ return self._k
112
+
113
+ def _get_relevant_documents(self, query: str) -> List[Document]:
114
+ """Retrieve documents based on semantic similarity."""
115
+ # Ensure vectorizer is fitted
116
+ if not hasattr(self, '_vectorizer') or self._vectorizer is None or not hasattr(self._vectorizer, 'vocabulary_') or not self._vectorizer.vocabulary_:
117
+ print("[SmartFAQRetriever] Fitting vectorizer...")
118
+ self._vectorizer = TfidfVectorizer(
119
+ max_features=3000,
120
+ stop_words='english',
121
+ ngram_range=(1, 2),
122
+ min_df=1,
123
+ max_df=0.9
124
+ )
125
+ questions = []
126
+ for doc in self._documents:
127
+ if "QUESTION:" in doc.page_content:
128
+ question_part = doc.page_content.split("ANSWER:")[0]
129
+ question = question_part.replace("QUESTION:", "").strip()
130
+ questions.append(question)
131
+ else:
132
+ questions.append(doc.page_content)
133
+ self._vectorizer.fit(questions)
134
+ query_lower = query.lower().strip()
135
+
136
+ # Extract questions from documents
137
+ questions = []
138
+ for doc in self._documents:
139
+ if "QUESTION:" in doc.page_content:
140
+ question_part = doc.page_content.split("ANSWER:")[0]
141
+ question = question_part.replace("QUESTION:", "").strip()
142
+ questions.append(question)
143
+ else:
144
+ questions.append(doc.page_content)
145
+
146
+ # Transform query and questions to TF-IDF vectors
147
+ query_vector = self._vectorizer.transform([query_lower])
148
+ question_vectors = self._vectorizer.transform(questions)
149
+
150
+ # Calculate cosine similarities
151
+ similarities = cosine_similarity(query_vector, question_vectors).flatten()
152
+
153
+ # Get top k documents
154
+ top_indices = similarities.argsort()[-self._k:][::-1]
155
+
156
+ # Return documents with highest similarity scores
157
+ relevant_docs = [self._documents[i] for i in top_indices if similarities[i] > 0.1]
158
+
159
+ if not relevant_docs:
160
+ # Fallback to first k documents if no good matches
161
+ relevant_docs = self._documents[:self._k]
162
+
163
+ return relevant_docs
164
+
165
+ async def _aget_relevant_documents(self, query: str) -> List[Document]:
166
+ """Async version of get_relevant_documents."""
167
+ return self._get_relevant_documents(query)
168
+
169
+ def setup_retriever(use_kaggle_data: bool = False, kaggle_dataset: Optional[str] = None,
170
+ kaggle_username: Optional[str] = None, kaggle_key: Optional[str] = None,
171
+ use_local_mental_health_data: bool = False) -> BaseRetriever:
172
+ """
173
+ Creates a vector store with documents from test data, Kaggle datasets, or local mental health data.
174
+
175
+ Args:
176
+ use_kaggle_data: Whether to load Kaggle data instead of test documents
177
+ kaggle_dataset: Kaggle dataset name (e.g., 'username/dataset-name')
178
+ kaggle_username: Your Kaggle username (optional if using kaggle.json)
179
+ kaggle_key: Your Kaggle API key (optional if using kaggle.json)
180
+ use_local_mental_health_data: Whether to load local mental health FAQ data
181
+ """
182
+ print("Setting up the retriever...")
183
+
184
+ if use_local_mental_health_data:
185
+ try:
186
+ print("Loading mental health FAQ data from local file...")
187
+ mental_health_file = "data/Mental_Health_FAQ.csv"
188
+
189
+ if not os.path.exists(mental_health_file):
190
+ print(f"Mental health FAQ file not found: {mental_health_file}")
191
+ use_local_mental_health_data = False
192
+ else:
193
+ # Load mental health FAQ data
194
+ df = pd.read_csv(mental_health_file)
195
+ documents = []
196
+
197
+ for _, row in df.iterrows():
198
+ question = row['Questions']
199
+ answer = row['Answers']
200
+ # Create document in FAQ format
201
+ content = f"QUESTION: {question}\nANSWER: {answer}"
202
+ documents.append(Document(page_content=content))
203
+
204
+ print(f"Loaded {len(documents)} mental health FAQ documents")
205
+ for i, doc in enumerate(documents[:3]):
206
+ print(f"Sample FAQ {i+1}: {doc.page_content[:200]}...")
207
+
208
+ except Exception as e:
209
+ print(f"Error loading mental health data: {e}")
210
+ use_local_mental_health_data = False
211
+
212
+ if use_kaggle_data and kaggle_dataset:
213
+ try:
214
+ from src.kaggle_loader import KaggleDataLoader
215
+
216
+ print(f"Loading Kaggle dataset: {kaggle_dataset}")
217
+ # Create loader without parameters - it will auto-load from kaggle.json
218
+ loader = KaggleDataLoader()
219
+
220
+ # Download the dataset
221
+ dataset_path = loader.download_dataset(kaggle_dataset)
222
+
223
+ # Load documents based on file type - only process files from this specific dataset
224
+ documents = []
225
+
226
+ # Get the dataset name to identify the correct files
227
+ dataset_name = kaggle_dataset.split('/')[-1]
228
+ print(f"Processing files in dataset directory: {dataset_path}")
229
+
230
+ for file in os.listdir(dataset_path):
231
+ file_path = os.path.join(dataset_path, file)
232
+
233
+ if file.endswith('.csv'):
234
+ print(f"Loading CSV file: {file}")
235
+ # For FAQ datasets, use the improved loading method
236
+ if 'faq' in file.lower() or 'mental' in file.lower():
237
+ documents.extend(loader.load_csv_dataset(file_path, [], chunk_size=50))
238
+ else:
239
+ # For other CSV files, use first few columns as text
240
+ df = pd.read_csv(file_path)
241
+ text_columns = df.columns[:3].tolist() # Use first 3 columns
242
+ documents.extend(loader.load_csv_dataset(file_path, text_columns, chunk_size=50))
243
+
244
+ elif file.endswith('.json'):
245
+ print(f"Loading JSON file: {file}")
246
+ documents.extend(loader.load_json_dataset(file_path))
247
+
248
+ elif file.endswith('.txt'):
249
+ print(f"Loading text file: {file}")
250
+ documents.extend(loader.load_text_dataset(file_path))
251
+
252
+ print(f"Loaded {len(documents)} documents from Kaggle dataset")
253
+ for i, doc in enumerate(documents[:3]):
254
+ print(f"Sample doc {i+1}: {doc.page_content[:200]}")
255
+
256
+ except Exception as e:
257
+ print(f"Error loading Kaggle data: {e}")
258
+ print("Falling back to test documents...")
259
+ use_kaggle_data = False
260
+
261
+ if not use_kaggle_data and not use_local_mental_health_data:
262
+ # No test documents - use mental health data as default
263
+ print("No specific data source specified, loading mental health FAQ data as default...")
264
+ try:
265
+ mental_health_file = "data/Mental_Health_FAQ.csv"
266
+
267
+ if not os.path.exists(mental_health_file):
268
+ raise FileNotFoundError(f"Mental health FAQ file not found: {mental_health_file}")
269
+
270
+ # Load mental health FAQ data
271
+ df = pd.read_csv(mental_health_file)
272
+ documents = []
273
+
274
+ for _, row in df.iterrows():
275
+ question = row['Questions']
276
+ answer = row['Answers']
277
+ # Create document in FAQ format
278
+ content = f"QUESTION: {question}\nANSWER: {answer}"
279
+ documents.append(Document(page_content=content))
280
+
281
+ print(f"Loaded {len(documents)} mental health FAQ documents")
282
+ for i, doc in enumerate(documents[:3]):
283
+ print(f"Sample FAQ {i+1}: {doc.page_content[:200]}...")
284
+
285
+ except Exception as e:
286
+ print(f"Error loading mental health data: {e}")
287
+ raise Exception("No valid data source available. Please ensure mental health FAQ data is present or provide Kaggle credentials.")
288
+
289
+ print("Creating TF-IDF embeddings...")
290
+ embeddings = ImprovedTFIDFEmbeddings()
291
+
292
+ print("Creating ChromaDB vector store...")
293
+ client = chromadb.PersistentClient(path="./src/chroma_db")
294
+
295
+ # Clear existing collections to prevent mixing old and new data
296
+ try:
297
+ collections = client.list_collections()
298
+ for collection in collections:
299
+ print(f"Deleting existing collection: {collection.name}")
300
+ client.delete_collection(collection.name)
301
+ except Exception as e:
302
+ print(f"Warning: Could not clear existing collections: {e}")
303
+
304
+ print(f"Processing {len(documents)} documents...")
305
+
306
+ # Check if this is a FAQ dataset and use smart retriever
307
+ if any("QUESTION:" in doc.page_content for doc in documents):
308
+ print("Using SmartFAQRetriever for better semantic matching...")
309
+ return SmartFAQRetriever(documents, k=4)
310
+ else:
311
+ # Use vector store for non-FAQ datasets
312
+ vectorstore = Chroma.from_documents(
313
+ documents=documents,
314
+ embedding=embeddings,
315
+ client=client
316
+ )
317
+ print("Retriever setup complete.")
318
+ return vectorstore.as_retriever(k=4)
319
+
320
+ def setup_rag_chain() -> Runnable:
321
+ """Sets up the RAG chain with a prompt template and an LLM."""
322
+ # Define the prompt template for the LLM
323
+ prompt = PromptTemplate(
324
+ template="""You are an assistant for question-answering tasks.
325
+ Use the following documents to answer the question.
326
+ If you don't know the answer, just say that you don't know.
327
+ Use three sentences maximum and keep the answer concise:
328
+ Question: {question}
329
+ Documents: {documents}
330
+ Answer:
331
+ """,
332
+ input_variables=["question", "documents"],
333
+ )
334
+
335
+ # Initialize the LLM with dolphin-llama3:8b model
336
+ # Note: This requires the Ollama server to be running with the specified model
337
+ llm = ChatOllama(
338
+ model="dolphin-llama3:8b",
339
+ temperature=0,
340
+ )
341
+
342
+ # Create a chain combining the prompt template and LLM
343
+ return prompt | llm | StrOutputParser()
344
+
345
+
346
+ # Define the RAG application class
347
+ class RAGApplication:
348
+ def __init__(self, retriever: BaseRetriever, rag_chain: Runnable):
349
+ self.retriever = retriever
350
+ self.rag_chain = rag_chain
351
+
352
+ def run(self, question: str) -> str:
353
+ """Runs the RAG pipeline for a given question."""
354
+ # Retrieve relevant documents
355
+ documents = self.retriever.invoke(question)
356
+
357
+ # Debug: Print retrieved documents
358
+ print(f"\nDEBUG: Retrieved {len(documents)} documents for question: '{question}'")
359
+ for i, doc in enumerate(documents):
360
+ print(f"DEBUG: Document {i+1}: {doc.page_content[:200]}...")
361
+
362
+ # Extract content from retrieved documents
363
+ doc_texts = "\n\n".join([doc.page_content for doc in documents])
364
+
365
+ # Debug: Print the combined document text
366
+ print(f"DEBUG: Combined document text: {doc_texts[:300]}...")
367
+
368
+ # Get the answer from the language model
369
+ answer = self.rag_chain.invoke({"question": question, "documents": doc_texts})
370
+ return answer
371
+
372
+ # Main execution block
373
+ if __name__ == "__main__":
374
+ load_dotenv()
375
+
376
+ # 1. Setup the components
377
+ retriever = setup_retriever()
378
+ rag_chain = setup_rag_chain()
379
+
380
+ # 2. Initialize the RAG application
381
+ rag_application = RAGApplication(retriever, rag_chain)
382
+
383
+ # 3. Run an example query
384
+ question = "What is prompt engineering"
385
+ print("\n--- Running RAG Application ---")
386
+ print(f"Question: {question}")
387
+ answer = rag_application.run(question)
388
+ print(f"Answer: {answer}")