|
|
|
from huggingface_hub import InferenceClient |
|
import streamlit as st |
|
import logging |
|
import os |
|
from dotenv import load_dotenv |
|
from datasets import load_dataset |
|
from langchain_core.documents import Document |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.embeddings import BedrockEmbeddings |
|
from langchain_qdrant import Qdrant |
|
from langchain_aws import ChatBedrock |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain_core.output_parsers import StrOutputParser |
|
from qdrant_client import QdrantClient |
|
from qdrant_client.models import Distance, VectorParams |
|
import re |
|
import json |
|
from urllib.error import URLError |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
def load_environment(): |
|
"""Load and validate environment variables.""" |
|
try: |
|
load_dotenv() |
|
required_vars = ['AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'AWS_REGION', 'QDRANT_URL', 'QDRANT_API_KEY'] |
|
missing_vars = [var for var in required_vars if not os.getenv(var)] |
|
if missing_vars: |
|
logger.error(f"Missing environment variables: {missing_vars}") |
|
st.error(f"Missing environment variables: {missing_vars}") |
|
raise ValueError(f"Missing environment variables: {missing_vars}") |
|
logger.info("Environment variables loaded successfully") |
|
except Exception as e: |
|
logger.error(f"Error loading environment variables: {e}") |
|
st.error(f"Error loading environment variables: {e}") |
|
raise |
|
|
|
@st.cache_resource |
|
def load_wikipedia_documents(): |
|
"""Load 100 Wikipedia documents from Cohere's HF dataset.""" |
|
try: |
|
dataset = load_dataset( |
|
"Cohere/wikipedia-22-12-simple-embeddings", |
|
split="train[:100]" |
|
) |
|
documents = [Document(page_content=item["text"]) for item in dataset] |
|
logger.info(f"Loaded {len(documents)} Wikipedia documents") |
|
if not documents: |
|
logger.error("No documents loaded from dataset") |
|
st.error("No documents loaded from dataset") |
|
return [] |
|
return documents |
|
except Exception as e: |
|
logger.error(f"Error loading dataset: {e}") |
|
st.error(f"Failed to load dataset: {e}") |
|
return [] |
|
|
|
@st.cache_resource |
|
def split_documents(_documents): |
|
"""Split documents into chunks.""" |
|
try: |
|
if not _documents: |
|
logger.error("No documents provided for splitting") |
|
st.error("No documents provided for splitting") |
|
return [] |
|
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) |
|
chunks = splitter.split_documents(_documents) |
|
logger.info(f"Split into {len(chunks)} chunks") |
|
if not chunks: |
|
logger.error("No chunks created from documents") |
|
st.error("No chunks created from documents") |
|
return [] |
|
return chunks |
|
except Exception as e: |
|
logger.error(f"Error splitting documents: {e}") |
|
st.error(f"Failed to split documents: {e}") |
|
return [] |
|
|
|
@st.cache_resource |
|
def initialize_embeddings(): |
|
"""Initialize AWS Bedrock embeddings.""" |
|
try: |
|
embeddings = BedrockEmbeddings( |
|
model_id="amazon.titan-embed-text-v1", |
|
region_name=os.getenv("AWS_REGION") |
|
) |
|
logger.info("Initialized Bedrock embeddings") |
|
return embeddings |
|
except Exception as e: |
|
logger.error(f"Error initializing embeddings: {e}") |
|
st.error(f"Failed to initialize embeddings: {e}") |
|
return None |
|
|
|
def store_in_qdrant(_chunks, _embeddings): |
|
"""Store document chunks in a hosted Qdrant instance after deleting all collections.""" |
|
try: |
|
|
|
client = QdrantClient( |
|
url=os.getenv("QDRANT_URL"), |
|
api_key=os.getenv("QDRANT_API_KEY"), |
|
timeout=30 |
|
) |
|
|
|
|
|
try: |
|
client.get_collections() |
|
logger.info("Successfully connected to Qdrant at %s", os.getenv("QDRANT_URL")) |
|
except Exception as e: |
|
logger.error("Failed to connect to Qdrant: %s", e) |
|
st.error(f"Failed to connect to Qdrant: {e}") |
|
return None |
|
|
|
|
|
try: |
|
collections = client.get_collections().collections |
|
for collection in collections: |
|
client.delete_collection(collection.name) |
|
logger.info(f"Deleted Qdrant collection: {collection.name}") |
|
logger.info("All Qdrant collections deleted") |
|
except Exception as e: |
|
logger.warning(f"Error deleting collections: {e}") |
|
st.warning(f"Error deleting collections: {e}") |
|
|
|
|
|
if not _chunks: |
|
logger.error("No chunks provided for Qdrant storage") |
|
st.error("No chunks provided for Qdrant storage") |
|
return None |
|
|
|
|
|
collection_name = "wikipedia_chunks" |
|
try: |
|
vector_store = Qdrant.from_documents( |
|
documents=_chunks, |
|
embedding=_embeddings, |
|
url=os.getenv("QDRANT_URL"), |
|
api_key=os.getenv("QDRANT_API_KEY"), |
|
collection_name=collection_name, |
|
force_recreate=True |
|
) |
|
logger.info(f"Created Qdrant collection {collection_name} with {len(_chunks)} chunks") |
|
except Exception as e: |
|
logger.error(f"Error creating Qdrant collection: {e}") |
|
st.error(f"Failed to create Qdrant collection: {e}") |
|
return None |
|
|
|
|
|
try: |
|
collection_info = client.get_collection(collection_name) |
|
stored_points = collection_info.points_count |
|
logger.info(f"Stored {stored_points} points in Qdrant collection {collection_name}") |
|
if stored_points == 0: |
|
logger.error("No documents stored in Qdrant collection") |
|
st.error("No documents stored in Qdrant collection") |
|
return None |
|
if stored_points != len(_chunks): |
|
logger.warning(f"Expected {len(_chunks)} chunks, but stored {stored_points} in Qdrant") |
|
st.warning(f"Expected {len(_chunks)} chunks, but stored {stored_points} in Qdrant") |
|
return vector_store |
|
except Exception as e: |
|
logger.error(f"Error verifying Qdrant storage: {e}") |
|
st.error(f"Failed to verify Qdrant storage: {e}") |
|
return None |
|
|
|
except Exception as e: |
|
logger.error(f"Error in Qdrant storage process: {e}") |
|
st.error(f"Failed to store documents in Qdrant: {e}") |
|
return None |
|
|
|
@st.cache_resource |
|
def initialize_llm(): |
|
"""Initialize AWS Bedrock Claude 3.5 Sonnet model.""" |
|
try: |
|
llm = ChatBedrock( |
|
model_id="anthropic.claude-3-5-sonnet-20240620-v1:0", |
|
region_name=os.getenv("AWS_REGION"), |
|
model_kwargs={"max_tokens": 1000} |
|
) |
|
logger.info("Initialized Claude 3.5 Sonnet") |
|
return llm |
|
except Exception as e: |
|
logger.error(f"Error initializing LLM: {e}") |
|
st.error(f"Failed to initialize LLM: {e}") |
|
return None |
|
|
|
def extract_score_from_text(text): |
|
"""Extract the first float number between 0 and 1 from the text using regex.""" |
|
try: |
|
matches = re.findall(r'\b0(?:\.\d+)?\b|\b1(?:\.0+)?\b', text) |
|
if not matches: |
|
logger.warning("No score found in text") |
|
return None |
|
score = float(matches[0]) |
|
if 0.0 <= score <= 1.0: |
|
return score |
|
logger.warning(f"Score {score} out of expected range 0-1") |
|
return None |
|
except ValueError as e: |
|
logger.warning(f"Cannot convert match to float: {e}") |
|
return None |
|
|
|
def claude_rerank(docs, query, llm, top_n=5): |
|
"""Rerank documents based on relevance using the LLM.""" |
|
try: |
|
rerank_prompt = ChatPromptTemplate.from_template( |
|
""" |
|
Given the query: "{query}" and the document chunk: "{chunk}", please rate |
|
the relevance on a scale from 0 to 1 (0=not relevant, 1=highly relevant). |
|
|
|
Respond with a number only, like: 0.8 |
|
""" |
|
) |
|
scored_docs = [] |
|
for idx, doc in enumerate(docs): |
|
prompt = rerank_prompt.format(query=query, chunk=doc.page_content) |
|
response = llm.invoke(prompt) |
|
text = response.content.strip() |
|
logger.info(f"Doc {idx} rerank raw output: {text}") |
|
score = extract_score_from_text(text) |
|
if score is None: |
|
logger.warning(f"Failed to extract valid score for doc {idx}. Assigning 0.") |
|
score = 0.0 |
|
scored_docs.append((doc, score)) |
|
scored_docs.sort(key=lambda x: x[1], reverse=True) |
|
logger.info(f"Reranked top {top_n} docs based on scores") |
|
return [doc for doc, _ in scored_docs[:top_n]] |
|
except Exception as e: |
|
logger.error(f"Error in reranking: {e}") |
|
st.error(f"Error in reranking: {e}") |
|
return docs[:top_n] |
|
|
|
def create_rag_chain(vector_store, llm, use_rerank=False): |
|
"""Create a RAG chain with or without reranking.""" |
|
try: |
|
prompt_template = ChatPromptTemplate.from_template( |
|
"""You are a helpful assistant. Use the following context to answer the question concisely.\n\nContext:\n{context}\n\nQuestion: {question}\n\nAnswer:""" |
|
) |
|
retriever = vector_store.as_retriever(search_kwargs={"k": 20 if use_rerank else 5}) |
|
|
|
def rerank_context(inputs): |
|
try: |
|
docs = retriever.invoke(inputs["question"]) |
|
if not docs: |
|
logger.warning("No documents retrieved for query") |
|
return {"context": "", "question": inputs["question"]} |
|
if use_rerank: |
|
docs = claude_rerank(docs, inputs["question"], llm) |
|
return {"context": "\n\n".join(doc.page_content for doc in docs), "question": inputs["question"]} |
|
except Exception as e: |
|
logger.error(f"Error in rerank_context: {e}") |
|
return {"context": "", "question": inputs["question"]} |
|
|
|
chain = rerank_context | prompt_template | llm | StrOutputParser() |
|
logger.info(f"Initialized {'re-ranked' if use_rerank else 'baseline'} RAG chain") |
|
return chain |
|
except Exception as e: |
|
logger.error(f"Error creating RAG chain: {e}") |
|
st.error(f"Failed to create RAG chain: {e}") |
|
return None |
|
|
|
def main(): |
|
st.title("Wikipedia Q&A with RAG (Qdrant + AWS Bedrock)") |
|
st.write("Enter a question to get answers using baseline and reranked retrieval methods.") |
|
|
|
|
|
try: |
|
load_environment() |
|
except ValueError: |
|
return |
|
|
|
|
|
documents = load_wikipedia_documents() |
|
if not documents: |
|
st.error("Cannot proceed without documents") |
|
return |
|
chunks = split_documents(documents) |
|
if not chunks: |
|
st.error("Cannot proceed without document chunks") |
|
return |
|
embeddings = initialize_embeddings() |
|
if embeddings is None: |
|
st.error("Cannot proceed without embeddings") |
|
return |
|
vector_store = store_in_qdrant(chunks, embeddings) |
|
if vector_store is None: |
|
st.error("Cannot proceed without vector store") |
|
return |
|
llm = initialize_llm() |
|
if llm is None: |
|
st.error("Cannot proceed without LLM") |
|
return |
|
|
|
baseline_chain = create_rag_chain(vector_store, llm, use_rerank=False) |
|
if baseline_chain is None: |
|
st.error("Cannot proceed without baseline chain") |
|
return |
|
rerank_chain = create_rag_chain(vector_store, llm, use_rerank=True) |
|
if rerank_chain is None: |
|
st.error("Cannot proceed without rerank chain") |
|
return |
|
|
|
|
|
query = st.text_input("Enter your question:", placeholder="e.g., What are the main causes of climate change?") |
|
if query: |
|
with st.spinner("Processing your query..."): |
|
try: |
|
baseline_response = baseline_chain.invoke({"question": query}) |
|
rerank_response = rerank_chain.invoke({"question": query}) |
|
|
|
st.subheader("Results") |
|
st.write("**Query:**", query) |
|
st.write("**Baseline Answer:**") |
|
st.write(baseline_response) |
|
st.write("**Reranked Answer:**") |
|
st.write(rerank_response) |
|
except Exception as e: |
|
logger.error(f"Error processing query: {e}") |
|
st.error(f"Error processing query: {e}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|