File size: 15,690 Bytes
c588193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
import os
from dotenv import load_dotenv
from langchain_community.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_ollama import ChatOllama
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import Runnable
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
import chromadb
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
from typing import Optional, List
import re

# Disable ChromaDB telemetry to avoid the error
os.environ["ANONYMIZED_TELEMETRY"] = "False"
os.environ["CHROMA_SERVER_HOST"] = "localhost"
os.environ["CHROMA_SERVER_HTTP_PORT"] = "8000"


class ImprovedTFIDFEmbeddings(Embeddings):
    """Improved TF-IDF based embedding function with better preprocessing."""
    
    def __init__(self):
        self.vectorizer = TfidfVectorizer(
            max_features=5000,
            stop_words='english',
            ngram_range=(1, 3),
            min_df=1,
            max_df=0.85,
            lowercase=True,
            strip_accents='unicode',
            analyzer='word'
        )
        self.fitted = False
        self.documents = []
    
    def embed_documents(self, texts):
        """Create embeddings for a list of texts."""
        if not self.fitted:
            self.documents = texts
            self.vectorizer.fit(texts)
            self.fitted = True
        
        # Transform texts to TF-IDF vectors
        tfidf_matrix = self.vectorizer.transform(texts)
        
        # Convert to dense arrays and normalize
        embeddings = []
        for i in range(tfidf_matrix.shape[0]):
            embedding = tfidf_matrix[i].toarray().flatten()
            # Normalize the embedding
            norm = np.linalg.norm(embedding)
            if norm > 0:
                embedding = embedding / norm
            # Pad or truncate to 512 dimensions
            if len(embedding) < 512:
                embedding = np.pad(embedding, (0, 512 - len(embedding)))
            else:
                embedding = embedding[:512]
            embeddings.append(embedding.tolist())
        
        return embeddings
    
    def embed_query(self, text):
        """Create embedding for a single query text."""
        if not self.fitted:
            # If not fitted, fit with just this text
            self.vectorizer.fit([text])
            self.fitted = True
        
        # Transform query to TF-IDF vector
        tfidf_matrix = self.vectorizer.transform([text])
        embedding = tfidf_matrix[0].toarray().flatten()
        
        # Normalize the embedding
        norm = np.linalg.norm(embedding)
        if norm > 0:
            embedding = embedding / norm
        
        # Pad or truncate to 512 dimensions
        if len(embedding) < 512:
            embedding = np.pad(embedding, (0, 512 - len(embedding)))
        else:
            embedding = embedding[:512]
        
        return embedding.tolist()


class SmartFAQRetriever(BaseRetriever):
    """Smart retriever optimized for FAQ datasets with semantic similarity."""
    
    def __init__(self, documents: List[Document], k: int = 4):
        super().__init__()
        self._documents = documents
        self._k = k
        self._vectorizer = None  # Use private attribute
    
    @property
    def documents(self):
        return self._documents
    
    @property
    def k(self):
        return self._k
    
    def _get_relevant_documents(self, query: str) -> List[Document]:
        """Retrieve documents based on semantic similarity."""
        # Ensure vectorizer is fitted
        if not hasattr(self, '_vectorizer') or self._vectorizer is None or not hasattr(self._vectorizer, 'vocabulary_') or not self._vectorizer.vocabulary_:
            print("[SmartFAQRetriever] Fitting vectorizer...")
            self._vectorizer = TfidfVectorizer(
                max_features=3000,
                stop_words='english',
                ngram_range=(1, 2),
                min_df=1,
                max_df=0.9
            )
            questions = []
            for doc in self._documents:
                if "QUESTION:" in doc.page_content:
                    question_part = doc.page_content.split("ANSWER:")[0]
                    question = question_part.replace("QUESTION:", "").strip()
                    questions.append(question)
                else:
                    questions.append(doc.page_content)
            self._vectorizer.fit(questions)
        query_lower = query.lower().strip()
        
        # Extract questions from documents
        questions = []
        for doc in self._documents:
            if "QUESTION:" in doc.page_content:
                question_part = doc.page_content.split("ANSWER:")[0]
                question = question_part.replace("QUESTION:", "").strip()
                questions.append(question)
            else:
                questions.append(doc.page_content)
        
        # Transform query and questions to TF-IDF vectors
        query_vector = self._vectorizer.transform([query_lower])
        question_vectors = self._vectorizer.transform(questions)
        
        # Calculate cosine similarities
        similarities = cosine_similarity(query_vector, question_vectors).flatten()
        
        # Get top k documents
        top_indices = similarities.argsort()[-self._k:][::-1]
        
        # Return documents with highest similarity scores
        relevant_docs = [self._documents[i] for i in top_indices if similarities[i] > 0.1]
        
        if not relevant_docs:
            # Fallback to first k documents if no good matches
            relevant_docs = self._documents[:self._k]
        
        return relevant_docs
    
    async def _aget_relevant_documents(self, query: str) -> List[Document]:
        """Async version of get_relevant_documents."""
        return self._get_relevant_documents(query)

