Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from visual_bge.modeling import Visualized_BGE | |
from qdrant_client import QdrantClient | |
from qdrant_client.http.models import Filter, FieldCondition, MatchValue | |
from PIL import Image | |
from io import BytesIO | |
import requests | |
import matplotlib.pyplot as plt | |
from matplotlib import font_manager | |
import textwrap | |
import os | |
import tempfile | |
from huggingface_hub import hf_hub_download | |
model_weight = hf_hub_download(repo_id="BAAI/bge-visualized", filename="Visualized_m3.pth") | |
# Load Thai font | |
thai_font = font_manager.FontProperties(fname='./Sarabun-Regular.ttf') | |
# Load model | |
model = Visualized_BGE( | |
model_name_bge="BAAI/bge-m3", | |
model_weight=model_weight | |
) | |
# Load Qdrant connection | |
qdrant_client = QdrantClient( | |
url=os.environ.get("QDRANT_URL"), | |
api_key=os.environ.get("QDRANT_API_KEY") | |
) | |
# Visual helper function | |
def visualize_results(results): | |
cols = 4 | |
rows = (len(results) + cols - 1) // cols | |
fig, axs = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3)) | |
axs = axs.flatten() if hasattr(axs, 'flatten') else [axs] | |
for i, res in enumerate(results): | |
try: | |
image_url = res.payload['image_url'] | |
img = Image.open(image_url) if os.path.exists(image_url) else Image.open(BytesIO(requests.get(image_url).content)) | |
name = res.payload['name'] | |
if len(name) > 30: | |
name = name[:27] + "..." | |
wrapped_name = textwrap.fill(name, width=15) | |
axs[i].imshow(img) | |
axs[i].set_title(f"{wrapped_name}\nScore: {res.score:.2f}", fontproperties=thai_font, fontsize=10) | |
axs[i].axis('off') | |
except Exception as e: | |
axs[i].text(0.5, 0.5, f'Error: {str(e)}', ha='center', va='center', fontsize=8) | |
axs[i].axis('off') | |
for j in range(len(results), len(axs)): | |
axs[j].axis('off') | |
plt.tight_layout(pad=3.0) | |
plt.subplots_adjust(hspace=0.5) | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile: | |
fig.savefig(tmpfile.name) | |
return tmpfile.name | |
# Text Query Handler | |
def search_by_text(text_input): | |
if not text_input.strip(): | |
return "Please provide a text input.", None | |
query_vector = model.encode(text=text_input)[0].tolist() | |
results = qdrant_client.query_points( | |
collection_name="bge_visualized_m3_demo", | |
query=query_vector, | |
with_payload=True, | |
).points | |
image_path = visualize_results(results) | |
return f"Results for: {text_input}", image_path | |
# Image Query Handler | |
def search_by_image(image_input): | |
if image_input is None: | |
return "Please upload an image.", None | |
query_vector = model.encode(image=image_input)[0].tolist() | |
results = qdrant_client.query_points( | |
collection_name="bge_visualized_m3", | |
query=query_vector, | |
with_payload=True, | |
).points | |
image_path = visualize_results(results) | |
return "Results for image query", image_path | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("# π Visualized BGE: Multimodal Search with Qdrant") | |
with gr.Tab("π Text Query"): | |
text_input = gr.Textbox(label="Enter text to search") | |
text_output = gr.Textbox(label="Query Info") | |
text_image = gr.Image(label="Results", type="filepath") | |
text_btn = gr.Button("Search") | |
text_btn.click(fn=search_by_text, inputs=text_input, outputs=[text_output, text_image]) | |
with gr.Tab("πΌοΈ Image Query"): | |
image_input = gr.Image(label="Upload image to search", type="pil") | |
image_output = gr.Textbox(label="Query Info") | |
image_result = gr.Image(label="Results", type="filepath") | |
image_btn = gr.Button("Search") | |
image_btn.click(fn=search_by_image, inputs=image_input, outputs=[image_output, image_result]) | |
demo.launch() | |