DeepakKolhe1995's picture
Update app.py
1490204 verified
from huggingface_hub import InferenceClient
import streamlit as st
import logging
import os
from dotenv import load_dotenv
from datasets import load_dataset
from langchain_core.documents import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import BedrockEmbeddings
from langchain_qdrant import Qdrant
from langchain_aws import ChatBedrock
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
import re
import json
from urllib.error import URLError
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def load_environment():
"""Load and validate environment variables."""
try:
load_dotenv()
required_vars = ['AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'AWS_REGION', 'QDRANT_URL', 'QDRANT_API_KEY']
missing_vars = [var for var in required_vars if not os.getenv(var)]
if missing_vars:
logger.error(f"Missing environment variables: {missing_vars}")
st.error(f"Missing environment variables: {missing_vars}")
raise ValueError(f"Missing environment variables: {missing_vars}")
logger.info("Environment variables loaded successfully")
except Exception as e:
logger.error(f"Error loading environment variables: {e}")
st.error(f"Error loading environment variables: {e}")
raise
@st.cache_resource
def load_wikipedia_documents():
"""Load 100 Wikipedia documents from Cohere's HF dataset."""
try:
dataset = load_dataset(
"Cohere/wikipedia-22-12-simple-embeddings",
split="train[:100]" # Load only 100 entries
)
documents = [Document(page_content=item["text"]) for item in dataset]
logger.info(f"Loaded {len(documents)} Wikipedia documents")
if not documents:
logger.error("No documents loaded from dataset")
st.error("No documents loaded from dataset")
return []
return documents
except Exception as e:
logger.error(f"Error loading dataset: {e}")
st.error(f"Failed to load dataset: {e}")
return []
@st.cache_resource
def split_documents(_documents):
"""Split documents into chunks."""
try:
if not _documents:
logger.error("No documents provided for splitting")
st.error("No documents provided for splitting")
return []
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
chunks = splitter.split_documents(_documents)
logger.info(f"Split into {len(chunks)} chunks")
if not chunks:
logger.error("No chunks created from documents")
st.error("No chunks created from documents")
return []
return chunks
except Exception as e:
logger.error(f"Error splitting documents: {e}")
st.error(f"Failed to split documents: {e}")
return []
@st.cache_resource
def initialize_embeddings():
"""Initialize AWS Bedrock embeddings."""
try:
embeddings = BedrockEmbeddings(
model_id="amazon.titan-embed-text-v1",
region_name=os.getenv("AWS_REGION")
)
logger.info("Initialized Bedrock embeddings")
return embeddings
except Exception as e:
logger.error(f"Error initializing embeddings: {e}")
st.error(f"Failed to initialize embeddings: {e}")
return None
def store_in_qdrant(_chunks, _embeddings):
"""Store document chunks in a hosted Qdrant instance after deleting all collections."""
try:
# Initialize Qdrant client
client = QdrantClient(
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"),
timeout=30
)
# Test Qdrant connection
try:
client.get_collections()
logger.info("Successfully connected to Qdrant at %s", os.getenv("QDRANT_URL"))
except Exception as e:
logger.error("Failed to connect to Qdrant: %s", e)
st.error(f"Failed to connect to Qdrant: {e}")
return None
# Delete all existing collections
try:
collections = client.get_collections().collections
for collection in collections:
client.delete_collection(collection.name)
logger.info(f"Deleted Qdrant collection: {collection.name}")
logger.info("All Qdrant collections deleted")
except Exception as e:
logger.warning(f"Error deleting collections: {e}")
st.warning(f"Error deleting collections: {e}")
# Validate input chunks
if not _chunks:
logger.error("No chunks provided for Qdrant storage")
st.error("No chunks provided for Qdrant storage")
return None
# Create and populate new collection
collection_name = "wikipedia_chunks"
try:
vector_store = Qdrant.from_documents(
documents=_chunks,
embedding=_embeddings,
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"),
collection_name=collection_name,
force_recreate=True # Ensure fresh collection
)
logger.info(f"Created Qdrant collection {collection_name} with {len(_chunks)} chunks")
except Exception as e:
logger.error(f"Error creating Qdrant collection: {e}")
st.error(f"Failed to create Qdrant collection: {e}")
return None
# Verify storage
try:
collection_info = client.get_collection(collection_name)
stored_points = collection_info.points_count
logger.info(f"Stored {stored_points} points in Qdrant collection {collection_name}")
if stored_points == 0:
logger.error("No documents stored in Qdrant collection")
st.error("No documents stored in Qdrant collection")
return None
if stored_points != len(_chunks):
logger.warning(f"Expected {len(_chunks)} chunks, but stored {stored_points} in Qdrant")
st.warning(f"Expected {len(_chunks)} chunks, but stored {stored_points} in Qdrant")
return vector_store
except Exception as e:
logger.error(f"Error verifying Qdrant storage: {e}")
st.error(f"Failed to verify Qdrant storage: {e}")
return None
except Exception as e:
logger.error(f"Error in Qdrant storage process: {e}")
st.error(f"Failed to store documents in Qdrant: {e}")
return None
@st.cache_resource
def initialize_llm():
"""Initialize AWS Bedrock Claude 3.5 Sonnet model."""
try:
llm = ChatBedrock(
model_id="anthropic.claude-3-5-sonnet-20240620-v1:0",
region_name=os.getenv("AWS_REGION"),
model_kwargs={"max_tokens": 1000}
)
logger.info("Initialized Claude 3.5 Sonnet")
return llm
except Exception as e:
logger.error(f"Error initializing LLM: {e}")
st.error(f"Failed to initialize LLM: {e}")
return None
def extract_score_from_text(text):
"""Extract the first float number between 0 and 1 from the text using regex."""
try:
matches = re.findall(r'\b0(?:\.\d+)?\b|\b1(?:\.0+)?\b', text)
if not matches:
logger.warning("No score found in text")
return None
score = float(matches[0])
if 0.0 <= score <= 1.0:
return score
logger.warning(f"Score {score} out of expected range 0-1")
return None
except ValueError as e:
logger.warning(f"Cannot convert match to float: {e}")
return None
def claude_rerank(docs, query, llm, top_n=5):
"""Rerank documents based on relevance using the LLM."""
try:
rerank_prompt = ChatPromptTemplate.from_template(
"""
Given the query: "{query}" and the document chunk: "{chunk}", please rate
the relevance on a scale from 0 to 1 (0=not relevant, 1=highly relevant).
Respond with a number only, like: 0.8
"""
)
scored_docs = []
for idx, doc in enumerate(docs):
prompt = rerank_prompt.format(query=query, chunk=doc.page_content)
response = llm.invoke(prompt)
text = response.content.strip()
logger.info(f"Doc {idx} rerank raw output: {text}")
score = extract_score_from_text(text)
if score is None:
logger.warning(f"Failed to extract valid score for doc {idx}. Assigning 0.")
score = 0.0
scored_docs.append((doc, score))
scored_docs.sort(key=lambda x: x[1], reverse=True)
logger.info(f"Reranked top {top_n} docs based on scores")
return [doc for doc, _ in scored_docs[:top_n]]
except Exception as e:
logger.error(f"Error in reranking: {e}")
st.error(f"Error in reranking: {e}")
return docs[:top_n] # Fallback to original docs
def create_rag_chain(vector_store, llm, use_rerank=False):
"""Create a RAG chain with or without reranking."""
try:
prompt_template = ChatPromptTemplate.from_template(
"""You are a helpful assistant. Use the following context to answer the question concisely.\n\nContext:\n{context}\n\nQuestion: {question}\n\nAnswer:"""
)
retriever = vector_store.as_retriever(search_kwargs={"k": 20 if use_rerank else 5})
def rerank_context(inputs):
try:
docs = retriever.invoke(inputs["question"])
if not docs:
logger.warning("No documents retrieved for query")
return {"context": "", "question": inputs["question"]}
if use_rerank:
docs = claude_rerank(docs, inputs["question"], llm)
return {"context": "\n\n".join(doc.page_content for doc in docs), "question": inputs["question"]}
except Exception as e:
logger.error(f"Error in rerank_context: {e}")
return {"context": "", "question": inputs["question"]}
chain = rerank_context | prompt_template | llm | StrOutputParser()
logger.info(f"Initialized {'re-ranked' if use_rerank else 'baseline'} RAG chain")
return chain
except Exception as e:
logger.error(f"Error creating RAG chain: {e}")
st.error(f"Failed to create RAG chain: {e}")
return None
def main():
st.title("Wikipedia Q&A with RAG (Qdrant + AWS Bedrock)")
st.write("Enter a question to get answers using baseline and reranked retrieval methods.")
# Load environment variables
try:
load_environment()
except ValueError:
return
# Initialize components
documents = load_wikipedia_documents()
if not documents:
st.error("Cannot proceed without documents")
return
chunks = split_documents(documents)
if not chunks:
st.error("Cannot proceed without document chunks")
return
embeddings = initialize_embeddings()
if embeddings is None:
st.error("Cannot proceed without embeddings")
return
vector_store = store_in_qdrant(chunks, embeddings)
if vector_store is None:
st.error("Cannot proceed without vector store")
return
llm = initialize_llm()
if llm is None:
st.error("Cannot proceed without LLM")
return
baseline_chain = create_rag_chain(vector_store, llm, use_rerank=False)
if baseline_chain is None:
st.error("Cannot proceed without baseline chain")
return
rerank_chain = create_rag_chain(vector_store, llm, use_rerank=True)
if rerank_chain is None:
st.error("Cannot proceed without rerank chain")
return
# Streamlit input
query = st.text_input("Enter your question:", placeholder="e.g., What are the main causes of climate change?")
if query:
with st.spinner("Processing your query..."):
try:
baseline_response = baseline_chain.invoke({"question": query})
rerank_response = rerank_chain.invoke({"question": query})
st.subheader("Results")
st.write("**Query:**", query)
st.write("**Baseline Answer:**")
st.write(baseline_response)
st.write("**Reranked Answer:**")
st.write(rerank_response)
except Exception as e:
logger.error(f"Error processing query: {e}")
st.error(f"Error processing query: {e}")
if __name__ == "__main__":
main()