brendon-ai commited on
Commit
43d9f2a
·
verified ·
1 Parent(s): 2cd1b80

Update src/RAGSample.py

Browse files
Files changed (1) hide show
  1. src/RAGSample.py +93 -27
src/RAGSample.py CHANGED
@@ -117,6 +117,38 @@ class SmartFAQRetriever(BaseRetriever):
117
  @property
118
  def k(self):
119
  return self._k
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  def _get_relevant_documents(self, query: str) -> List[Document]:
122
  """Retrieve documents based on semantic similarity."""
@@ -447,48 +479,82 @@ class RAGApplication:
447
  self.retriever = retriever
448
  self.rag_chain = rag_chain
449
 
450
- def run(self, question: str) -> str:
451
- """Runs the RAG pipeline for a given question."""
452
- try:
453
- # Input validation
454
- if not question or not question.strip():
455
- return "Please provide a valid question."
456
 
457
- question = question.strip()
458
- print(f"\nProcessing question: '{question}'")
459
 
460
- # Retrieve relevant documents
461
- documents = self.retriever.invoke(question)
462
 
463
- # Debug: Print retrieved documents
464
- print(f"DEBUG: Retrieved {len(documents)} documents")
465
- for i, doc in enumerate(documents):
466
- print(f"DEBUG: Document {i+1}: {doc.page_content[:200]}...")
467
 
468
- # Extract content from retrieved documents
469
- doc_texts = "\n\n".join([doc.page_content for doc in documents])
470
 
471
- # Limit the total input length to prevent token overflow
472
- max_input_length = 500 # Conservative limit
473
- if len(doc_texts) > max_input_length:
474
- doc_texts = doc_texts[:max_input_length] + "..."
475
- print(f"DEBUG: Truncated document text to {max_input_length} characters")
476
 
477
- print(f"DEBUG: Combined document text length: {len(doc_texts)}")
478
 
479
- # Get the answer from the language model
480
- print("DEBUG: Calling language model...")
481
- answer = self.rag_chain.invoke({"question": question, "documents": doc_texts})
482
- print(f"DEBUG: Language model response: {answer}")
483
 
484
- return answer
485
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
  except Exception as e:
487
  print(f"Error in RAG application: {str(e)}")
488
  import traceback
489
  traceback.print_exc()
490
  return f"I apologize, but I encountered an error processing your question: {str(e)}. Please try rephrasing it or ask a different question."
491
 
 
492
  # Main execution block
493
  if __name__ == "__main__":
494
  load_dotenv()
 
117
  @property
118
  def k(self):
119
  return self._k
120
+
121
+ def get_documents_with_confidence(self, query: str) -> List[dict]:
122
+ """Return top documents and their confidence (similarity) scores."""
123
+ results = self._get_relevant_documents_with_scores(query)
124
+ return [{"document": doc.page_content, "confidence": round(score, 3)} for doc, score in results]
125
+
126
+
127
+ def _get_relevant_documents_with_scores(self, query: str) -> List[tuple[Document, float]]:
128
+ """Retrieve documents along with similarity scores."""
129
+ if not hasattr(self, '_vectorizer') or self._vectorizer is None or not hasattr(self._vectorizer, 'vocabulary_') or not self._vectorizer.vocabulary_:
130
+ self._vectorizer = TfidfVectorizer(
131
+ max_features=3000,
132
+ stop_words='english',
133
+ ngram_range=(1, 2),
134
+ min_df=1,
135
+ max_df=0.9
136
+ )
137
+ questions = [doc.page_content.split("ANSWER:")[0].replace("QUESTION:", "").strip()
138
+ if "QUESTION:" in doc.page_content else doc.page_content
139
+ for doc in self._documents]
140
+ self._vectorizer.fit(questions)
141
+
142
+ query_vector = self._vectorizer.transform([query.lower().strip()])
143
+ question_texts = [doc.page_content.split("ANSWER:")[0].replace("QUESTION:", "").strip()
144
+ if "QUESTION:" in doc.page_content else doc.page_content
145
+ for doc in self._documents]
146
+ question_vectors = self._vectorizer.transform(question_texts)
147
+ similarities = cosine_similarity(query_vector, question_vectors).flatten()
148
+
149
+ top_indices = similarities.argsort()[-self._k:][::-1]
150
+ return [(self._documents[i], float(similarities[i])) for i in top_indices if similarities[i] > 0.1]
151
+
152
 
