brendon-ai commited on
Commit
ab66284
·
verified ·
1 Parent(s): 05d39f8

Delete RAGSample.py

Browse files
Files changed (1) hide show
  1. RAGSample.py +0 -388
RAGSample.py DELETED
@@ -1,388 +0,0 @@
1
- import os
2
- from dotenv import load_dotenv
3
- from langchain_community.document_loaders import WebBaseLoader
4
- from langchain.text_splitter import RecursiveCharacterTextSplitter
5
- from langchain_community.vectorstores import Chroma
6
- from langchain_ollama import ChatOllama
7
- from langchain.prompts import PromptTemplate
8
- from langchain_core.output_parsers import StrOutputParser
9
- from langchain_core.retrievers import BaseRetriever
10
- from langchain_core.runnables import Runnable
11
- from langchain_core.documents import Document
12
- from langchain_core.embeddings import Embeddings
13
- import chromadb
14
- import numpy as np
15
- from sklearn.feature_extraction.text import TfidfVectorizer
16
- from sklearn.metrics.pairwise import cosine_similarity
17
- import pandas as pd
18
- from typing import Optional, List
19
- import re
20
-
21
- # Disable ChromaDB telemetry to avoid the error
22
- os.environ["ANONYMIZED_TELEMETRY"] = "False"
23
- os.environ["CHROMA_SERVER_HOST"] = "localhost"
24
- os.environ["CHROMA_SERVER_HTTP_PORT"] = "8000"
25
-
26
-
27
- class ImprovedTFIDFEmbeddings(Embeddings):
28
- """Improved TF-IDF based embedding function with better preprocessing."""
29
-
30
- def __init__(self):
31
- self.vectorizer = TfidfVectorizer(
32
- max_features=5000,
33
- stop_words='english',
34
- ngram_range=(1, 3),
35
- min_df=1,
36
- max_df=0.85,
37
- lowercase=True,
38
- strip_accents='unicode',
39
- analyzer='word'
40
- )
41
- self.fitted = False
42
- self.documents = []
43
-
44
- def embed_documents(self, texts):
45
- """Create embeddings for a list of texts."""
46
- if not self.fitted:
47
- self.documents = texts
48
- self.vectorizer.fit(texts)
49
- self.fitted = True
50
-
51
- # Transform texts to TF-IDF vectors
52
- tfidf_matrix = self.vectorizer.transform(texts)
53
-
54
- # Convert to dense arrays and normalize
55
- embeddings = []
56
- for i in range(tfidf_matrix.shape[0]):
57
- embedding = tfidf_matrix[i].toarray().flatten()
58
- # Normalize the embedding
59
- norm = np.linalg.norm(embedding)
60
- if norm > 0:
61
- embedding = embedding / norm
62
- # Pad or truncate to 512 dimensions
63
- if len(embedding) < 512:
64
- embedding = np.pad(embedding, (0, 512 - len(embedding)))
65
- else:
66
- embedding = embedding[:512]
67
- embeddings.append(embedding.tolist())
68
-
69
- return embeddings
70
-
71
- def embed_query(self, text):
72
- """Create embedding for a single query text."""
73
- if not self.fitted:
74
- # If not fitted, fit with just this text
75
- self.vectorizer.fit([text])
76
- self.fitted = True
77
-
78
- # Transform query to TF-IDF vector
79
- tfidf_matrix = self.vectorizer.transform([text])
80
- embedding = tfidf_matrix[0].toarray().flatten()
81
-
82
- # Normalize the embedding
83
- norm = np.linalg.norm(embedding)
84
- if norm > 0:
85
- embedding = embedding / norm
86
-
87
- # Pad or truncate to 512 dimensions
88
- if len(embedding) < 512:
89
- embedding = np.pad(embedding, (0, 512 - len(embedding)))
90
- else:
91
- embedding = embedding[:512]
92
-
93
- return embedding.tolist()
94
-
95
-
96
- class SmartFAQRetriever(BaseRetriever):
97
- """Smart retriever optimized for FAQ datasets with semantic similarity."""
98
-
99
- def __init__(self, documents: List[Document], k: int = 4):
100
- super().__init__()
101
- self._documents = documents
102
- self._k = k
103
- self._vectorizer = None # Use private attribute
104
-
105
- @property
106
- def documents(self):
107
- return self._documents
108
-
109
- @property
110
- def k(self):
111
- return self._k
112
-
113
- def _get_relevant_documents(self, query: str) -> List[Document]:
114
- """Retrieve documents based on semantic similarity."""
115
- # Ensure vectorizer is fitted
116
- if not hasattr(self, '_vectorizer') or self._vectorizer is None or not hasattr(self._vectorizer, 'vocabulary_') or not self._vectorizer.vocabulary_:
117
- print("[SmartFAQRetriever] Fitting vectorizer...")
118
- self._vectorizer = TfidfVectorizer(
119
- max_features=3000,
120
- stop_words='english',
121
- ngram_range=(1, 2),
122
- min_df=1,
123
- max_df=0.9
124
- )
125
- questions = []
126
- for doc in self._documents:
127
- if "QUESTION:" in doc.page_content:
128
- question_part = doc.page_content.split("ANSWER:")[0]
129
- question = question_part.replace("QUESTION:", "").strip()
130
- questions.append(question)
131
- else:
132
- questions.append(doc.page_content)
133
- self._vectorizer.fit(questions)
134
- query_lower = query.lower().strip()
135
-
136
- # Extract questions from documents
137
- questions = []
138
- for doc in self._documents:
139
- if "QUESTION:" in doc.page_content:
140
- question_part = doc.page_content.split("ANSWER:")[0]
141
- question = question_part.replace("QUESTION:", "").strip()
142
- questions.append(question)
143
- else:
144
- questions.append(doc.page_content)
145
-
146
- # Transform query and questions to TF-IDF vectors
147
- query_vector = self._vectorizer.transform([query_lower])
148
- question_vectors = self._vectorizer.transform(questions)
149
-
150
- # Calculate cosine similarities
151
- similarities = cosine_similarity(query_vector, question_vectors).flatten()
152
-
153
- # Get top k documents
154
- top_indices = similarities.argsort()[-self._k:][::-1]
155
-
156
- # Return documents with highest similarity scores
157
- relevant_docs = [self._documents[i] for i in top_indices if similarities[i] > 0.1]
158
-
159
- if not relevant_docs:
160
- # Fallback to first k documents if no good matches
161
- relevant_docs = self._documents[:self._k]
162
-
163
- return relevant_docs
164
-
165
- async def _aget_relevant_documents(self, query: str) -> List[Document]:
166
- """Async version of get_relevant_documents."""
167
- return self._get_relevant_documents(query)
168
-
169
- def setup_retriever(use_kaggle_data: bool = False, kaggle_dataset: Optional[str] = None,
170
- kaggle_username: Optional[str] = None, kaggle_key: Optional[str] = None,
171
- use_local_mental_health_data: bool = False) -> BaseRetriever:
172
- """
173
- Creates a vector store with documents from test data, Kaggle datasets, or local mental health data.
174
-
175
- Args:
176
- use_kaggle_data: Whether to load Kaggle data instead of test documents
177
- kaggle_dataset: Kaggle dataset name (e.g., 'username/dataset-name')
178
- kaggle_username: Your Kaggle username (optional if using kaggle.json)
179
- kaggle_key: Your Kaggle API key (optional if using kaggle.json)
180
- use_local_mental_health_data: Whether to load local mental health FAQ data
181
- """
182
- print("Setting up the retriever...")
183
-
184
- if use_local_mental_health_data:
185
- try:
186
- print("Loading mental health FAQ data from local file...")
187
- mental_health_file = "data/Mental_Health_FAQ.csv"
188
-
189
- if not os.path.exists(mental_health_file):
190
- print(f"Mental health FAQ file not found: {mental_health_file}")
191
- use_local_mental_health_data = False
192
- else:
193
- # Load mental health FAQ data
194
- df = pd.read_csv(mental_health_file)
195
- documents = []
196
-
197
- for _, row in df.iterrows():
198
- question = row['Questions']
199
- answer = row['Answers']
200
- # Create document in FAQ format
201
- content = f"QUESTION: {question}\nANSWER: {answer}"
202
- documents.append(Document(page_content=content))
203
-
204
- print(f"Loaded {len(documents)} mental health FAQ documents")
205
- for i, doc in enumerate(documents[:3]):
206
- print(f"Sample FAQ {i+1}: {doc.page_content[:200]}...")
207
-
208
- except Exception as e:
209
- print(f"Error loading mental health data: {e}")
210
- use_local_mental_health_data = False
211
-
212
- if use_kaggle_data and kaggle_dataset:
213
- try:
214
- from src.kaggle_loader import KaggleDataLoader
215
-
216
- print(f"Loading Kaggle dataset: {kaggle_dataset}")
217
- # Create loader without parameters - it will auto-load from kaggle.json
218
- loader = KaggleDataLoader()
219
-
220
- # Download the dataset
221
- dataset_path = loader.download_dataset(kaggle_dataset)
222
-
223
- # Load documents based on file type - only process files from this specific dataset
224
- documents = []
225
-
226
- # Get the dataset name to identify the correct files
227
- dataset_name = kaggle_dataset.split('/')[-1]
228
- print(f"Processing files in dataset directory: {dataset_path}")
229
-
230
- for file in os.listdir(dataset_path):
231
- file_path = os.path.join(dataset_path, file)
232
-
233
- if file.endswith('.csv'):
234
- print(f"Loading CSV file: {file}")
235
- # For FAQ datasets, use the improved loading method
236
- if 'faq' in file.lower() or 'mental' in file.lower():
237
- documents.extend(loader.load_csv_dataset(file_path, [], chunk_size=50))
238
- else:
239
- # For other CSV files, use first few columns as text
240
- df = pd.read_csv(file_path)
241
- text_columns = df.columns[:3].tolist() # Use first 3 columns
242
- documents.extend(loader.load_csv_dataset(file_path, text_columns, chunk_size=50))
243
-
244
- elif file.endswith('.json'):
245
- print(f"Loading JSON file: {file}")
246
- documents.extend(loader.load_json_dataset(file_path))
247
-
248
- elif file.endswith('.txt'):
249
- print(f"Loading text file: {file}")
250
- documents.extend(loader.load_text_dataset(file_path))
251
-
252
- print(f"Loaded {len(documents)} documents from Kaggle dataset")
253
- for i, doc in enumerate(documents[:3]):
254
- print(f"Sample doc {i+1}: {doc.page_content[:200]}")
255
-
256
- except Exception as e:
257
- print(f"Error loading Kaggle data: {e}")
258
- print("Falling back to test documents...")
259
- use_kaggle_data = False
260
-
261
- if not use_kaggle_data and not use_local_mental_health_data:
262
- # No test documents - use mental health data as default
263
- print("No specific data source specified, loading mental health FAQ data as default...")
264
- try:
265
- mental_health_file = "data/Mental_Health_FAQ.csv"
266
-
267
- if not os.path.exists(mental_health_file):
268
- raise FileNotFoundError(f"Mental health FAQ file not found: {mental_health_file}")
269
-
270
- # Load mental health FAQ data
271
- df = pd.read_csv(mental_health_file)
272
- documents = []
273
-
274
- for _, row in df.iterrows():
275
- question = row['Questions']
276
- answer = row['Answers']
277
- # Create document in FAQ format
278
- content = f"QUESTION: {question}\nANSWER: {answer}"
279
- documents.append(Document(page_content=content))
280
-
281
- print(f"Loaded {len(documents)} mental health FAQ documents")
282
- for i, doc in enumerate(documents[:3]):
283
- print(f"Sample FAQ {i+1}: {doc.page_content[:200]}...")
284
-
285
- except Exception as e:
286
- print(f"Error loading mental health data: {e}")
287
- raise Exception("No valid data source available. Please ensure mental health FAQ data is present or provide Kaggle credentials.")
288
-
289
- print("Creating TF-IDF embeddings...")
290
- embeddings = ImprovedTFIDFEmbeddings()
291
-
292
- print("Creating ChromaDB vector store...")
293
- client = chromadb.PersistentClient(path="./src/chroma_db")
294
-
295
- # Clear existing collections to prevent mixing old and new data
296
- try:
297
- collections = client.list_collections()
298
- for collection in collections:
299
- print(f"Deleting existing collection: {collection.name}")
300
- client.delete_collection(collection.name)
301
- except Exception as e:
302
- print(f"Warning: Could not clear existing collections: {e}")
303
-
304
- print(f"Processing {len(documents)} documents...")
305
-
306
- # Check if this is a FAQ dataset and use smart retriever
307
- if any("QUESTION:" in doc.page_content for doc in documents):
308
- print("Using SmartFAQRetriever for better semantic matching...")
309
- return SmartFAQRetriever(documents, k=4)
310
- else:
311
- # Use vector store for non-FAQ datasets
312
- vectorstore = Chroma.from_documents(
313
- documents=documents,
314
- embedding=embeddings,
315
- client=client
316
- )
317
- print("Retriever setup complete.")
318
- return vectorstore.as_retriever(k=4)
319
-
320
- def setup_rag_chain() -> Runnable:
321
- """Sets up the RAG chain with a prompt template and an LLM."""
322
- # Define the prompt template for the LLM
323
- prompt = PromptTemplate(
324
- template="""You are an assistant for question-answering tasks.
325
- Use the following documents to answer the question.
326
- If you don't know the answer, just say that you don't know.
327
- Use three sentences maximum and keep the answer concise:
328
- Question: {question}
329
- Documents: {documents}
330
- Answer:
331
- """,
332
- input_variables=["question", "documents"],
333
- )
334
-
335
- # Initialize the LLM with dolphin-llama3:8b model
336
- # Note: This requires the Ollama server to be running with the specified model
337
- llm = ChatOllama(
338
- model="dolphin-llama3:8b",
339
- temperature=0,
340
- )
341
-
342
- # Create a chain combining the prompt template and LLM
343
- return prompt | llm | StrOutputParser()
344
-
345
-
346
- # Define the RAG application class
347
- class RAGApplication:
348
- def __init__(self, retriever: BaseRetriever, rag_chain: Runnable):
349
- self.retriever = retriever
350
- self.rag_chain = rag_chain
351
-
352
- def run(self, question: str) -> str:
353
- """Runs the RAG pipeline for a given question."""
354
- # Retrieve relevant documents
355
- documents = self.retriever.invoke(question)
356
-
357
- # Debug: Print retrieved documents
358
- print(f"\nDEBUG: Retrieved {len(documents)} documents for question: '{question}'")
359
- for i, doc in enumerate(documents):
360
- print(f"DEBUG: Document {i+1}: {doc.page_content[:200]}...")
361
-
362
- # Extract content from retrieved documents
363
- doc_texts = "\n\n".join([doc.page_content for doc in documents])
364
-
365
- # Debug: Print the combined document text
366
- print(f"DEBUG: Combined document text: {doc_texts[:300]}...")
367
-
368
- # Get the answer from the language model
369
- answer = self.rag_chain.invoke({"question": question, "documents": doc_texts})
370
- return answer
371
-
372
- # Main execution block
373
- if __name__ == "__main__":
374
- load_dotenv()
375
-
376
- # 1. Setup the components
377
- retriever = setup_retriever()
378
- rag_chain = setup_rag_chain()
379
-
380
- # 2. Initialize the RAG application
381
- rag_application = RAGApplication(retriever, rag_chain)
382
-
383
- # 3. Run an example query
384
- question = "What is prompt engineering"
385
- print("\n--- Running RAG Application ---")
386
- print(f"Question: {question}")
387
- answer = rag_application.run(question)
388
- print(f"Answer: {answer}")