def setup_retriever(use_kaggle_data: bool = False, kaggle_dataset: Optional[str] = None, 
                   kaggle_username: Optional[str] = None, kaggle_key: Optional[str] = None,
                   use_local_mental_health_data: bool = False) -> BaseRetriever:
    """
    Creates a vector store with documents from test data, Kaggle datasets, or local mental health data.
    
    Args:
        use_kaggle_data: Whether to load Kaggle data instead of test documents
        kaggle_dataset: Kaggle dataset name (e.g., 'username/dataset-name')
        kaggle_username: Your Kaggle username (optional if using kaggle.json)
        kaggle_key: Your Kaggle API key (optional if using kaggle.json)
        use_local_mental_health_data: Whether to load local mental health FAQ data
    """
    print("Setting up the retriever...")
    
    if use_local_mental_health_data:
        try:
            print("Loading mental health FAQ data from local file...")
            mental_health_file = "data/Mental_Health_FAQ.csv"
            
            if not os.path.exists(mental_health_file):
                print(f"Mental health FAQ file not found: {mental_health_file}")
                use_local_mental_health_data = False
            else:
                # Load mental health FAQ data
                df = pd.read_csv(mental_health_file)
                documents = []
                
                for _, row in df.iterrows():
                    question = row['Questions']
                    answer = row['Answers']
                    # Create document in FAQ format
                    content = f"QUESTION: {question}\nANSWER: {answer}"
                    documents.append(Document(page_content=content))
                
                print(f"Loaded {len(documents)} mental health FAQ documents")
                for i, doc in enumerate(documents[:3]):
                    print(f"Sample FAQ {i+1}: {doc.page_content[:200]}...")
                
        except Exception as e:
            print(f"Error loading mental health data: {e}")
            use_local_mental_health_data = False
    
    if use_kaggle_data and kaggle_dataset:
        try:
            from src.kaggle_loader import KaggleDataLoader
            
            print(f"Loading Kaggle dataset: {kaggle_dataset}")
            # Create loader without parameters - it will auto-load from kaggle.json
            loader = KaggleDataLoader()
            
            # Download the dataset
            dataset_path = loader.download_dataset(kaggle_dataset)
            
            # Load documents based on file type - only process files from this specific dataset
            documents = []
            
            # Get the dataset name to identify the correct files
            dataset_name = kaggle_dataset.split('/')[-1]
            print(f"Processing files in dataset directory: {dataset_path}")
            
            for file in os.listdir(dataset_path):
                file_path = os.path.join(dataset_path, file)
                
                if file.endswith('.csv'):
                    print(f"Loading CSV file: {file}")
                    # For FAQ datasets, use the improved loading method
                    if 'faq' in file.lower() or 'mental' in file.lower():
                        documents.extend(loader.load_csv_dataset(file_path, [], chunk_size=50))
                    else:
                        # For other CSV files, use first few columns as text
                        df = pd.read_csv(file_path)
                        text_columns = df.columns[:3].tolist()  # Use first 3 columns
                        documents.extend(loader.load_csv_dataset(file_path, text_columns, chunk_size=50))
                
                elif file.endswith('.json'):
                    print(f"Loading JSON file: {file}")
                    documents.extend(loader.load_json_dataset(file_path))
                
                elif file.endswith('.txt'):
                    print(f"Loading text file: {file}")
                    documents.extend(loader.load_text_dataset(file_path))
            
            print(f"Loaded {len(documents)} documents from Kaggle dataset")
            for i, doc in enumerate(documents[:3]):
                print(f"Sample doc {i+1}: {doc.page_content[:200]}")
            
        except Exception as e:
            print(f"Error loading Kaggle data: {e}")
            print("Falling back to test documents...")
            use_kaggle_data = False
    
    if not use_kaggle_data and not use_local_mental_health_data:
        # No test documents - use mental health data as default
        print("No specific data source specified, loading mental health FAQ data as default...")
        try:
            mental_health_file = "data/Mental_Health_FAQ.csv"
            
            if not os.path.exists(mental_health_file):
                raise FileNotFoundError(f"Mental health FAQ file not found: {mental_health_file}")
            
            # Load mental health FAQ data
            df = pd.read_csv(mental_health_file)
            documents = []
            
            for _, row in df.iterrows():
                question = row['Questions']
                answer = row['Answers']
                # Create document in FAQ format
                content = f"QUESTION: {question}\nANSWER: {answer}"
                documents.append(Document(page_content=content))
            
            print(f"Loaded {len(documents)} mental health FAQ documents")
            for i, doc in enumerate(documents[:3]):
                print(f"Sample FAQ {i+1}: {doc.page_content[:200]}...")
                
        except Exception as e:
            print(f"Error loading mental health data: {e}")
            raise Exception("No valid data source available. Please ensure mental health FAQ data is present or provide Kaggle credentials.")

    print("Creating TF-IDF embeddings...")
    embeddings = ImprovedTFIDFEmbeddings()

    print("Creating ChromaDB vector store...")
    client = chromadb.PersistentClient(path="./src/chroma_db")
    
    # Clear existing collections to prevent mixing old and new data
    try:
        collections = client.list_collections()
        for collection in collections:
            print(f"Deleting existing collection: {collection.name}")
            client.delete_collection(collection.name)
    except Exception as e:
        print(f"Warning: Could not clear existing collections: {e}")
    
    print(f"Processing {len(documents)} documents...")
    
    # Check if this is a FAQ dataset and use smart retriever
    if any("QUESTION:" in doc.page_content for doc in documents):
        print("Using SmartFAQRetriever for better semantic matching...")
        return SmartFAQRetriever(documents, k=4)
    else:
        # Use vector store for non-FAQ datasets
        vectorstore = Chroma.from_documents(
            documents=documents,
            embedding=embeddings,
            client=client
    )
    print("Retriever setup complete.")
    return vectorstore.as_retriever(k=4)

