Update src/RAGSample.py
Browse files- src/RAGSample.py +93 -27
src/RAGSample.py
CHANGED
@@ -117,6 +117,38 @@ class SmartFAQRetriever(BaseRetriever):
|
|
117 |
@property
|
118 |
def k(self):
|
119 |
return self._k
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
def _get_relevant_documents(self, query: str) -> List[Document]:
|
122 |
"""Retrieve documents based on semantic similarity."""
|
@@ -447,48 +479,82 @@ class RAGApplication:
|
|
447 |
self.retriever = retriever
|
448 |
self.rag_chain = rag_chain
|
449 |
|
450 |
-
def run(self, question: str) -> str:
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
|
457 |
-
|
458 |
-
|
459 |
|
460 |
-
|
461 |
-
|
462 |
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
|
468 |
-
|
469 |
-
|
470 |
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
|
477 |
-
|
478 |
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
|
484 |
-
|
485 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
except Exception as e:
|
487 |
print(f"Error in RAG application: {str(e)}")
|
488 |
import traceback
|
489 |
traceback.print_exc()
|
490 |
return f"I apologize, but I encountered an error processing your question: {str(e)}. Please try rephrasing it or ask a different question."
|
491 |
|
|
|
492 |
# Main execution block
|
493 |
if __name__ == "__main__":
|
494 |
load_dotenv()
|
|
|
117 |
@property
|
118 |
def k(self):
|
119 |
return self._k
|
120 |
+
|
121 |
+
def get_documents_with_confidence(self, query: str) -> List[dict]:
|
122 |
+
"""Return top documents and their confidence (similarity) scores."""
|
123 |
+
results = self._get_relevant_documents_with_scores(query)
|
124 |
+
return [{"document": doc.page_content, "confidence": round(score, 3)} for doc, score in results]
|
125 |
+
|
126 |
+
|
127 |
+
def _get_relevant_documents_with_scores(self, query: str) -> List[tuple[Document, float]]:
|
128 |
+
"""Retrieve documents along with similarity scores."""
|
129 |
+
if not hasattr(self, '_vectorizer') or self._vectorizer is None or not hasattr(self._vectorizer, 'vocabulary_') or not self._vectorizer.vocabulary_:
|
130 |
+
self._vectorizer = TfidfVectorizer(
|
131 |
+
max_features=3000,
|
132 |
+
stop_words='english',
|
133 |
+
ngram_range=(1, 2),
|
134 |
+
min_df=1,
|
135 |
+
max_df=0.9
|
136 |
+
)
|
137 |
+
questions = [doc.page_content.split("ANSWER:")[0].replace("QUESTION:", "").strip()
|
138 |
+
if "QUESTION:" in doc.page_content else doc.page_content
|
139 |
+
for doc in self._documents]
|
140 |
+
self._vectorizer.fit(questions)
|
141 |
+
|
142 |
+
query_vector = self._vectorizer.transform([query.lower().strip()])
|
143 |
+
question_texts = [doc.page_content.split("ANSWER:")[0].replace("QUESTION:", "").strip()
|
144 |
+
if "QUESTION:" in doc.page_content else doc.page_content
|
145 |
+
for doc in self._documents]
|
146 |
+
question_vectors = self._vectorizer.transform(question_texts)
|
147 |
+
similarities = cosine_similarity(query_vector, question_vectors).flatten()
|
148 |
+
|
149 |
+
top_indices = similarities.argsort()[-self._k:][::-1]
|
150 |
+
return [(self._documents[i], float(similarities[i])) for i in top_indices if similarities[i] > 0.1]
|
151 |
+
|
152 |
|
153 |
def _get_relevant_documents(self, query: str) -> List[Document]:
|
154 |
"""Retrieve documents based on semantic similarity."""
|
|
|
479 |
self.retriever = retriever
|
480 |
self.rag_chain = rag_chain
|
481 |
|
482 |
+
# def run(self, question: str) -> str:
|
483 |
+
# """Runs the RAG pipeline for a given question."""
|
484 |
+
# try:
|
485 |
+
# # Input validation
|
486 |
+
# if not question or not question.strip():
|
487 |
+
# return "Please provide a valid question."
|
488 |
|
489 |
+
# question = question.strip()
|
490 |
+
# print(f"\nProcessing question: '{question}'")
|
491 |
|
492 |
+
# # Retrieve relevant documents
|
493 |
+
# documents = self.retriever.invoke(question)
|
494 |
|
495 |
+
# # Debug: Print retrieved documents
|
496 |
+
# print(f"DEBUG: Retrieved {len(documents)} documents")
|
497 |
+
# for i, doc in enumerate(documents):
|
498 |
+
# print(f"DEBUG: Document {i+1}: {doc.page_content[:200]}...")
|
499 |
|
500 |
+
# # Extract content from retrieved documents
|
501 |
+
# doc_texts = "\n\n".join([doc.page_content for doc in documents])
|
502 |
|
503 |
+
# # Limit the total input length to prevent token overflow
|
504 |
+
# max_input_length = 500 # Conservative limit
|
505 |
+
# if len(doc_texts) > max_input_length:
|
506 |
+
# doc_texts = doc_texts[:max_input_length] + "..."
|
507 |
+
# print(f"DEBUG: Truncated document text to {max_input_length} characters")
|
508 |
|
509 |
+
# print(f"DEBUG: Combined document text length: {len(doc_texts)}")
|
510 |
|
511 |
+
# # Get the answer from the language model
|
512 |
+
# print("DEBUG: Calling language model...")
|
513 |
+
# answer = self.rag_chain.invoke({"question": question, "documents": doc_texts})
|
514 |
+
# print(f"DEBUG: Language model response: {answer}")
|
515 |
|
516 |
+
# return answer
|
517 |
|
518 |
+
# except Exception as e:
|
519 |
+
# print(f"Error in RAG application: {str(e)}")
|
520 |
+
# import traceback
|
521 |
+
# traceback.print_exc()
|
522 |
+
# return f"I apologize, but I encountered an error processing your question: {str(e)}. Please try rephrasing it or ask a different question."
|
523 |
+
def run(self, question: str) -> str:
|
524 |
+
try:
|
525 |
+
if not question.strip():
|
526 |
+
return "Please provide a valid question."
|
527 |
+
|
528 |
+
print(f"\nProcessing question: '{question}'")
|
529 |
+
|
530 |
+
if hasattr(self.retriever, "get_documents_with_confidence"):
|
531 |
+
docs_with_scores = self.retriever.get_documents_with_confidence(question)
|
532 |
+
documents = [Document(page_content=d["document"]) for d in docs_with_scores]
|
533 |
+
confidence_info = "\n".join([f"- Score: {d['confidence']}, Snippet: {d['document'][:100]}..." for d in docs_with_scores])
|
534 |
+
else:
|
535 |
+
documents = self.retriever.invoke(question)
|
536 |
+
confidence_info = "Confidence scoring not available."
|
537 |
+
|
538 |
+
print(f"Retrieved {len(documents)} documents")
|
539 |
+
print(confidence_info)
|
540 |
+
|
541 |
+
doc_texts = "\n\n".join([doc.page_content for doc in documents])
|
542 |
+
if len(doc_texts) > 500:
|
543 |
+
doc_texts = doc_texts[:500] + "..."
|
544 |
+
|
545 |
+
answer = self.rag_chain.invoke({"question": question, "documents": doc_texts})
|
546 |
+
|
547 |
+
# Append confidence footer
|
548 |
+
footer = "\n\n(Note: This answer is based on documents with confidence scores. Review full context if critical.)"
|
549 |
+
return answer.strip() + footer
|
550 |
+
|
551 |
except Exception as e:
|
552 |
print(f"Error in RAG application: {str(e)}")
|
553 |
import traceback
|
554 |
traceback.print_exc()
|
555 |
return f"I apologize, but I encountered an error processing your question: {str(e)}. Please try rephrasing it or ask a different question."
|
556 |
|
557 |
+
|
558 |
# Main execution block
|
559 |
if __name__ == "__main__":
|
560 |
load_dotenv()
|