DarshanaD commited on
Commit
933dbc2
Β·
1 Parent(s): 3faa6ea

Initial commit

Browse files
Files changed (3) hide show
  1. README.md +2 -2
  2. app.py +364 -52
  3. requirements.txt +10 -1
README.md CHANGED
@@ -3,8 +3,8 @@ title: Rag Re Ranking
3
  emoji: πŸ’¬
4
  colorFrom: yellow
5
  colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.0.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
3
  emoji: πŸ’¬
4
  colorFrom: yellow
5
  colorTo: purple
6
+ sdk: stremlit
7
+ sdk_version: 1.35.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -1,64 +1,376 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
-
 
 
 
 
 
 
 
 
62
 
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
 
1
+ import streamlit as st
2
+ import boto3
3
+ import json
4
+ import chromadb
5
+ from datasets import load_dataset
6
+ import uuid
7
+ import time
8
 
9
+ # Simple function to connect to AWS Bedrock
10
+ def connect_to_bedrock():
11
+ client = boto3.client('bedrock-runtime', region_name='us-east-1')
12
+ return client
13
 
14
+ # Simple function to load Wikipedia documents
15
+ def load_wikipedia_docs(num_docs=100):
16
+ st.write(f"πŸ“š Loading {num_docs} Wikipedia documents...")
17
+
18
+ # Load Wikipedia dataset from Hugging Face
19
+ dataset = load_dataset("Cohere/wikipedia-22-12-simple-embeddings", split="train")
20
+
21
+ # Take only the first num_docs documents
22
+ documents = []
23
+ for i in range(min(num_docs, len(dataset))):
24
+ doc = dataset[i]
25
+ documents.append({
26
+ 'text': doc['text'],
27
+ 'title': doc.get('title', f'Document {i+1}'),
28
+ 'id': str(i)
29
+ })
30
+
31
+ return documents
32
 
33
+ # Simple function to split text into chunks
34
+ def split_into_chunks(documents, chunk_size=500):
35
+ st.write("βœ‚οΈ Splitting documents into 500-character chunks...")
36
+
37
+ chunks = []
38
+ chunk_id = 0
39
+
40
+ for doc in documents:
41
+ text = doc['text']
42
+ title = doc['title']
43
+
44
+ # Split text into chunks of 500 characters
45
+ for i in range(0, len(text), chunk_size):
46
+ chunk_text = text[i:i + chunk_size]
47
+ if len(chunk_text.strip()) > 50: # Only keep meaningful chunks
48
+ chunks.append({
49
+ 'id': str(chunk_id),
50
+ 'text': chunk_text,
51
+ 'title': title,
52
+ 'doc_id': doc['id']
53
+ })
54
+ chunk_id += 1
55
+
56
+ return chunks
57
 
58
+ # Get embeddings from Bedrock Titan model
59
+ def get_embeddings(bedrock_client, text):
60
+ body = json.dumps({
61
+ "inputText": text
62
+ })
63
+
64
+ response = bedrock_client.invoke_model(
65
+ modelId="amazon.titan-embed-text-v1",
66
+ body=body
67
+ )
68
+
69
+ result = json.loads(response['body'].read())
70
+ return result['embedding']
71
 
72
+ # Store chunks in ChromaDB
73
+ def store_in_chromadb(bedrock_client, chunks):
74
+ st.write("πŸ’Ύ Storing chunks in ChromaDB with embeddings...")
75
+
76
+ # Create ChromaDB client
77
+ chroma_client = chromadb.Client()
78
+
79
+ # Create or get collection
80
+ try:
81
+ collection = chroma_client.get_collection("wikipedia_chunks")
82
+ chroma_client.delete_collection("wikipedia_chunks")
83
+ except:
84
+ pass
85
+
86
+ collection = chroma_client.create_collection("wikipedia_chunks")
87
+
88
+ # Prepare data for ChromaDB
89
+ ids = []
90
+ texts = []
91
+ metadatas = []
92
+ embeddings = []
93
+
94
+ progress_bar = st.progress(0)
95
+
96
+ for i, chunk in enumerate(chunks):
97
+ # Get embedding for each chunk
98
+ embedding = get_embeddings(bedrock_client, chunk['text'])
99
+
100
+ ids.append(chunk['id'])
101
+ texts.append(chunk['text'])
102
+ metadatas.append({
103
+ 'title': chunk['title'],
104
+ 'doc_id': chunk['doc_id']
105
+ })
106
+ embeddings.append(embedding)
107
+
108
+ # Update progress
109
+ progress_bar.progress((i + 1) / len(chunks))
110
+
111
+ # Add to ChromaDB in batches of 100
112
+ if len(ids) == 100 or i == len(chunks) - 1:
113
+ collection.add(
114
+ ids=ids,
115
+ documents=texts,
116
+ metadatas=metadatas,
117
+ embeddings=embeddings
118
+ )
119
+ ids, texts, metadatas, embeddings = [], [], [], []
120
+
121
+ return collection
122
 
