import gradio as gr import torch from visual_bge.modeling import Visualized_BGE from qdrant_client import QdrantClient from qdrant_client.http.models import Filter, FieldCondition, MatchValue from PIL import Image from io import BytesIO import requests import matplotlib.pyplot as plt from matplotlib import font_manager import textwrap import os import tempfile from huggingface_hub import hf_hub_download model_weight = hf_hub_download(repo_id="BAAI/bge-visualized", filename="Visualized_m3.pth") # Load Thai font thai_font = font_manager.FontProperties(fname='./Sarabun-Regular.ttf') # Load model model = Visualized_BGE( model_name_bge="BAAI/bge-m3", model_weight=model_weight ) # Load Qdrant connection qdrant_client = QdrantClient( url=os.environ.get("QDRANT_URL"), api_key=os.environ.get("QDRANT_API_KEY") ) # Visual helper function def visualize_results(results): cols = 4 rows = (len(results) + cols - 1) // cols fig, axs = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3)) axs = axs.flatten() if hasattr(axs, 'flatten') else [axs] for i, res in enumerate(results): try: image_url = res.payload['image_url'] img = Image.open(image_url) if os.path.exists(image_url) else Image.open(BytesIO(requests.get(image_url).content)) name = res.payload['name'] if len(name) > 30: name = name[:27] + "..." wrapped_name = textwrap.fill(name, width=15) axs[i].imshow(img) axs[i].set_title(f"{wrapped_name}\nScore: {res.score:.2f}", fontproperties=thai_font, fontsize=10) axs[i].axis('off') except Exception as e: axs[i].text(0.5, 0.5, f'Error: {str(e)}', ha='center', va='center', fontsize=8) axs[i].axis('off') for j in range(len(results), len(axs)): axs[j].axis('off') plt.tight_layout(pad=3.0) plt.subplots_adjust(hspace=0.5) with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile: fig.savefig(tmpfile.name) return tmpfile.name # Text Query Handler def search_by_text(text_input): if not text_input.strip(): return "Please provide a text input.", None query_vector = model.encode(text=text_input)[0].tolist() results = qdrant_client.query_points( collection_name="bge_visualized_m3_demo", query=query_vector, with_payload=True, ).points image_path = visualize_results(results) return f"Results for: {text_input}", image_path # Image Query Handler def search_by_image(image_input): if image_input is None: return "Please upload an image.", None query_vector = model.encode(image=image_input)[0].tolist() results = qdrant_client.query_points( collection_name="bge_visualized_m3", query=query_vector, with_payload=True, ).points image_path = visualize_results(results) return "Results for image query", image_path # Gradio UI with gr.Blocks() as demo: gr.Markdown("# 🔍 Visualized BGE: Multimodal Search with Qdrant") with gr.Tab("📝 Text Query"): text_input = gr.Textbox(label="Enter text to search") text_output = gr.Textbox(label="Query Info") text_image = gr.Image(label="Results", type="filepath") text_btn = gr.Button("Search") text_btn.click(fn=search_by_text, inputs=text_input, outputs=[text_output, text_image]) with gr.Tab("🖼️ Image Query"): image_input = gr.Image(label="Upload image to search", type="pil") image_output = gr.Textbox(label="Query Info") image_result = gr.Image(label="Results", type="filepath") image_btn = gr.Button("Search") image_btn.click(fn=search_by_image, inputs=image_input, outputs=[image_output, image_result]) demo.launch()