from huggingface_hub import InferenceClient import streamlit as st import logging import os from dotenv import load_dotenv from datasets import load_dataset from langchain_core.documents import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import BedrockEmbeddings from langchain_qdrant import Qdrant from langchain_aws import ChatBedrock from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnablePassthrough from langchain_core.output_parsers import StrOutputParser from qdrant_client import QdrantClient from qdrant_client.models import Distance, VectorParams import re import json from urllib.error import URLError # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def load_environment(): """Load and validate environment variables.""" try: load_dotenv() required_vars = ['AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'AWS_REGION', 'QDRANT_URL', 'QDRANT_API_KEY'] missing_vars = [var for var in required_vars if not os.getenv(var)] if missing_vars: logger.error(f"Missing environment variables: {missing_vars}") st.error(f"Missing environment variables: {missing_vars}") raise ValueError(f"Missing environment variables: {missing_vars}") logger.info("Environment variables loaded successfully") except Exception as e: logger.error(f"Error loading environment variables: {e}") st.error(f"Error loading environment variables: {e}") raise @st.cache_resource def load_wikipedia_documents(): """Load 100 Wikipedia documents from Cohere's HF dataset.""" try: dataset = load_dataset( "Cohere/wikipedia-22-12-simple-embeddings", split="train[:100]" # Load only 100 entries ) documents = [Document(page_content=item["text"]) for item in dataset] logger.info(f"Loaded {len(documents)} Wikipedia documents") if not documents: logger.error("No documents loaded from dataset") st.error("No documents loaded from dataset") return [] return documents except Exception as e: logger.error(f"Error loading dataset: {e}") st.error(f"Failed to load dataset: {e}") return [] @st.cache_resource def split_documents(_documents): """Split documents into chunks.""" try: if not _documents: logger.error("No documents provided for splitting") st.error("No documents provided for splitting") return [] splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) chunks = splitter.split_documents(_documents) logger.info(f"Split into {len(chunks)} chunks") if not chunks: logger.error("No chunks created from documents") st.error("No chunks created from documents") return [] return chunks except Exception as e: logger.error(f"Error splitting documents: {e}") st.error(f"Failed to split documents: {e}") return [] @st.cache_resource def initialize_embeddings(): """Initialize AWS Bedrock embeddings.""" try: embeddings = BedrockEmbeddings( model_id="amazon.titan-embed-text-v1", region_name=os.getenv("AWS_REGION") ) logger.info("Initialized Bedrock embeddings") return embeddings except Exception as e: logger.error(f"Error initializing embeddings: {e}") st.error(f"Failed to initialize embeddings: {e}") return None def store_in_qdrant(_chunks, _embeddings): """Store document chunks in a hosted Qdrant instance after deleting all collections.""" try: # Initialize Qdrant client client = QdrantClient( url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"), timeout=30 ) # Test Qdrant connection try: client.get_collections() logger.info("Successfully connected to Qdrant at %s", os.getenv("QDRANT_URL")) except Exception as e: logger.error("Failed to connect to Qdrant: %s", e) st.error(f"Failed to connect to Qdrant: {e}") return None # Delete all existing collections try: collections = client.get_collections().collections for collection in collections: client.delete_collection(collection.name) logger.info(f"Deleted Qdrant collection: {collection.name}") logger.info("All Qdrant collections deleted") except Exception as e: logger.warning(f"Error deleting collections: {e}") st.warning(f"Error deleting collections: {e}") # Validate input chunks if not _chunks: logger.error("No chunks provided for Qdrant storage") st.error("No chunks provided for Qdrant storage") return None # Create and populate new collection collection_name = "wikipedia_chunks" try: vector_store = Qdrant.from_documents( documents=_chunks, embedding=_embeddings, url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"), collection_name=collection_name, force_recreate=True # Ensure fresh collection ) logger.info(f"Created Qdrant collection {collection_name} with {len(_chunks)} chunks") except Exception as e: logger.error(f"Error creating Qdrant collection: {e}") st.error(f"Failed to create Qdrant collection: {e}") return None # Verify storage try: collection_info = client.get_collection(collection_name) stored_points = collection_info.points_count logger.info(f"Stored {stored_points} points in Qdrant collection {collection_name}") if stored_points == 0: logger.error("No documents stored in Qdrant collection") st.error("No documents stored in Qdrant collection") return None if stored_points != len(_chunks): logger.warning(f"Expected {len(_chunks)} chunks, but stored {stored_points} in Qdrant") st.warning(f"Expected {len(_chunks)} chunks, but stored {stored_points} in Qdrant") return vector_store except Exception as e: logger.error(f"Error verifying Qdrant storage: {e}") st.error(f"Failed to verify Qdrant storage: {e}") return None except Exception as e: logger.error(f"Error in Qdrant storage process: {e}") st.error(f"Failed to store documents in Qdrant: {e}") return None @st.cache_resource def initialize_llm(): """Initialize AWS Bedrock Claude 3.5 Sonnet model.""" try: llm = ChatBedrock( model_id="anthropic.claude-3-5-sonnet-20240620-v1:0", region_name=os.getenv("AWS_REGION"), model_kwargs={"max_tokens": 1000} ) logger.info("Initialized Claude 3.5 Sonnet") return llm except Exception as e: logger.error(f"Error initializing LLM: {e}") st.error(f"Failed to initialize LLM: {e}") return None def extract_score_from_text(text): """Extract the first float number between 0 and 1 from the text using regex.""" try: matches = re.findall(r'\b0(?:\.\d+)?\b|\b1(?:\.0+)?\b', text) if not matches: logger.warning("No score found in text") return None score = float(matches[0]) if 0.0 <= score <= 1.0: return score logger.warning(f"Score {score} out of expected range 0-1") return None except ValueError as e: logger.warning(f"Cannot convert match to float: {e}") return None def claude_rerank(docs, query, llm, top_n=5): """Rerank documents based on relevance using the LLM.""" try: rerank_prompt = ChatPromptTemplate.from_template( """ Given the query: "{query}" and the document chunk: "{chunk}", please rate the relevance on a scale from 0 to 1 (0=not relevant, 1=highly relevant). Respond with a number only, like: 0.8 """ ) scored_docs = [] for idx, doc in enumerate(docs): prompt = rerank_prompt.format(query=query, chunk=doc.page_content) response = llm.invoke(prompt) text = response.content.strip() logger.info(f"Doc {idx} rerank raw output: {text}") score = extract_score_from_text(text) if score is None: logger.warning(f"Failed to extract valid score for doc {idx}. Assigning 0.") score = 0.0 scored_docs.append((doc, score)) scored_docs.sort(key=lambda x: x[1], reverse=True) logger.info(f"Reranked top {top_n} docs based on scores") return [doc for doc, _ in scored_docs[:top_n]] except Exception as e: logger.error(f"Error in reranking: {e}") st.error(f"Error in reranking: {e}") return docs[:top_n] # Fallback to original docs def create_rag_chain(vector_store, llm, use_rerank=False): """Create a RAG chain with or without reranking.""" try: prompt_template = ChatPromptTemplate.from_template( """You are a helpful assistant. Use the following context to answer the question concisely.\n\nContext:\n{context}\n\nQuestion: {question}\n\nAnswer:""" ) retriever = vector_store.as_retriever(search_kwargs={"k": 20 if use_rerank else 5}) def rerank_context(inputs): try: docs = retriever.invoke(inputs["question"]) if not docs: logger.warning("No documents retrieved for query") return {"context": "", "question": inputs["question"]} if use_rerank: docs = claude_rerank(docs, inputs["question"], llm) return {"context": "\n\n".join(doc.page_content for doc in docs), "question": inputs["question"]} except Exception as e: logger.error(f"Error in rerank_context: {e}") return {"context": "", "question": inputs["question"]} chain = rerank_context | prompt_template | llm | StrOutputParser() logger.info(f"Initialized {'re-ranked' if use_rerank else 'baseline'} RAG chain") return chain except Exception as e: logger.error(f"Error creating RAG chain: {e}") st.error(f"Failed to create RAG chain: {e}") return None def main(): st.title("Wikipedia Q&A with RAG (Qdrant + AWS Bedrock)") st.write("Enter a question to get answers using baseline and reranked retrieval methods.") # Load environment variables try: load_environment() except ValueError: return # Initialize components documents = load_wikipedia_documents() if not documents: st.error("Cannot proceed without documents") return chunks = split_documents(documents) if not chunks: st.error("Cannot proceed without document chunks") return embeddings = initialize_embeddings() if embeddings is None: st.error("Cannot proceed without embeddings") return vector_store = store_in_qdrant(chunks, embeddings) if vector_store is None: st.error("Cannot proceed without vector store") return llm = initialize_llm() if llm is None: st.error("Cannot proceed without LLM") return baseline_chain = create_rag_chain(vector_store, llm, use_rerank=False) if baseline_chain is None: st.error("Cannot proceed without baseline chain") return rerank_chain = create_rag_chain(vector_store, llm, use_rerank=True) if rerank_chain is None: st.error("Cannot proceed without rerank chain") return # Streamlit input query = st.text_input("Enter your question:", placeholder="e.g., What are the main causes of climate change?") if query: with st.spinner("Processing your query..."): try: baseline_response = baseline_chain.invoke({"question": query}) rerank_response = rerank_chain.invoke({"question": query}) st.subheader("Results") st.write("**Query:**", query) st.write("**Baseline Answer:**") st.write(baseline_response) st.write("**Reranked Answer:**") st.write(rerank_response) except Exception as e: logger.error(f"Error processing query: {e}") st.error(f"Error processing query: {e}") if __name__ == "__main__": main()