AdithyaSK commited on
Commit
d047a23
Β·
1 Parent(s): db64b10

Refactor app.py: improve code formatting and enhance readability

Browse files
Files changed (1) hide show
  1. app.py +79 -28
app.py CHANGED
@@ -23,9 +23,16 @@ import seaborn as sns
23
  from einops import rearrange
24
 
25
  # Import from colpali_engine
26
- from colpali_engine.models import BiGemma3, BiGemmaProcessor3, ColGemma3, ColGemmaProcessor3
 
 
 
 
 
27
  from colpali_engine.interpretability import get_similarity_maps_from_embeddings
28
- from colpali_engine.interpretability.similarity_map_utils import normalize_similarity_map
 
 
29
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
 
@@ -33,6 +40,7 @@ print(f"Device: {device}")
33
  if torch.cuda.is_available():
34
  print(f"GPU: {torch.cuda.get_device_name(0)}")
35
 
 
36
  # Global state for models and indexed documents
37
  class DocumentIndex:
38
  def __init__(self):
@@ -44,6 +52,7 @@ class DocumentIndex:
44
  self.colgemma_model = None
45
  self.colgemma_processor = None
46
 
 
47
  doc_index = DocumentIndex()
48
 
49
 
@@ -113,7 +122,7 @@ def index_bigemma_images(images: List[Image.Image]):
113
  # Process in smaller batches to avoid memory issues
114
  batch_size = 2
115
  for i in range(0, len(images), batch_size):
116
- batch = images[i:i+batch_size]
117
  batch_images = processor.process_images(batch).to(device)
118
 
119
  with torch.no_grad():
@@ -122,7 +131,9 @@ def index_bigemma_images(images: List[Image.Image]):
122
 
123
  # Concatenate all embeddings
124
  all_embeddings = torch.cat(embeddings_list, dim=0)
125
- print(f"βœ“ Indexed {len(images)} pages with BiGemma3 (shape: {all_embeddings.shape})")
 
 
126
 
127
  return all_embeddings
128
 
@@ -138,7 +149,7 @@ def index_colgemma_images(images: List[Image.Image]):
138
  # Process in smaller batches to avoid memory issues
139
  batch_size = 2
140
  for i in range(0, len(images), batch_size):
141
- batch = images[i:i+batch_size]
142
  batch_images = processor.process_images(batch).to(device)
143
 
144
  with torch.no_grad():
@@ -147,7 +158,9 @@ def index_colgemma_images(images: List[Image.Image]):
147
 
148
  # Concatenate all embeddings
149
  all_embeddings = torch.cat(embeddings_list, dim=0)
150
- print(f"βœ“ Indexed {len(images)} pages with ColGemma3 (shape: {all_embeddings.shape})")
 
 
151
 
152
  return all_embeddings
153
 
@@ -182,11 +195,14 @@ def index_document(pdf_files, model_choice: str) -> str:
182
  doc_index.colgemma_embeddings = index_colgemma_images(doc_index.images)
183
  status_messages.append("βœ“ Indexed with ColGemma3")
184
 
185
- final_status = "\n".join(status_messages) + "\n\nβœ… Document ready for querying!"
 
 
186
  return final_status
187
 
188
  except Exception as e:
189
  import traceback
 
190
  error_details = traceback.format_exc()
191
  print(f"Indexing error: {error_details}")
192
  return f"❌ Error indexing document: {str(e)}"
