Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| NetraEmbed Demo - Document Retrieval with BiGemma3 and ColGemma3 | |
| This demo allows you to: | |
| 1. Select a model (NetraEmbed, ColNetraEmbed, or Both) | |
| 2. Upload PDF files and index them | |
| 3. Search for relevant pages based on your query | |
| HuggingFace Spaces deployment with ZeroGPU support. | |
| """ | |
| import spaces | |
| import torch | |
| import gradio as gr | |
| from pdf2image import convert_from_path | |
| from PIL import Image | |
| from typing import List, Tuple, Optional | |
| import math | |
| import io | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import seaborn as sns | |
| from einops import rearrange | |
| # Import from colpali_engine | |
| from colpali_engine.models import ( | |
| BiGemma3, | |
| BiGemmaProcessor3, | |
| ColGemma3, | |
| ColGemmaProcessor3, | |
| ) | |
| from colpali_engine.interpretability import get_similarity_maps_from_embeddings | |
| from colpali_engine.interpretability.similarity_map_utils import ( | |
| normalize_similarity_map, | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Device: {device}") | |
| if torch.cuda.is_available(): | |
| print(f"GPU: {torch.cuda.get_device_name(0)}") | |
| # Global state for models and indexed documents | |
| class DocumentIndex: | |
| def __init__(self): | |
| self.images: List[Image.Image] = [] | |
| self.bigemma_embeddings = None | |
| self.colgemma_embeddings = None | |
| self.bigemma_model = None | |
| self.bigemma_processor = None | |
| self.colgemma_model = None | |
| self.colgemma_processor = None | |
| doc_index = DocumentIndex() | |
| def load_bigemma_model(): | |
| """Load BiGemma3 model and processor.""" | |
| if doc_index.bigemma_model is None: | |
| print("Loading BiGemma3 (NetraEmbed)...") | |
| doc_index.bigemma_processor = BiGemmaProcessor3.from_pretrained( | |
| "Cognitive-Lab/NetraEmbed", | |
| use_fast=True, | |
| ) | |
| doc_index.bigemma_model = BiGemma3.from_pretrained( | |
| "Cognitive-Lab/NetraEmbed", | |
| torch_dtype=torch.bfloat16, | |
| device_map=device, | |
| ).eval() | |
| print("β BiGemma3 loaded successfully") | |
| return doc_index.bigemma_model, doc_index.bigemma_processor | |
| def load_colgemma_model(): | |
| """Load ColGemma3 model and processor.""" | |
| if doc_index.colgemma_model is None: | |
| print("Loading ColGemma3 (ColNetraEmbed)...") | |
| doc_index.colgemma_model = ColGemma3.from_pretrained( | |
| "Cognitive-Lab/ColNetraEmbed", | |
| dtype=torch.bfloat16, | |
| device_map=device, | |
| ).eval() | |
| doc_index.colgemma_processor = ColGemmaProcessor3.from_pretrained( | |
| "Cognitive-Lab/ColNetraEmbed", | |
| use_fast=True, | |
| ) | |
| print("β ColGemma3 loaded successfully") | |
| return doc_index.colgemma_model, doc_index.colgemma_processor | |
| def pdf_to_images(pdf_paths: List[str]) -> List[Image.Image]: | |
| """Convert PDF files to list of PIL Images.""" | |
| images = [] | |
| for pdf_path in pdf_paths: | |
| try: | |
| print(f"Converting PDF to images: {pdf_path}") | |
| page_images = convert_from_path(pdf_path, dpi=200) | |
| images.extend(page_images) | |
| print(f"Converted {len(page_images)} pages from {pdf_path}") | |
| except Exception as e: | |
| print(f"β PDF conversion error for {pdf_path}: {str(e)}") | |
| raise gr.Error(f"Failed to convert PDF: {str(e)}") | |
| if len(images) >= 150: | |
| raise gr.Error("The number of images should be less than 150.") | |
| return images | |
| def index_bigemma_images(images: List[Image.Image]): | |
| """Index images with BiGemma3.""" | |
| model, processor = load_bigemma_model() | |
| print(f"Indexing {len(images)} images with BiGemma3...") | |
| embeddings_list = [] | |
| # Process in smaller batches to avoid memory issues | |
| batch_size = 2 | |
| for i in range(0, len(images), batch_size): | |
| batch = images[i : i + batch_size] | |
| batch_images = processor.process_images(batch).to(device) | |
| with torch.no_grad(): | |
| embeddings = model(**batch_images, embedding_dim=768) | |
| embeddings_list.append(embeddings.cpu()) | |
| # Concatenate all embeddings | |
| all_embeddings = torch.cat(embeddings_list, dim=0) | |
| print( | |
| f"β Indexed {len(images)} pages with BiGemma3 (shape: {all_embeddings.shape})" | |
| ) | |
| return all_embeddings | |
| def index_colgemma_images(images: List[Image.Image]): | |
| """Index images with ColGemma3.""" | |
| model, processor = load_colgemma_model() | |
| print(f"Indexing {len(images)} images with ColGemma3...") | |
| embeddings_list = [] | |
| # Process in smaller batches to avoid memory issues | |
| batch_size = 2 | |
| for i in range(0, len(images), batch_size): | |
| batch = images[i : i + batch_size] | |
| batch_images = processor.process_images(batch).to(device) | |
| with torch.no_grad(): | |
| embeddings = model(**batch_images) | |
| embeddings_list.append(embeddings.cpu()) | |
| # Concatenate all embeddings | |
| all_embeddings = torch.cat(embeddings_list, dim=0) | |
| print( | |
| f"β Indexed {len(images)} pages with ColGemma3 (shape: {all_embeddings.shape})" | |
| ) | |
| return all_embeddings | |
| def index_document(pdf_files, model_choice: str) -> str: | |
| """Upload and index PDF documents.""" | |
| if not pdf_files: | |
| return "β οΈ Please upload PDF documents first." | |
| if not model_choice: | |
| return "β οΈ Please select a model first." | |
| try: | |
| status_messages = [] | |
| # Convert PDFs to images | |
| status_messages.append("β³ Converting PDFs to images...") | |
| pdf_paths = [f.name for f in pdf_files] | |
| doc_index.images = pdf_to_images(pdf_paths) | |
| num_pages = len(doc_index.images) | |
| status_messages.append(f"β Converted to {num_pages} images") | |
| # Index with BiGemma3 | |
| if model_choice in ["NetraEmbed (BiGemma3)", "Both"]: | |
| status_messages.append("β³ Indexing with BiGemma3...") | |
| doc_index.bigemma_embeddings = index_bigemma_images(doc_index.images) | |
| status_messages.append("β Indexed with BiGemma3") | |
| # Index with ColGemma3 | |
| if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]: | |
| status_messages.append("β³ Indexing with ColGemma3...") | |
| doc_index.colgemma_embeddings = index_colgemma_images(doc_index.images) | |
| status_messages.append("β Indexed with ColGemma3") | |
| final_status = ( | |
| "\n".join(status_messages) + "\n\nβ Document ready for querying!" | |
| ) | |
| return final_status | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| print(f"Indexing error: {error_details}") | |
| return f"β Error indexing document: {str(e)}" | |
| def generate_colgemma_heatmap( | |
| image: Image.Image, | |
| query_embedding: torch.Tensor, | |
| image_embedding: torch.Tensor, | |
| ) -> Image.Image: | |
| """Generate heatmap overlay for ColGemma3 results.""" | |
| try: | |
| model, processor = load_colgemma_model() | |
| # Re-process the single image | |
| batch_images = processor.process_images([image]).to(device) | |
| # Create image mask | |
| if "input_ids" in batch_images and hasattr(model.config, "image_token_id"): | |
| image_token_id = model.config.image_token_id | |
| image_mask = batch_images["input_ids"] == image_token_id | |
| else: | |
| image_mask = torch.ones( | |
| image_embedding.shape[0], | |
| image_embedding.shape[1], | |
| dtype=torch.bool, | |
| device=device, | |
| ) | |
| # Calculate n_patches | |
| num_image_tokens = image_mask.sum().item() | |
| n_side = int(math.sqrt(num_image_tokens)) | |
| n_patches = ( | |
| (n_side, n_side) if n_side * n_side == num_image_tokens else (16, 16) | |
| ) | |
| # Generate similarity maps | |
| similarity_maps_list = get_similarity_maps_from_embeddings( | |
| image_embeddings=image_embedding.unsqueeze(0).to(device), | |
| query_embeddings=query_embedding.to(device), | |
| n_patches=n_patches, | |
| image_mask=image_mask, | |
| ) | |
| similarity_map = similarity_maps_list[0] | |
| if similarity_map.dtype == torch.bfloat16: | |
| similarity_map = similarity_map.float() | |
| aggregated_map = torch.mean(similarity_map, dim=0) | |
| # Create heatmap overlay | |
| img_array = np.array(image.convert("RGBA")) | |
| similarity_map_array = ( | |
| normalize_similarity_map(aggregated_map).to(torch.float32).cpu().numpy() | |
| ) | |
| similarity_map_array = rearrange(similarity_map_array, "h w -> w h") | |
| similarity_map_image = Image.fromarray( | |
| (similarity_map_array * 255).astype("uint8") | |
| ).resize(image.size, Image.Resampling.BICUBIC) | |
| # Create matplotlib figure | |
| fig, ax = plt.subplots(figsize=(10, 10)) | |
| ax.imshow(img_array) | |
| ax.imshow( | |
| similarity_map_image, | |
| cmap=sns.color_palette("mako", as_cmap=True), | |
| alpha=0.5, | |
| ) | |
| ax.set_axis_off() | |
| plt.tight_layout() | |
| # Convert to PIL Image | |
| buffer = io.BytesIO() | |
| plt.savefig(buffer, format="png", dpi=150, bbox_inches="tight", pad_inches=0) | |
| buffer.seek(0) | |
| heatmap_image = Image.open(buffer).copy() | |
| plt.close() | |
| return heatmap_image | |
| except Exception as e: | |
| print(f"β Heatmap generation error: {str(e)}") | |
| return image | |
| def query_documents( | |
| query: str, model_choice: str, top_k: int, show_heatmap: bool = False | |
| ) -> Tuple[Optional[List], Optional[str], Optional[List], Optional[str]]: | |
| """Query the indexed documents.""" | |
| if not doc_index.images: | |
| return None, "β οΈ Please upload and index a document first.", None, None | |
| if not query.strip(): | |
| return None, "β οΈ Please enter a query.", None, None | |
| try: | |
| bigemma_results = [] | |
| bigemma_text = "" | |
| colgemma_results = [] | |
| colgemma_text = "" | |
| # Query with BiGemma3 | |
| if model_choice in ["NetraEmbed (BiGemma3)", "Both"]: | |
| if doc_index.bigemma_embeddings is None: | |
| return ( | |
| None, | |
| "β οΈ Please index the document with BiGemma3 first.", | |
| None, | |
| None, | |
| ) | |
| model, processor = load_bigemma_model() | |
| # Encode query | |
| batch_query = processor.process_texts([query]).to(device) | |
| with torch.no_grad(): | |
| query_embedding = model(**batch_query, embedding_dim=768) | |
| # Compute scores | |
| scores = processor.score( | |
| qs=[query_embedding[0].cpu()], | |
| ps=list(torch.unbind(doc_index.bigemma_embeddings)), | |
| device=device, | |
| ) | |
| # Get top-k results | |
| top_k_actual = min(top_k, len(doc_index.images)) | |
| top_indices = scores[0].argsort(descending=True)[:top_k_actual] | |
| # Format results | |
| bigemma_text = "### BiGemma3 (NetraEmbed) Results\n\n" | |
| for rank, idx in enumerate(top_indices): | |
| score = scores[0, idx].item() | |
| bigemma_text += ( | |
| f"**Rank {rank + 1}:** Page {idx.item() + 1} - Score: {score:.4f}\n" | |
| ) | |
| bigemma_results.append( | |
| ( | |
| doc_index.images[idx.item()], | |
| f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.4f})", | |
| ) | |
| ) | |
| # Query with ColGemma3 | |
| if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]: | |
| if doc_index.colgemma_embeddings is None: | |
| return ( | |
| bigemma_results if bigemma_results else None, | |
| bigemma_text | |
| if bigemma_text | |
| else "β οΈ Please index the document with ColGemma3 first.", | |
| None, | |
| None, | |
| ) | |
| model, processor = load_colgemma_model() | |
| # Encode query | |
| batch_query = processor.process_queries([query]).to(device) | |
| with torch.no_grad(): | |
| query_embedding = model(**batch_query) | |
| # Compute scores | |
| scores = processor.score_multi_vector( | |
| qs=[query_embedding[0].cpu()], | |
| ps=list(torch.unbind(doc_index.colgemma_embeddings)), | |
| device=device, | |
| ) | |
| # Get top-k results | |
| top_k_actual = min(top_k, len(doc_index.images)) | |
| top_indices = scores[0].argsort(descending=True)[:top_k_actual] | |
| # Format results | |
| colgemma_text = "### ColGemma3 (ColNetraEmbed) Results\n\n" | |
| for rank, idx in enumerate(top_indices): | |
| score = scores[0, idx].item() | |
| colgemma_text += ( | |
| f"**Rank {rank + 1}:** Page {idx.item() + 1} - Score: {score:.2f}\n" | |
| ) | |
| # Generate heatmap if requested | |
| if show_heatmap: | |
| heatmap_image = generate_colgemma_heatmap( | |
| image=doc_index.images[idx.item()], | |
| query_embedding=query_embedding, | |
| image_embedding=doc_index.colgemma_embeddings[idx.item()], | |
| ) | |
| colgemma_results.append( | |
| ( | |
| heatmap_image, | |
| f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.2f})", | |
| ) | |
| ) | |
| else: | |
| colgemma_results.append( | |
| ( | |
| doc_index.images[idx.item()], | |
| f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.2f})", | |
| ) | |
| ) | |
| # Return results based on model choice | |
| if model_choice == "NetraEmbed (BiGemma3)": | |
| return bigemma_results, bigemma_text, None, None | |
| elif model_choice == "ColNetraEmbed (ColGemma3)": | |
| return None, None, colgemma_results, colgemma_text | |
| else: # Both | |
| return bigemma_results, bigemma_text, colgemma_results, colgemma_text | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| print(f"Query error: {error_details}") | |
| return None, f"β Error during query: {str(e)}", None, None | |
| # Create Gradio interface | |
| with gr.Blocks(title="NetraEmbed Demo") as demo: | |
| # Header section | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("# NetraEmbed") | |
| gr.HTML( | |
| """ | |
| <div style="display: flex; gap: 8px; flex-wrap: wrap; margin-bottom: 15px;"> | |
| <a href="https://arxiv.org/abs/2512.03514" target="_blank"> | |
| <img src="https://img.shields.io/badge/arXiv-2512.03514-b31b1b.svg" alt="Paper"> | |
| </a> | |
| <a href="https://github.com/adithya-s-k/colpali" target="_blank"> | |
| <img src="https://img.shields.io/badge/GitHub-colpali-181717?logo=github" alt="GitHub"> | |
| </a> | |
| <a href="https://huggingface.co/Cognitive-Lab/NetraEmbed" target="_blank"> | |
| <img src="https://img.shields.io/badge/π€%20HuggingFace-Model-yellow" alt="Model"> | |
| </a> | |
| <a href="https://www.cognitivelab.in/blog/introducing-netraembed" target="_blank"> | |
| <img src="https://img.shields.io/badge/Blog-CognitiveLab-blue" alt="Blog"> | |
| </a> | |
| <a href="https://huggingface.co/spaces/AdithyaSK/NetraEmbed" target="_blank"> | |
| <img src="https://img.shields.io/badge/π€%20Demo-HuggingFace%20Space-yellow" alt="Demo"> | |
| </a> | |
| </div> | |
| """ | |
| ) | |
| gr.Markdown( | |
| """ | |
| **π Universal Multilingual Multimodal Document Retrieval** | |
| Upload a PDF document, select your model(s), and query using semantic search. | |
| **Available Models:** | |
| - **NetraEmbed (BiGemma3)**: Single-vector embedding with Matryoshka representation | |
| Fast retrieval with cosine similarity | |
| - **ColNetraEmbed (ColGemma3)**: Multi-vector embedding with late interaction | |
| High-quality retrieval with MaxSim scoring and attention heatmaps | |
| """ | |
| ) | |
| with gr.Column(scale=1): | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center;"> | |
| <img src="https://cdn-uploads.huggingface.co/production/uploads/6442d975ad54813badc1ddf7/-fYMikXhSuqRqm-UIdulK.png" | |
| alt="NetraEmbed Banner" | |
| style="width: 100%; height: auto; border-radius: 8px;"> | |
| </div> | |
| """ | |
| ) | |
| gr.Markdown("---") | |
| # Main interface | |
| with gr.Row(): | |
| # Column 1: Model & Upload | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π€ Select Model & Upload") | |
| model_select = gr.Radio( | |
| choices=["NetraEmbed (BiGemma3)", "ColNetraEmbed (ColGemma3)", "Both"], | |
| value="Both", | |
| label="Select Model(s)", | |
| ) | |
| pdf_upload = gr.File( | |
| label="Upload PDFs", file_types=[".pdf"], file_count="multiple" | |
| ) | |
| index_btn = gr.Button("π₯ Index Documents", variant="primary", size="sm") | |
| index_status = gr.Textbox( | |
| label="Indexing Status", | |
| lines=8, | |
| interactive=False, | |
| value="Select model and upload PDFs to start", | |
| ) | |
| # Column 2: Query & Results | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π Query Documents") | |
| query_input = gr.Textbox( | |
| label="Enter Query", | |
| placeholder="e.g., financial report, organizational structure...", | |
| lines=2, | |
| ) | |
| with gr.Row(): | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| label="Top K Results", | |
| scale=2, | |
| ) | |
| heatmap_checkbox = gr.Checkbox( | |
| label="Show Heatmaps (ColGemma3)", | |
| value=False, | |
| scale=1, | |
| ) | |
| query_btn = gr.Button("π Search", variant="primary", size="sm") | |
| gr.Markdown("---") | |
| gr.Markdown("### π Results") | |
| # Results section | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| bigemma_results_text = gr.Markdown( | |
| value="*BiGemma3 results will appear here...*", | |
| ) | |
| bigemma_gallery = gr.Gallery( | |
| label="BiGemma3 - Top Retrieved Pages", | |
| show_label=True, | |
| columns=2, | |
| height="auto", | |
| object_fit="contain", | |
| ) | |
| with gr.Column(scale=1): | |
| colgemma_results_text = gr.Markdown( | |
| value="*ColGemma3 results will appear here...*", | |
| ) | |
| colgemma_gallery = gr.Gallery( | |
| label="ColGemma3 - Top Retrieved Pages", | |
| show_label=True, | |
| columns=2, | |
| height="auto", | |
| object_fit="contain", | |
| ) | |
| # Tips | |
| with gr.Accordion("π‘ Tips", open=False): | |
| gr.Markdown( | |
| """ | |
| - **Both models**: Compare results side-by-side | |
| - **Scores**: BiGemma3 uses cosine similarity (-1 to 1), ColGemma3 uses MaxSim (higher is better) | |
| - **Heatmaps**: Enable to visualize ColGemma3 attention patterns (brighter = higher attention) | |
| - **Refresh**: If you change documents, refresh the page to clear the index | |
| """ | |
| ) | |
| # Event handlers | |
| index_btn.click( | |
| fn=index_document, | |
| inputs=[pdf_upload, model_select], | |
| outputs=[index_status], | |
| ) | |
| query_btn.click( | |
| fn=query_documents, | |
| inputs=[query_input, model_select, top_k_slider, heatmap_checkbox], | |
| outputs=[ | |
| bigemma_gallery, | |
| bigemma_results_text, | |
| colgemma_gallery, | |
| colgemma_results_text, | |
| ], | |
| ) | |
| # Enable queue for handling multiple requests | |
| demo.queue(max_size=20) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |