Spaces:
Sleeping
Sleeping
import streamlit as st | |
import boto3 | |
import json | |
import chromadb | |
from datasets import load_dataset | |
import uuid | |
import time | |
# Simple function to connect to AWS Bedrock | |
def connect_to_bedrock(): | |
client = boto3.client('bedrock-runtime', region_name='us-east-1') | |
return client | |
# Simple function to load Wikipedia documents | |
def load_wikipedia_docs(num_docs=100): | |
st.write(f"π Loading {num_docs} Wikipedia documents...") | |
# Load Wikipedia dataset from Hugging Face | |
dataset = load_dataset("Cohere/wikipedia-22-12-simple-embeddings", split="train") | |
# Take only the first num_docs documents | |
documents = [] | |
for i in range(min(num_docs, len(dataset))): | |
doc = dataset[i] | |
documents.append({ | |
'text': doc['text'], | |
'title': doc.get('title', f'Document {i+1}'), | |
'id': str(i) | |
}) | |
return documents | |
# Simple function to split text into chunks | |
def split_into_chunks(documents, chunk_size=500): | |
st.write("βοΈ Splitting documents into 500-character chunks...") | |
chunks = [] | |
chunk_id = 0 | |
for doc in documents: | |
text = doc['text'] | |
title = doc['title'] | |
# Split text into chunks of 500 characters | |
for i in range(0, len(text), chunk_size): | |
chunk_text = text[i:i + chunk_size] | |
if len(chunk_text.strip()) > 50: # Only keep meaningful chunks | |
chunks.append({ | |
'id': str(chunk_id), | |
'text': chunk_text, | |
'title': title, | |
'doc_id': doc['id'] | |
}) | |
chunk_id += 1 | |
return chunks | |
# Get embeddings from Bedrock Titan model | |
def get_embeddings(bedrock_client, text): | |
body = json.dumps({ | |
"inputText": text | |
}) | |
response = bedrock_client.invoke_model( | |
modelId="amazon.titan-embed-text-v1", | |
body=body | |
) | |
result = json.loads(response['body'].read()) | |
return result['embedding'] | |
# Store chunks in ChromaDB | |
def store_in_chromadb(bedrock_client, chunks): | |
st.write("πΎ Storing chunks in ChromaDB with embeddings...") | |
# Create ChromaDB client | |
chroma_client = chromadb.Client() | |
# Create or get collection | |
try: | |
collection = chroma_client.get_collection("wikipedia_chunks") | |
chroma_client.delete_collection("wikipedia_chunks") | |
except: | |
pass | |
collection = chroma_client.create_collection("wikipedia_chunks") | |
# Prepare data for ChromaDB | |
ids = [] | |
texts = [] | |
metadatas = [] | |
embeddings = [] | |
progress_bar = st.progress(0) | |
for i, chunk in enumerate(chunks): | |
# Get embedding for each chunk | |
embedding = get_embeddings(bedrock_client, chunk['text']) | |
ids.append(chunk['id']) | |
texts.append(chunk['text']) | |
metadatas.append({ | |
'title': chunk['title'], | |
'doc_id': chunk['doc_id'] | |
}) | |
embeddings.append(embedding) | |
# Update progress | |
progress_bar.progress((i + 1) / len(chunks)) | |
# Add to ChromaDB in batches of 100 | |
if len(ids) == 100 or i == len(chunks) - 1: | |
collection.add( | |
ids=ids, | |
documents=texts, | |
metadatas=metadatas, | |
embeddings=embeddings | |
) | |
ids, texts, metadatas, embeddings = [], [], [], [] | |
return collection | |
# Simple retrieval without re-ranking | |
def simple_retrieval(collection, bedrock_client, query, top_k=10): | |
# Get query embedding | |
query_embedding = get_embeddings(bedrock_client, query) | |
# Search in ChromaDB | |
results = collection.query( | |
query_embeddings=[query_embedding], | |
n_results=top_k | |
) | |
# Format results | |
retrieved_docs = [] | |
for i in range(len(results['documents'][0])): | |
retrieved_docs.append({ | |
'text': results['documents'][0][i], | |
'title': results['metadatas'][0][i]['title'], | |
'distance': results['distances'][0][i] | |
}) | |
return retrieved_docs | |
# Re-ranking using Claude 3.5 | |
def rerank_with_claude(bedrock_client, query, documents, top_k=5): | |
# Create prompt for re-ranking | |
docs_text = "" | |
for i, doc in enumerate(documents): | |
docs_text += f"[{i+1}] {doc['text'][:200]}...\n\n" | |
prompt = f""" | |
Given the query: "{query}" | |
Please rank the following documents by relevance to the query. | |
Return only the numbers (1, 2, 3, etc.) of the most relevant documents in order, separated by commas. | |
Return exactly {top_k} numbers. | |
Documents: | |
{docs_text} | |
Most relevant document numbers (in order): | |
""" | |
body = json.dumps({ | |
"anthropic_version": "bedrock-2023-05-31", | |
"max_tokens": 100, | |
"messages": [{"role": "user", "content": prompt}] | |
}) | |
response = bedrock_client.invoke_model( | |
modelId="anthropic.claude-3-haiku-20240307-v1:0", | |
body=body | |
) | |
result = json.loads(response['body'].read()) | |
ranking_text = result['content'][0]['text'].strip() | |
try: | |
# Parse the ranking | |
rankings = [int(x.strip()) - 1 for x in ranking_text.split(',')] # Convert to 0-based index | |
# Reorder documents based on ranking | |
reranked_docs = [] | |
for rank in rankings[:top_k]: | |
if 0 <= rank < len(documents): | |
reranked_docs.append(documents[rank]) | |
return reranked_docs | |
except: | |
# If parsing fails, return original order | |
return documents[:top_k] | |
# Generate answer using retrieved documents | |
def generate_answer(bedrock_client, query, documents): | |
# Combine documents into context | |
context = "\n\n".join([f"Source: {doc['title']}\n{doc['text']}" for doc in documents]) | |
prompt = f""" | |
Based on the following information, please answer the question. | |
Question: {query} | |
Information: | |
{context} | |
Please provide a clear and comprehensive answer based on the information above. | |
""" | |
body = json.dumps({ | |
"anthropic_version": "bedrock-2023-05-31", | |
"max_tokens": 500, | |
"messages": [{"role": "user", "content": prompt}] | |
}) | |
response = bedrock_client.invoke_model( | |
modelId="anthropic.claude-3-haiku-20240307-v1:0", | |
body=body | |
) | |
result = json.loads(response['body'].read()) | |
return result['content'][0]['text'] | |
# Main app | |
def main(): | |
st.title("π Wikipedia-Documents Retrieval with Re-ranking") | |
st.write("Compare search results with and without re-ranking!") | |
# Initialize session state | |
if 'collection' not in st.session_state: | |
st.session_state.collection = None | |
if 'setup_done' not in st.session_state: | |
st.session_state.setup_done = False | |
# Setup section | |
if not st.session_state.setup_done: | |
st.subheader("π οΈ Setup") | |
if st.button("π Load Wikipedia Data and Setup ChromaDB"): | |
try: | |
with st.spinner("Setting up... This may take a few minutes..."): | |
# Connect to Bedrock | |
bedrock_client = connect_to_bedrock() | |
# Load Wikipedia documents | |
documents = load_wikipedia_docs(100) | |
st.success(f"β Loaded {len(documents)} documents") | |
# Split into chunks | |
chunks = split_into_chunks(documents, 500) | |
st.success(f"β Created {len(chunks)} chunks") | |
# Store in ChromaDB | |
collection = store_in_chromadb(bedrock_client, chunks) | |
st.session_state.collection = collection | |
st.session_state.setup_done = True | |
st.success("π Setup complete! You can now test queries below.") | |
st.balloons() | |
except Exception as e: | |
st.error(f"β Setup failed: {str(e)}") | |
else: | |
st.success("β Setup completed! ChromaDB is ready with Wikipedia data.") | |
# Query testing section | |
st.subheader("π Test Queries") | |
# Predefined queries | |
sample_queries = [ | |
"What are the main causes of climate change?", | |
"How does quantum computing work?", | |
"What were the social impacts of the industrial revolution?" | |
] | |
# Query selection | |
query_option = st.radio("Choose a query:", | |
["Custom Query"] + sample_queries) | |
if query_option == "Custom Query": | |
query = st.text_input("Enter your custom query:") | |
else: | |
query = query_option | |
st.write(f"Selected query: **{query}**") | |
if query: | |
if st.button("π Compare Retrieval Methods"): | |
try: | |
bedrock_client = connect_to_bedrock() | |
st.write("---") | |
# Method 1: Simple Retrieval | |
st.subheader("π Method 1: Simple Retrieval (Baseline)") | |
with st.spinner("Performing simple retrieval..."): | |
simple_results = simple_retrieval(st.session_state.collection, bedrock_client, query, 10) | |
simple_top5 = simple_results[:5] | |
st.write("**Top 5 Results:**") | |
for i, doc in enumerate(simple_top5, 1): | |
with st.expander(f"{i}. {doc['title']} (Distance: {doc['distance']:.3f})"): | |
st.write(doc['text'][:300] + "...") | |
# Generate answer with simple retrieval | |
simple_answer = generate_answer(bedrock_client, query, simple_top5) | |
st.write("**Answer using Simple Retrieval:**") | |
st.info(simple_answer) | |
st.write("---") | |
# Method 2: Retrieval with Re-ranking | |
st.subheader("π― Method 2: Retrieval with Re-ranking") | |
with st.spinner("Performing retrieval with re-ranking..."): | |
# First get more results | |
initial_results = simple_retrieval(st.session_state.collection, bedrock_client, query, 10) | |
# Then re-rank them | |
reranked_results = rerank_with_claude(bedrock_client, query, initial_results, 5) | |
st.write("**Top 5 Re-ranked Results:**") | |
for i, doc in enumerate(reranked_results, 1): | |
with st.expander(f"{i}. {doc['title']} (Re-ranked)"): | |
st.write(doc['text'][:300] + "...") | |
# Generate answer with re-ranked results | |
reranked_answer = generate_answer(bedrock_client, query, reranked_results) | |
st.write("**Answer using Re-ranked Retrieval:**") | |
st.success(reranked_answer) | |
st.write("---") | |
st.subheader("π Comparison Summary") | |
st.write("**Simple Retrieval:** Uses only vector similarity to find relevant documents.") | |
st.write("**Re-ranked Retrieval:** Uses Claude 3.5 to intelligently reorder results for better relevance.") | |
except Exception as e: | |
st.error(f"β Error during retrieval: {str(e)}") | |
# Reset button | |
if st.button("π Reset Setup"): | |
st.session_state.collection = None | |
st.session_state.setup_done = False | |
st.rerun() | |
# Installation guide | |
def show_installation_guide(): | |
with st.expander("π Installation Guide"): | |
st.markdown(""" | |
**Step 1: Install Required Libraries** | |
```bash | |
pip install streamlit boto3 chromadb datasets | |
``` | |
**Step 2: Set up AWS** | |
```bash | |
aws configure | |
``` | |
Enter your AWS access keys when prompted. | |
**Step 3: Run the App** | |
```bash | |
streamlit run reranking_app.py | |
``` | |
**What this app does:** | |
1. Loads 100 Wikipedia documents | |
2. Splits them into 500-character chunks | |
3. Creates embeddings using Bedrock Titan | |
4. Stores in local ChromaDB | |
5. Compares simple vs re-ranked retrieval | |
""") | |
# Run the app | |
if __name__ == "__main__": | |
show_installation_guide() | |
main() |