def setup_rag_chain() -> Runnable:
    """Sets up the RAG chain with a prompt template and an LLM."""
    # Define the prompt template for the LLM
    prompt = PromptTemplate(
        template="""You are an assistant for question-answering tasks.
Use the following documents to answer the question.
If you don't know the answer, just say that you don't know.
Use three sentences maximum and keep the answer concise:
Question: {question}
Documents: {documents}
Answer:
""",
        input_variables=["question", "documents"],
    )

    # Initialize the LLM with dolphin-llama3:8b model
    # Note: This requires the Ollama server to be running with the specified model
    llm = ChatOllama(
        model="dolphin-llama3:8b",
        temperature=0,
    )

    # Create a chain combining the prompt template and LLM
    return prompt | llm | StrOutputParser()


# Define the RAG application class
class RAGApplication:
    def __init__(self, retriever: BaseRetriever, rag_chain: Runnable):
        self.retriever = retriever
        self.rag_chain = rag_chain

    def run(self, question: str) -> str:
        """Runs the RAG pipeline for a given question."""
        # Retrieve relevant documents
        documents = self.retriever.invoke(question)
        
        # Debug: Print retrieved documents
        print(f"\nDEBUG: Retrieved {len(documents)} documents for question: '{question}'")
        for i, doc in enumerate(documents):
            print(f"DEBUG: Document {i+1}: {doc.page_content[:200]}...")
        
        # Extract content from retrieved documents
        doc_texts = "\n\n".join([doc.page_content for doc in documents])
        
        # Debug: Print the combined document text
        print(f"DEBUG: Combined document text: {doc_texts[:300]}...")
        
        # Get the answer from the language model
        answer = self.rag_chain.invoke({"question": question, "documents": doc_texts})
        return answer

# Main execution block
if __name__ == "__main__":
    load_dotenv()

    # 1. Setup the components
    retriever = setup_retriever()
    rag_chain = setup_rag_chain()

    # 2. Initialize the RAG application
    rag_application = RAGApplication(retriever, rag_chain)

    # 3. Run an example query
    question = "What is prompt engineering"
    print("\n--- Running RAG Application ---")
    print(f"Question: {question}")
    answer = rag_application.run(question)
    print(f"Answer: {answer}")