@@ -211,14 +227,18 @@ def generate_colgemma_heatmap(
211
  image_mask = batch_images["input_ids"] == image_token_id
212
  else:
213
  image_mask = torch.ones(
214
- image_embedding.shape[0], image_embedding.shape[1],
215
- dtype=torch.bool, device=device
 
 
216
  )
217
 
218
  # Calculate n_patches
219
  num_image_tokens = image_mask.sum().item()
220
  n_side = int(math.sqrt(num_image_tokens))
221
- n_patches = (n_side, n_side) if n_side * n_side == num_image_tokens else (16, 16)
 
 
222
 
223
  # Generate similarity maps
224
  similarity_maps_list = get_similarity_maps_from_embeddings(
@@ -235,12 +255,14 @@ def generate_colgemma_heatmap(
235
 
236
  # Create heatmap overlay
237
  img_array = np.array(image.convert("RGBA"))
238
- similarity_map_array = normalize_similarity_map(aggregated_map).to(torch.float32).cpu().numpy()
 
 
239
  similarity_map_array = rearrange(similarity_map_array, "h w -> w h")
240
 
241
- similarity_map_image = Image.fromarray((similarity_map_array * 255).astype("uint8")).resize(
242
- image.size, Image.Resampling.BICUBIC
243
- )
244
 
245
  # Create matplotlib figure
246
  fig, ax = plt.subplots(figsize=(10, 10))
@@ -287,7 +309,12 @@ def query_documents(
287
  # Query with BiGemma3
288
  if model_choice in ["NetraEmbed (BiGemma3)", "Both"]:
289
  if doc_index.bigemma_embeddings is None:
290
- return None, "⚠️ Please index the document with BiGemma3 first.", None, None
 
 
 
 
 
291
 
292
  model, processor = load_bigemma_model()
293
 
@@ -311,15 +338,27 @@ def query_documents(
311
  bigemma_text = "### BiGemma3 (NetraEmbed) Results\n\n"
312
  for rank, idx in enumerate(top_indices):
313
  score = scores[0, idx].item()
314
- bigemma_text += f"**Rank {rank + 1}:** Page {idx.item() + 1} - Score: {score:.4f}\n"
 
 
315
  bigemma_results.append(
316
- (doc_index.images[idx.item()], f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.4f})")
 
 
 
317
  )
318
 
319
  # Query with ColGemma3
320
  if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]:
321
  if doc_index.colgemma_embeddings is None:
322
- return bigemma_results if bigemma_results else None, bigemma_text if bigemma_text else "⚠️ Please index the document with ColGemma3 first.", None, None
 
 
 
 
 
 
 
323
 
324
  model, processor = load_colgemma_model()
325
 
@@ -343,7 +382,9 @@ def query_documents(
343
  colgemma_text = "### ColGemma3 (ColNetraEmbed) Results\n\n"
344
  for rank, idx in enumerate(top_indices):
345
  score = scores[0, idx].item()
346
- colgemma_text += f"**Rank {rank + 1}:** Page {idx.item() + 1} - Score: {score:.2f}\n"
 
 
347
 
348
  # Generate heatmap if requested
349
  if show_heatmap:
@@ -353,11 +394,17 @@ def query_documents(
353
  image_embedding=doc_index.colgemma_embeddings[idx.item()],
354
  )
355
  colgemma_results.append(
356
- (heatmap_image, f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.2f})")
 
 
 
357
  )
358
  else:
359
  colgemma_results.append(
360
- (doc_index.images[idx.item()], f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.2f})")
 
 
 
361
  )
362
 
363
  # Return results based on model choice
@@ -370,6 +417,7 @@ def query_documents(
370
 
371
  except Exception as e:
372
  import traceback
 
373
  error_details = traceback.format_exc()
374
  print(f"Query error: {error_details}")
375
  return None, f"❌ Error during query: {str(e)}", None, None
@@ -390,14 +438,14 @@ with gr.Blocks(title="NetraEmbed Demo") as demo:
390
  <a href="https://github.com/adithya-s-k/colpali" target="_blank">
391
  <img src="https://img.shields.io/badge/GitHub-colpali-181717?logo=github" alt="GitHub">
392
  </a>
393
- <a href="https://huggingface.co/Cognitive-Lab/ColNetraEmbed" target="_blank">
394
  <img src="https://img.shields.io/badge/πŸ€—%20HuggingFace-Model-yellow" alt="Model">
395
  </a>
396
  <a href="https://www.cognitivelab.in/blog/introducing-netraembed" target="_blank">
397
  <img src="https://img.shields.io/badge/Blog-CognitiveLab-blue" alt="Blog">
398
  </a>
399
- <a href="https://cloud.cognitivelab.in" target="_blank">
400
- <img src="https://img.shields.io/badge/Demo-Try%20it%20out-green" alt="Demo">
401
  </a>
402
  </div>
403
  """
@@ -443,9 +491,7 @@ with gr.Blocks(title="NetraEmbed Demo") as demo:
443
  )
444
 
445
  pdf_upload = gr.File(
446
- label="Upload PDFs",
447
- file_types=[".pdf"],
448
- file_count="multiple"
449
  )
450
  index_btn = gr.Button("πŸ“₯ Index Documents", variant="primary", size="sm")
451
 
@@ -531,7 +577,12 @@ with gr.Blocks(title="NetraEmbed Demo") as demo:
531
  query_btn.click(
532
  fn=query_documents,
533
  inputs=[query_input, model_select, top_k_slider, heatmap_checkbox],
534
- outputs=[bigemma_gallery, bigemma_results_text, colgemma_gallery, colgemma_results_text],
 
 
 
 
 
535
  )
536
 
537
  # Enable queue for handling multiple requests
 
23
  from einops import rearrange
24
 
25
  # Import from colpali_engine
26
+ from colpali_engine.models import (
27
+ BiGemma3,
28
+ BiGemmaProcessor3,
29
+ ColGemma3,
30
+ ColGemmaProcessor3,
31
+ )
32
  from colpali_engine.interpretability import get_similarity_maps_from_embeddings
33
+ from colpali_engine.interpretability.similarity_map_utils import (
34
+ normalize_similarity_map,
35
+ )
36
 
37
  device = "cuda" if torch.cuda.is_available() else "cpu"
38
 
 
40
  if torch.cuda.is_available():
41
  print(f"GPU: {torch.cuda.get_device_name(0)}")
42
 
43
+
44
  # Global state for models and indexed documents
45
  class DocumentIndex:
46
  def __init__(self):
 
52
  self.colgemma_model = None
53
  self.colgemma_processor = None
54
 
55
+
56
  doc_index = DocumentIndex()
57
 
58
 
 
122
  # Process in smaller batches to avoid memory issues
123
  batch_size = 2
124
  for i in range(0, len(images), batch_size):
125
+ batch = images[i : i + batch_size]
126
  batch_images = processor.process_images(batch).to(device)
127
 
128
  with torch.no_grad():
 
131
 
132
  # Concatenate all embeddings
133
  all_embeddings = torch.cat(embeddings_list, dim=0)
134
+ print(
135
+ f"βœ“ Indexed {len(images)} pages with BiGemma3 (shape: {all_embeddings.shape})"
136
+ )
137
 
138
  return all_embeddings
139
 
 
149
  # Process in smaller batches to avoid memory issues
150
  batch_size = 2
151
  for i in range(0, len(images), batch_size):
152
+ batch = images[i : i + batch_size]
153
  batch_images = processor.process_images(batch).to(device)
154
 
155
  with torch.no_grad():
 
158
 
159
  # Concatenate all embeddings
160
  all_embeddings = torch.cat(embeddings_list, dim=0)
161
+ print(
162
+ f"βœ“ Indexed {len(images)} pages with ColGemma3 (shape: {all_embeddings.shape})"
163
+ )
164
 
165
  return all_embeddings
166
 
 
195
  doc_index.colgemma_embeddings = index_colgemma_images(doc_index.images)
196
  status_messages.append("βœ“ Indexed with ColGemma3")
197
 
198
+ final_status = (
199
+ "\n".join(status_messages) + "\n\nβœ… Document ready for querying!"
200
+ )
201
  return final_status
202
 
203
  except Exception as e:
204
  import traceback
205
+
206
  error_details = traceback.format_exc()
207
  print(f"Indexing error: {error_details}")
208
  return f"❌ Error indexing document: {str(e)}"
 
227
  image_mask = batch_images["input_ids"] == image_token_id
228
  else:
229
  image_mask = torch.ones(
230
+ image_embedding.shape[0],
231
+ image_embedding.shape[1],
232
+ dtype=torch.bool,
233
+ device=device,
234
  )
235
 
236
  # Calculate n_patches
237
  num_image_tokens = image_mask.sum().item()
238
  n_side = int(math.sqrt(num_image_tokens))
239
+ n_patches = (
240
+ (n_side, n_side) if n_side * n_side == num_image_tokens else (16, 16)
241
+ )
242
 
243
  # Generate similarity maps
244
  similarity_maps_list = get_similarity_maps_from_embeddings(
 
255
 
256
  # Create heatmap overlay
257
  img_array = np.array(image.convert("RGBA"))
258
+ similarity_map_array = (
259
+ normalize_similarity_map(aggregated_map).to(torch.float32).cpu().numpy()
260
+ )
261
  similarity_map_array = rearrange(similarity_map_array, "h w -> w h")
262
 
263
+ similarity_map_image = Image.fromarray(
264
+ (similarity_map_array * 255).astype("uint8")
265
+ ).resize(image.size, Image.Resampling.BICUBIC)
266
 
267
  # Create matplotlib figure
268
  fig, ax = plt.subplots(figsize=(10, 10))
 
309
  # Query with BiGemma3
310
  if model_choice in ["NetraEmbed (BiGemma3)", "Both"]:
311
  if doc_index.bigemma_embeddings is None:
312
+ return (
313
+ None,
314
+ "⚠️ Please index the document with BiGemma3 first.",
315
+ None,
316
+ None,
317
+ )
318
 
319
  model, processor = load_bigemma_model()
320
 
 
338
  bigemma_text = "### BiGemma3 (NetraEmbed) Results\n\n"
339
  for rank, idx in enumerate(top_indices):
340
  score = scores[0, idx].item()
341
+ bigemma_text += (
342
+ f"**Rank {rank + 1}:** Page {idx.item() + 1} - Score: {score:.4f}\n"
343
+ )
344
  bigemma_results.append(
345
+ (
346
+ doc_index.images[idx.item()],
347
+ f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.4f})",
348
+ )
349
  )
350
 
351
  # Query with ColGemma3
352
  if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]:
353
  if doc_index.colgemma_embeddings is None:
354
+ return (
355
+ bigemma_results if bigemma_results else None,
356
+ bigemma_text
357
+ if bigemma_text
358
+ else "⚠️ Please index the document with ColGemma3 first.",
359
+ None,
360
+ None,
361
+ )
362
 
363
  model, processor = load_colgemma_model()
364
 
 
382
  colgemma_text = "### ColGemma3 (ColNetraEmbed) Results\n\n"
383
  for rank, idx in enumerate(top_indices):
384
  score = scores[0, idx].item()
385
+ colgemma_text += (
386
+ f"**Rank {rank + 1}:** Page {idx.item() + 1} - Score: {score:.2f}\n"
387
+ )
388
 
389
  # Generate heatmap if requested
390
  if show_heatmap:
 
394
  image_embedding=doc_index.colgemma_embeddings[idx.item()],
395
  )
396
  colgemma_results.append(
397
+ (
398
+ heatmap_image,
399
+ f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.2f})",
400
+ )
401
  )
402
  else:
403
  colgemma_results.append(
404
+ (
405
+ doc_index.images[idx.item()],
406
+ f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.2f})",
407
+ )
408
  )
409
 
410
  # Return results based on model choice
 
417
 
418
  except Exception as e:
419
  import traceback
420
+
421
  error_details = traceback.format_exc()
422
  print(f"Query error: {error_details}")
423
  return None, f"❌ Error during query: {str(e)}", None, None
 
438
  <a href="https://github.com/adithya-s-k/colpali" target="_blank">
439
  <img src="https://img.shields.io/badge/GitHub-colpali-181717?logo=github" alt="GitHub">
440
  </a>
441
+ <a href="https://huggingface.co/Cognitive-Lab/NetraEmbed" target="_blank">
442
  <img src="https://img.shields.io/badge/πŸ€—%20HuggingFace-Model-yellow" alt="Model">
443
  </a>
444
  <a href="https://www.cognitivelab.in/blog/introducing-netraembed" target="_blank">
445
  <img src="https://img.shields.io/badge/Blog-CognitiveLab-blue" alt="Blog">
446
  </a>
447
+ <a href="https://huggingface.co/spaces/AdithyaSK/NetraEmbed" target="_blank">
448
+ <img src="https://img.shields.io/badge/πŸ€—%20Demo-HuggingFace%20Space-yellow" alt="Demo">
449
  </a>
450
  </div>
451
  """
 
491
  )
492
 
493
  pdf_upload = gr.File(
494
+ label="Upload PDFs", file_types=[".pdf"], file_count="multiple"
 
 
495
  )
496
  index_btn = gr.Button("πŸ“₯ Index Documents", variant="primary", size="sm")
497
 
 
577
  query_btn.click(
578
  fn=query_documents,
579
  inputs=[query_input, model_select, top_k_slider, heatmap_checkbox],
580
+ outputs=[
581
+ bigemma_gallery,
582
+ bigemma_results_text,
583
+ colgemma_gallery,
584
+ colgemma_results_text,
585
+ ],
586
  )
587
 
588
  # Enable queue for handling multiple requests