123
+ # Simple retrieval without re-ranking
124
+ def simple_retrieval(collection, bedrock_client, query, top_k=10):
125
+ # Get query embedding
126
+ query_embedding = get_embeddings(bedrock_client, query)
127
+
128
+ # Search in ChromaDB
129
+ results = collection.query(
130
+ query_embeddings=[query_embedding],
131
+ n_results=top_k
132
+ )
133
+
134
+ # Format results
135
+ retrieved_docs = []
136
+ for i in range(len(results['documents'][0])):
137
+ retrieved_docs.append({
138
+ 'text': results['documents'][0][i],
139
+ 'title': results['metadatas'][0][i]['title'],
140
+ 'distance': results['distances'][0][i]
141
+ })
142
+
143
+ return retrieved_docs
144
 
145
+ # Re-ranking using Claude 3.5
146
+ def rerank_with_claude(bedrock_client, query, documents, top_k=5):
147
+ # Create prompt for re-ranking
148
+ docs_text = ""
149
+ for i, doc in enumerate(documents):
150
+ docs_text += f"[{i+1}] {doc['text'][:200]}...\n\n"
151
+
152
+ prompt = f"""
153
+ Given the query: "{query}"
154
+
155
+ Please rank the following documents by relevance to the query.
156
+ Return only the numbers (1, 2, 3, etc.) of the most relevant documents in order, separated by commas.
157
+ Return exactly {top_k} numbers.
158
+
159
+ Documents:
160
+ {docs_text}
161
+
162
+ Most relevant document numbers (in order):
163
+ """
164
+
165
+ body = json.dumps({
166
+ "anthropic_version": "bedrock-2023-05-31",
167
+ "max_tokens": 100,
168
+ "messages": [{"role": "user", "content": prompt}]
169
+ })
170
+
171
+ response = bedrock_client.invoke_model(
172
+ modelId="anthropic.claude-3-haiku-20240307-v1:0",
173
+ body=body
174
+ )
175
+
176
+ result = json.loads(response['body'].read())
177
+ ranking_text = result['content'][0]['text'].strip()
178
+
179
+ try:
180
+ # Parse the ranking
181
+ rankings = [int(x.strip()) - 1 for x in ranking_text.split(',')] # Convert to 0-based index
182
+
183
+ # Reorder documents based on ranking
184
+ reranked_docs = []
185
+ for rank in rankings[:top_k]:
186
+ if 0 <= rank < len(documents):
187
+ reranked_docs.append(documents[rank])
188
+
189
+ return reranked_docs
190
+ except:
191
+ # If parsing fails, return original order
192
+ return documents[:top_k]
193
 
194
+ # Generate answer using retrieved documents
195
+ def generate_answer(bedrock_client, query, documents):
196
+ # Combine documents into context
197
+ context = "\n\n".join([f"Source: {doc['title']}\n{doc['text']}" for doc in documents])
198
+
199
+ prompt = f"""
200
+ Based on the following information, please answer the question.
201
+
202
+ Question: {query}
203
+
204
+ Information:
205
+ {context}
206
+
207
+ Please provide a clear and comprehensive answer based on the information above.
208
+ """
209
+
210
+ body = json.dumps({
211
+ "anthropic_version": "bedrock-2023-05-31",
212
+ "max_tokens": 500,
213
+ "messages": [{"role": "user", "content": prompt}]
214
+ })
215
+
216
+ response = bedrock_client.invoke_model(
217
+ modelId="anthropic.claude-3-haiku-20240307-v1:0",
218
+ body=body
219
+ )
220
+
221
+ result = json.loads(response['body'].read())
222
+ return result['content'][0]['text']
223
 