153
  def _get_relevant_documents(self, query: str) -> List[Document]:
154
  """Retrieve documents based on semantic similarity."""
 
479
  self.retriever = retriever
480
  self.rag_chain = rag_chain
481
 
482
+ # def run(self, question: str) -> str:
483
+ # """Runs the RAG pipeline for a given question."""
484
+ # try:
485
+ # # Input validation
486
+ # if not question or not question.strip():
487
+ # return "Please provide a valid question."
488
 
489
+ # question = question.strip()
490
+ # print(f"\nProcessing question: '{question}'")
491
 
492
+ # # Retrieve relevant documents
493
+ # documents = self.retriever.invoke(question)
494
 
495
+ # # Debug: Print retrieved documents
496
+ # print(f"DEBUG: Retrieved {len(documents)} documents")
497
+ # for i, doc in enumerate(documents):
498
+ # print(f"DEBUG: Document {i+1}: {doc.page_content[:200]}...")
499
 
500
+ # # Extract content from retrieved documents
501
+ # doc_texts = "\n\n".join([doc.page_content for doc in documents])
502
 
503
+ # # Limit the total input length to prevent token overflow
504
+ # max_input_length = 500 # Conservative limit
505
+ # if len(doc_texts) > max_input_length:
506
+ # doc_texts = doc_texts[:max_input_length] + "..."
507
+ # print(f"DEBUG: Truncated document text to {max_input_length} characters")
508
 
509
+ # print(f"DEBUG: Combined document text length: {len(doc_texts)}")
510
 
511
+ # # Get the answer from the language model
512
+ # print("DEBUG: Calling language model...")
513
+ # answer = self.rag_chain.invoke({"question": question, "documents": doc_texts})
514
+ # print(f"DEBUG: Language model response: {answer}")
515
 
516
+ # return answer
517
 
518
+ # except Exception as e:
519
+ # print(f"Error in RAG application: {str(e)}")
520
+ # import traceback
521
+ # traceback.print_exc()
522
+ # return f"I apologize, but I encountered an error processing your question: {str(e)}. Please try rephrasing it or ask a different question."
523
+ def run(self, question: str) -> str:
524
+ try:
525
+ if not question.strip():
526
+ return "Please provide a valid question."
527
+
528
+ print(f"\nProcessing question: '{question}'")
529
+
530
+ if hasattr(self.retriever, "get_documents_with_confidence"):
531
+ docs_with_scores = self.retriever.get_documents_with_confidence(question)
532
+ documents = [Document(page_content=d["document"]) for d in docs_with_scores]
533
+ confidence_info = "\n".join([f"- Score: {d['confidence']}, Snippet: {d['document'][:100]}..." for d in docs_with_scores])
534
+ else:
535
+ documents = self.retriever.invoke(question)
536
+ confidence_info = "Confidence scoring not available."
537
+
538
+ print(f"Retrieved {len(documents)} documents")
539
+ print(confidence_info)
540
+
541
+ doc_texts = "\n\n".join([doc.page_content for doc in documents])
542
+ if len(doc_texts) > 500:
543
+ doc_texts = doc_texts[:500] + "..."
544
+
545
+ answer = self.rag_chain.invoke({"question": question, "documents": doc_texts})
546
+
547
+ # Append confidence footer
548
+ footer = "\n\n(Note: This answer is based on documents with confidence scores. Review full context if critical.)"
549
+ return answer.strip() + footer
550
+
551
  except Exception as e:
552
  print(f"Error in RAG application: {str(e)}")
553
  import traceback
554
  traceback.print_exc()
555
  return f"I apologize, but I encountered an error processing your question: {str(e)}. Please try rephrasing it or ask a different question."
556
 
557
+
558
  # Main execution block
559
  if __name__ == "__main__":
560
  load_dotenv()