|
import streamlit as st |
|
import boto3 |
|
import json |
|
import chromadb |
|
from datasets import load_dataset |
|
import uuid |
|
import time |
|
|
|
|
|
def connect_to_bedrock(): |
|
client = boto3.client('bedrock-runtime', region_name='us-east-1') |
|
return client |
|
|
|
|
|
def load_wikipedia_docs(num_docs=100): |
|
st.write(f"π Loading {num_docs} Wikipedia documents...") |
|
|
|
|
|
dataset = load_dataset("Cohere/wikipedia-22-12-simple-embeddings", split="train") |
|
|
|
|
|
documents = [] |
|
for i in range(min(num_docs, len(dataset))): |
|
doc = dataset[i] |
|
documents.append({ |
|
'text': doc['text'], |
|
'title': doc.get('title', f'Document {i+1}'), |
|
'id': str(i) |
|
}) |
|
|
|
return documents |
|
|
|
|
|
def split_into_chunks(documents, chunk_size=500): |
|
st.write("βοΈ Splitting documents into 500-character chunks...") |
|
|
|
chunks = [] |
|
chunk_id = 0 |
|
|
|
for doc in documents: |
|
text = doc['text'] |
|
title = doc['title'] |
|
|
|
|
|
for i in range(0, len(text), chunk_size): |
|
chunk_text = text[i:i + chunk_size] |
|
if len(chunk_text.strip()) > 50: |
|
chunks.append({ |
|
'id': str(chunk_id), |
|
'text': chunk_text, |
|
'title': title, |
|
'doc_id': doc['id'] |
|
}) |
|
chunk_id += 1 |
|
|
|
return chunks |
|
|
|
|
|
def get_embeddings(bedrock_client, text): |
|
body = json.dumps({ |
|
"inputText": text |
|
}) |
|
|
|
response = bedrock_client.invoke_model( |
|
modelId="amazon.titan-embed-text-v1", |
|
body=body |
|
) |
|
|
|
result = json.loads(response['body'].read()) |
|
return result['embedding'] |
|
|
|
|
|
def store_in_chromadb(bedrock_client, chunks): |
|
st.write("πΎ Storing chunks in ChromaDB with embeddings...") |
|
|
|
|
|
chroma_client = chromadb.Client() |
|
|
|
|
|
try: |
|
collection = chroma_client.get_collection("wikipedia_chunks") |
|
chroma_client.delete_collection("wikipedia_chunks") |
|
except: |
|
pass |
|
|
|
collection = chroma_client.create_collection("wikipedia_chunks") |
|
|
|
|
|
ids = [] |
|
texts = [] |
|
metadatas = [] |
|
embeddings = [] |
|
|
|
progress_bar = st.progress(0) |
|
|
|
for i, chunk in enumerate(chunks): |
|
|
|
embedding = get_embeddings(bedrock_client, chunk['text']) |
|
|
|
ids.append(chunk['id']) |
|
texts.append(chunk['text']) |
|
metadatas.append({ |
|
'title': chunk['title'], |
|
'doc_id': chunk['doc_id'] |
|
}) |
|
embeddings.append(embedding) |
|
|
|
|
|
progress_bar.progress((i + 1) / len(chunks)) |
|
|
|
|
|
if len(ids) == 100 or i == len(chunks) - 1: |
|
collection.add( |
|
ids=ids, |
|
documents=texts, |
|
metadatas=metadatas, |
|
embeddings=embeddings |
|
) |
|
ids, texts, metadatas, embeddings = [], [], [], [] |
|
|
|
return collection |
|
|
|
|
|
def simple_retrieval(collection, bedrock_client, query, top_k=10): |
|
|
|
query_embedding = get_embeddings(bedrock_client, query) |
|
|
|
|
|
results = collection.query( |
|
query_embeddings=[query_embedding], |
|
n_results=top_k |
|
) |
|
|
|
|
|
retrieved_docs = [] |
|
for i in range(len(results['documents'][0])): |
|
retrieved_docs.append({ |
|
'text': results['documents'][0][i], |
|
'title': results['metadatas'][0][i]['title'], |
|
'distance': results['distances'][0][i] |
|
}) |
|
|
|
return retrieved_docs |
|
|
|
|
|
def rerank_with_claude(bedrock_client, query, documents, top_k=5): |
|
|
|
docs_text = "" |
|
for i, doc in enumerate(documents): |
|
docs_text += f"[{i+1}] {doc['text'][:200]}...\n\n" |
|
|
|
prompt = f""" |
|
Given the query: "{query}" |
|
|
|
Please rank the following documents by relevance to the query. |
|
Return only the numbers (1, 2, 3, etc.) of the most relevant documents in order, separated by commas. |
|
Return exactly {top_k} numbers. |
|
|
|
Documents: |
|
{docs_text} |
|
|
|
Most relevant document numbers (in order): |
|
""" |
|
|
|
body = json.dumps({ |
|
"anthropic_version": "bedrock-2023-05-31", |
|
"max_tokens": 100, |
|
"messages": [{"role": "user", "content": prompt}] |
|
}) |
|
|
|
response = bedrock_client.invoke_model( |
|
modelId="anthropic.claude-3-haiku-20240307-v1:0", |
|
body=body |
|
) |
|
|
|
result = json.loads(response['body'].read()) |
|
ranking_text = result['content'][0]['text'].strip() |
|
|
|
try: |
|
|
|
rankings = [int(x.strip()) - 1 for x in ranking_text.split(',')] |
|
|
|
|
|
reranked_docs = [] |
|
for rank in rankings[:top_k]: |
|
if 0 <= rank < len(documents): |
|
reranked_docs.append(documents[rank]) |
|
|
|
return reranked_docs |
|
except: |
|
|
|
return documents[:top_k] |
|
|
|
|
|
def generate_answer(bedrock_client, query, documents): |
|
|
|
context = "\n\n".join([f"Source: {doc['title']}\n{doc['text']}" for doc in documents]) |
|
|
|
prompt = f""" |
|
Based on the following information, please answer the question. |
|
|
|
Question: {query} |
|
|
|
Information: |
|
{context} |
|
|
|
Please provide a clear and comprehensive answer based on the information above. |
|
""" |
|
|
|
body = json.dumps({ |
|
"anthropic_version": "bedrock-2023-05-31", |
|
"max_tokens": 500, |
|
"messages": [{"role": "user", "content": prompt}] |
|
}) |
|
|
|
response = bedrock_client.invoke_model( |
|
modelId="anthropic.claude-3-haiku-20240307-v1:0", |
|
body=body |
|
) |
|
|
|
result = json.loads(response['body'].read()) |
|
return result['content'][0]['text'] |
|
|
|
|
|
def main(): |
|
st.title("π Wikipedia Retrieval with Re-ranking") |
|
st.write("Compare search results with and without re-ranking!") |
|
|
|
|
|
if 'collection' not in st.session_state: |
|
st.session_state.collection = None |
|
if 'setup_done' not in st.session_state: |
|
st.session_state.setup_done = False |
|
|
|
|
|
if not st.session_state.setup_done: |
|
st.subheader("π οΈ Setup") |
|
|
|
if st.button("π Load Wikipedia Data and Setup ChromaDB"): |
|
try: |
|
with st.spinner("Setting up... This may take a few minutes..."): |
|
|
|
bedrock_client = connect_to_bedrock() |
|
|
|
|
|
documents = load_wikipedia_docs(100) |
|
st.success(f"β
Loaded {len(documents)} documents") |
|
|
|
|
|
chunks = split_into_chunks(documents, 500) |
|
st.success(f"β
Created {len(chunks)} chunks") |
|
|
|
|
|
collection = store_in_chromadb(bedrock_client, chunks) |
|
st.session_state.collection = collection |
|
st.session_state.setup_done = True |
|
|
|
st.success("π Setup complete! You can now test queries below.") |
|
st.balloons() |
|
|
|
except Exception as e: |
|
st.error(f"β Setup failed: {str(e)}") |
|
|
|
else: |
|
st.success("β
Setup completed! ChromaDB is ready with Wikipedia data.") |
|
|
|
|
|
st.subheader("π Test Queries") |
|
|
|
|
|
sample_queries = [ |
|
"What are the main causes of climate change?", |
|
"How does quantum computing work?", |
|
"What were the social impacts of the industrial revolution?" |
|
] |
|
|
|
|
|
query_option = st.radio("Choose a query:", |
|
["Custom Query"] + sample_queries) |
|
|
|
if query_option == "Custom Query": |
|
query = st.text_input("Enter your custom query:") |
|
else: |
|
query = query_option |
|
st.write(f"Selected query: **{query}**") |
|
|
|
if query: |
|
if st.button("π Compare Retrieval Methods"): |
|
try: |
|
bedrock_client = connect_to_bedrock() |
|
|
|
st.write("---") |
|
|
|
|
|
st.subheader("π Method 1: Simple Retrieval (Baseline)") |
|
with st.spinner("Performing simple retrieval..."): |
|
simple_results = simple_retrieval(st.session_state.collection, bedrock_client, query, 10) |
|
simple_top5 = simple_results[:5] |
|
|
|
st.write("**Top 5 Results:**") |
|
for i, doc in enumerate(simple_top5, 1): |
|
with st.expander(f"{i}. {doc['title']} (Distance: {doc['distance']:.3f})"): |
|
st.write(doc['text'][:300] + "...") |
|
|
|
|
|
simple_answer = generate_answer(bedrock_client, query, simple_top5) |
|
st.write("**Answer using Simple Retrieval:**") |
|
st.info(simple_answer) |
|
|
|
st.write("---") |
|
|
|
|
|
st.subheader("π― Method 2: Retrieval with Re-ranking") |
|
with st.spinner("Performing retrieval with re-ranking..."): |
|
|
|
initial_results = simple_retrieval(st.session_state.collection, bedrock_client, query, 10) |
|
|
|
|
|
reranked_results = rerank_with_claude(bedrock_client, query, initial_results, 5) |
|
|
|
st.write("**Top 5 Re-ranked Results:**") |
|
for i, doc in enumerate(reranked_results, 1): |
|
with st.expander(f"{i}. {doc['title']} (Re-ranked)"): |
|
st.write(doc['text'][:300] + "...") |
|
|
|
|
|
reranked_answer = generate_answer(bedrock_client, query, reranked_results) |
|
st.write("**Answer using Re-ranked Retrieval:**") |
|
st.success(reranked_answer) |
|
|
|
st.write("---") |
|
st.subheader("π Comparison Summary") |
|
st.write("**Simple Retrieval:** Uses only vector similarity to find relevant documents.") |
|
st.write("**Re-ranked Retrieval:** Uses Claude 3.5 to intelligently reorder results for better relevance.") |
|
|
|
except Exception as e: |
|
st.error(f"β Error during retrieval: {str(e)}") |
|
|
|
|
|
if st.button("π Reset Setup"): |
|
st.session_state.collection = None |
|
st.session_state.setup_done = False |
|
st.rerun() |
|
|
|
|
|
def show_installation_guide(): |
|
with st.expander("π Installation Guide"): |
|
st.markdown(""" |
|
**Step 1: Install Required Libraries** |
|
```bash |
|
pip install streamlit boto3 chromadb datasets |
|
``` |
|
|
|
**Step 2: Set up AWS** |
|
```bash |
|
aws configure |
|
``` |
|
Enter your AWS access keys when prompted. |
|
|
|
**Step 3: Run the App** |
|
```bash |
|
streamlit run reranking_app.py |
|
``` |
|
|
|
**What this app does:** |
|
1. Loads 100 Wikipedia documents |
|
2. Splits them into 500-character chunks |
|
3. Creates embeddings using Bedrock Titan |
|
4. Stores in local ChromaDB |
|
5. Compares simple vs re-ranked retrieval |
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
show_installation_guide() |
|
main() |