224
+ # Main app
225
+ def main():
226
+ st.title("πŸ” Wikipedia Retrieval with Re-ranking")
227
+ st.write("Compare search results with and without re-ranking!")
228
+
229
+ # Initialize session state
230
+ if 'collection' not in st.session_state:
231
+ st.session_state.collection = None
232
+ if 'setup_done' not in st.session_state:
233
+ st.session_state.setup_done = False
234
+
235
+ # Setup section
236
+ if not st.session_state.setup_done:
237
+ st.subheader("πŸ› οΈ Setup")
238
+
239
+ if st.button("πŸš€ Load Wikipedia Data and Setup ChromaDB"):
240
+ try:
241
+ with st.spinner("Setting up... This may take a few minutes..."):
242
+ # Connect to Bedrock
243
+ bedrock_client = connect_to_bedrock()
244
+
245
+ # Load Wikipedia documents
246
+ documents = load_wikipedia_docs(100)
247
+ st.success(f"βœ… Loaded {len(documents)} documents")
248
+
249
+ # Split into chunks
250
+ chunks = split_into_chunks(documents, 500)
251
+ st.success(f"βœ… Created {len(chunks)} chunks")
252
+
253
+ # Store in ChromaDB
254
+ collection = store_in_chromadb(bedrock_client, chunks)
255
+ st.session_state.collection = collection
256
+ st.session_state.setup_done = True
257
+
258
+ st.success("πŸŽ‰ Setup complete! You can now test queries below.")
259
+ st.balloons()
260
+
261
+ except Exception as e:
262
+ st.error(f"❌ Setup failed: {str(e)}")
263
+
264
+ else:
265
+ st.success("βœ… Setup completed! ChromaDB is ready with Wikipedia data.")
266
+
267
+ # Query testing section
268
+ st.subheader("πŸ” Test Queries")
269
+
270
+ # Predefined queries
271
+ sample_queries = [
272
+ "What are the main causes of climate change?",
273
+ "How does quantum computing work?",
274
+ "What were the social impacts of the industrial revolution?"
275
+ ]
276
+
277
+ # Query selection
278
+ query_option = st.radio("Choose a query:",
279
+ ["Custom Query"] + sample_queries)
280
+
281
+ if query_option == "Custom Query":
282
+ query = st.text_input("Enter your custom query:")
283
+ else:
284
+ query = query_option
285
+ st.write(f"Selected query: **{query}**")
286
+
287
+ if query:
288
+ if st.button("πŸ” Compare Retrieval Methods"):
289
+ try:
290
+ bedrock_client = connect_to_bedrock()
291
+
292
+ st.write("---")
293
+
294
+ # Method 1: Simple Retrieval
295
+ st.subheader("πŸ“‹ Method 1: Simple Retrieval (Baseline)")
296
+ with st.spinner("Performing simple retrieval..."):
297
+ simple_results = simple_retrieval(st.session_state.collection, bedrock_client, query, 10)
298
+ simple_top5 = simple_results[:5]
299
+
300
+ st.write("**Top 5 Results:**")
301
+ for i, doc in enumerate(simple_top5, 1):
302
+ with st.expander(f"{i}. {doc['title']} (Distance: {doc['distance']:.3f})"):
303
+ st.write(doc['text'][:300] + "...")
304
+
305
+ # Generate answer with simple retrieval
306
+ simple_answer = generate_answer(bedrock_client, query, simple_top5)
307
+ st.write("**Answer using Simple Retrieval:**")
308
+ st.info(simple_answer)
309
+
310
+ st.write("---")
311
+
312
+ # Method 2: Retrieval with Re-ranking
313
+ st.subheader("🎯 Method 2: Retrieval with Re-ranking")
314
+ with st.spinner("Performing retrieval with re-ranking..."):
315
+ # First get more results
316
+ initial_results = simple_retrieval(st.session_state.collection, bedrock_client, query, 10)
317
+
318
+ # Then re-rank them
319
+ reranked_results = rerank_with_claude(bedrock_client, query, initial_results, 5)
320
+
321
+ st.write("**Top 5 Re-ranked Results:**")
322
+ for i, doc in enumerate(reranked_results, 1):
323
+ with st.expander(f"{i}. {doc['title']} (Re-ranked)"):
324
+ st.write(doc['text'][:300] + "...")
325
+
326
+ # Generate answer with re-ranked results
327
+ reranked_answer = generate_answer(bedrock_client, query, reranked_results)
328
+ st.write("**Answer using Re-ranked Retrieval:**")
329
+ st.success(reranked_answer)
330
+
331
+ st.write("---")
332
+ st.subheader("πŸ“Š Comparison Summary")
333
+ st.write("**Simple Retrieval:** Uses only vector similarity to find relevant documents.")
334
+ st.write("**Re-ranked Retrieval:** Uses Claude 3.5 to intelligently reorder results for better relevance.")
335
+
336
+ except Exception as e:
337
+ st.error(f"❌ Error during retrieval: {str(e)}")
338
+
339
+ # Reset button
340
+ if st.button("πŸ”„ Reset Setup"):
341
+ st.session_state.collection = None
342
+ st.session_state.setup_done = False
343
+ st.rerun()
344
 
345
+ # Installation guide
346
+ def show_installation_guide():
347
+ with st.expander("πŸ“– Installation Guide"):
348
+ st.markdown("""
349
+ **Step 1: Install Required Libraries**
350
+ ```bash
351
+ pip install streamlit boto3 chromadb datasets
352
+ ```
353
+
354
+ **Step 2: Set up AWS**
355
+ ```bash
356
+ aws configure
357
+ ```
358
+ Enter your AWS access keys when prompted.
359
+
360
+ **Step 3: Run the App**
361
+ ```bash
362
+ streamlit run reranking_app.py
363
+ ```
364
+
365
+ **What this app does:**
366
+ 1. Loads 100 Wikipedia documents
367
+ 2. Splits them into 500-character chunks
368
+ 3. Creates embeddings using Bedrock Titan
369
+ 4. Stores in local ChromaDB
370
+ 5. Compares simple vs re-ranked retrieval
371
+ """)
372
 
373
+ # Run the app
374
  if __name__ == "__main__":
375
+ show_installation_guide()
376
+ main()
requirements.txt CHANGED
@@ -1 +1,10 @@
1
- huggingface_hub==0.25.2
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub==0.25.2
2
+ qdrant_client
3
+ streamlit
4
+ boto3
5
+ PyPDF2
6
+ chromadb
7
+ datasets
8
+
9
+ streamlit
10
+ boto3