PrajwalW's picture
Update app.py
6bfaa0f verified
import streamlit as st
import boto3
import json
import chromadb
from datasets import load_dataset
import uuid
import time
# Simple function to connect to AWS Bedrock
def connect_to_bedrock():
client = boto3.client('bedrock-runtime', region_name='us-east-1')
return client
# Simple function to load Wikipedia documents
def load_wikipedia_docs(num_docs=100):
st.write(f"πŸ“š Loading {num_docs} Wikipedia documents...")
# Load Wikipedia dataset from Hugging Face
dataset = load_dataset("Cohere/wikipedia-22-12-simple-embeddings", split="train")
# Take only the first num_docs documents
documents = []
for i in range(min(num_docs, len(dataset))):
doc = dataset[i]
documents.append({
'text': doc['text'],
'title': doc.get('title', f'Document {i+1}'),
'id': str(i)
})
return documents
# Simple function to split text into chunks
def split_into_chunks(documents, chunk_size=500):
st.write("βœ‚οΈ Splitting documents into 500-character chunks...")
chunks = []
chunk_id = 0
for doc in documents:
text = doc['text']
title = doc['title']
# Split text into chunks of 500 characters
for i in range(0, len(text), chunk_size):
chunk_text = text[i:i + chunk_size]
if len(chunk_text.strip()) > 50: # Only keep meaningful chunks
chunks.append({
'id': str(chunk_id),
'text': chunk_text,
'title': title,
'doc_id': doc['id']
})
chunk_id += 1
return chunks
# Get embeddings from Bedrock Titan model
def get_embeddings(bedrock_client, text):
body = json.dumps({
"inputText": text
})
response = bedrock_client.invoke_model(
modelId="amazon.titan-embed-text-v1",
body=body
)
result = json.loads(response['body'].read())
return result['embedding']
# Store chunks in ChromaDB
def store_in_chromadb(bedrock_client, chunks):
st.write("πŸ’Ύ Storing chunks in ChromaDB with embeddings...")
# Create ChromaDB client
chroma_client = chromadb.Client()
# Create or get collection
try:
collection = chroma_client.get_collection("wikipedia_chunks")
chroma_client.delete_collection("wikipedia_chunks")
except:
pass
collection = chroma_client.create_collection("wikipedia_chunks")
# Prepare data for ChromaDB
ids = []
texts = []
metadatas = []
embeddings = []
progress_bar = st.progress(0)
for i, chunk in enumerate(chunks):
# Get embedding for each chunk
embedding = get_embeddings(bedrock_client, chunk['text'])
ids.append(chunk['id'])
texts.append(chunk['text'])
metadatas.append({
'title': chunk['title'],
'doc_id': chunk['doc_id']
})
embeddings.append(embedding)
# Update progress
progress_bar.progress((i + 1) / len(chunks))
# Add to ChromaDB in batches of 100
if len(ids) == 100 or i == len(chunks) - 1:
collection.add(
ids=ids,
documents=texts,
metadatas=metadatas,
embeddings=embeddings
)
ids, texts, metadatas, embeddings = [], [], [], []
return collection
# Simple retrieval without re-ranking
def simple_retrieval(collection, bedrock_client, query, top_k=10):
# Get query embedding
query_embedding = get_embeddings(bedrock_client, query)
# Search in ChromaDB
results = collection.query(
query_embeddings=[query_embedding],
n_results=top_k
)
# Format results
retrieved_docs = []
for i in range(len(results['documents'][0])):
retrieved_docs.append({
'text': results['documents'][0][i],
'title': results['metadatas'][0][i]['title'],
'distance': results['distances'][0][i]
})
return retrieved_docs
# Re-ranking using Claude 3.5
def rerank_with_claude(bedrock_client, query, documents, top_k=5):
# Create prompt for re-ranking
docs_text = ""
for i, doc in enumerate(documents):
docs_text += f"[{i+1}] {doc['text'][:200]}...\n\n"
prompt = f"""
Given the query: "{query}"
Please rank the following documents by relevance to the query.
Return only the numbers (1, 2, 3, etc.) of the most relevant documents in order, separated by commas.
Return exactly {top_k} numbers.
Documents:
{docs_text}
Most relevant document numbers (in order):
"""
body = json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 100,
"messages": [{"role": "user", "content": prompt}]
})
response = bedrock_client.invoke_model(
modelId="anthropic.claude-3-haiku-20240307-v1:0",
body=body
)
result = json.loads(response['body'].read())
ranking_text = result['content'][0]['text'].strip()
try:
# Parse the ranking
rankings = [int(x.strip()) - 1 for x in ranking_text.split(',')] # Convert to 0-based index
# Reorder documents based on ranking
reranked_docs = []
for rank in rankings[:top_k]:
if 0 <= rank < len(documents):
reranked_docs.append(documents[rank])
return reranked_docs
except:
# If parsing fails, return original order
return documents[:top_k]
# Generate answer using retrieved documents
def generate_answer(bedrock_client, query, documents):
# Combine documents into context
context = "\n\n".join([f"Source: {doc['title']}\n{doc['text']}" for doc in documents])
prompt = f"""
Based on the following information, please answer the question.
Question: {query}
Information:
{context}
Please provide a clear and comprehensive answer based on the information above.
"""
body = json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 500,
"messages": [{"role": "user", "content": prompt}]
})
response = bedrock_client.invoke_model(
modelId="anthropic.claude-3-haiku-20240307-v1:0",
body=body
)
result = json.loads(response['body'].read())
return result['content'][0]['text']
# Main app
def main():
st.title("πŸ” Wikipedia-Documents Retrieval with Re-ranking")
st.write("Compare search results with and without re-ranking!")
# Initialize session state
if 'collection' not in st.session_state:
st.session_state.collection = None
if 'setup_done' not in st.session_state:
st.session_state.setup_done = False
# Setup section
if not st.session_state.setup_done:
st.subheader("πŸ› οΈ Setup")
if st.button("πŸš€ Load Wikipedia Data and Setup ChromaDB"):
try:
with st.spinner("Setting up... This may take a few minutes..."):
# Connect to Bedrock
bedrock_client = connect_to_bedrock()
# Load Wikipedia documents
documents = load_wikipedia_docs(100)
st.success(f"βœ… Loaded {len(documents)} documents")
# Split into chunks
chunks = split_into_chunks(documents, 500)
st.success(f"βœ… Created {len(chunks)} chunks")
# Store in ChromaDB
collection = store_in_chromadb(bedrock_client, chunks)
st.session_state.collection = collection
st.session_state.setup_done = True
st.success("πŸŽ‰ Setup complete! You can now test queries below.")
st.balloons()
except Exception as e:
st.error(f"❌ Setup failed: {str(e)}")
else:
st.success("βœ… Setup completed! ChromaDB is ready with Wikipedia data.")
# Query testing section
st.subheader("πŸ” Test Queries")
# Predefined queries
sample_queries = [
"What are the main causes of climate change?",
"How does quantum computing work?",
"What were the social impacts of the industrial revolution?"
]
# Query selection
query_option = st.radio("Choose a query:",
["Custom Query"] + sample_queries)
if query_option == "Custom Query":
query = st.text_input("Enter your custom query:")
else:
query = query_option
st.write(f"Selected query: **{query}**")
if query:
if st.button("πŸ” Compare Retrieval Methods"):
try:
bedrock_client = connect_to_bedrock()
st.write("---")
# Method 1: Simple Retrieval
st.subheader("πŸ“‹ Method 1: Simple Retrieval (Baseline)")
with st.spinner("Performing simple retrieval..."):
simple_results = simple_retrieval(st.session_state.collection, bedrock_client, query, 10)
simple_top5 = simple_results[:5]
st.write("**Top 5 Results:**")
for i, doc in enumerate(simple_top5, 1):
with st.expander(f"{i}. {doc['title']} (Distance: {doc['distance']:.3f})"):
st.write(doc['text'][:300] + "...")
# Generate answer with simple retrieval
simple_answer = generate_answer(bedrock_client, query, simple_top5)
st.write("**Answer using Simple Retrieval:**")
st.info(simple_answer)
st.write("---")
# Method 2: Retrieval with Re-ranking
st.subheader("🎯 Method 2: Retrieval with Re-ranking")
with st.spinner("Performing retrieval with re-ranking..."):
# First get more results
initial_results = simple_retrieval(st.session_state.collection, bedrock_client, query, 10)
# Then re-rank them
reranked_results = rerank_with_claude(bedrock_client, query, initial_results, 5)
st.write("**Top 5 Re-ranked Results:**")
for i, doc in enumerate(reranked_results, 1):
with st.expander(f"{i}. {doc['title']} (Re-ranked)"):
st.write(doc['text'][:300] + "...")
# Generate answer with re-ranked results
reranked_answer = generate_answer(bedrock_client, query, reranked_results)
st.write("**Answer using Re-ranked Retrieval:**")
st.success(reranked_answer)
st.write("---")
st.subheader("πŸ“Š Comparison Summary")
st.write("**Simple Retrieval:** Uses only vector similarity to find relevant documents.")
st.write("**Re-ranked Retrieval:** Uses Claude 3.5 to intelligently reorder results for better relevance.")
except Exception as e:
st.error(f"❌ Error during retrieval: {str(e)}")
# Reset button
if st.button("πŸ”„ Reset Setup"):
st.session_state.collection = None
st.session_state.setup_done = False
st.rerun()
# Installation guide
def show_installation_guide():
with st.expander("πŸ“– Installation Guide"):
st.markdown("""
**Step 1: Install Required Libraries**
```bash
pip install streamlit boto3 chromadb datasets
```
**Step 2: Set up AWS**
```bash
aws configure
```
Enter your AWS access keys when prompted.
**Step 3: Run the App**
```bash
streamlit run reranking_app.py
```
**What this app does:**
1. Loads 100 Wikipedia documents
2. Splits them into 500-character chunks
3. Creates embeddings using Bedrock Titan
4. Stores in local ChromaDB
5. Compares simple vs re-ranked retrieval
""")
# Run the app
if __name__ == "__main__":
show_installation_guide()
main()