Visualized_m3 / app.py
PlengRKO's picture
Update app.py
4a66d2e verified
import gradio as gr
import torch
from visual_bge.modeling import Visualized_BGE
from qdrant_client import QdrantClient
from qdrant_client.http.models import Filter, FieldCondition, MatchValue
from PIL import Image
from io import BytesIO
import requests
import matplotlib.pyplot as plt
from matplotlib import font_manager
import textwrap
import os
import tempfile
from huggingface_hub import hf_hub_download
model_weight = hf_hub_download(repo_id="BAAI/bge-visualized", filename="Visualized_m3.pth")
# Load Thai font
thai_font = font_manager.FontProperties(fname='./Sarabun-Regular.ttf')
# Load model
model = Visualized_BGE(
model_name_bge="BAAI/bge-m3",
model_weight=model_weight
)
# Load Qdrant connection
qdrant_client = QdrantClient(
url=os.environ.get("QDRANT_URL"),
api_key=os.environ.get("QDRANT_API_KEY")
)
# Visual helper function
def visualize_results(results):
cols = 4
rows = (len(results) + cols - 1) // cols
fig, axs = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))
axs = axs.flatten() if hasattr(axs, 'flatten') else [axs]
for i, res in enumerate(results):
try:
image_url = res.payload['image_url']
img = Image.open(image_url) if os.path.exists(image_url) else Image.open(BytesIO(requests.get(image_url).content))
name = res.payload['name']
if len(name) > 30:
name = name[:27] + "..."
wrapped_name = textwrap.fill(name, width=15)
axs[i].imshow(img)
axs[i].set_title(f"{wrapped_name}\nScore: {res.score:.2f}", fontproperties=thai_font, fontsize=10)
axs[i].axis('off')
except Exception as e:
axs[i].text(0.5, 0.5, f'Error: {str(e)}', ha='center', va='center', fontsize=8)
axs[i].axis('off')
for j in range(len(results), len(axs)):
axs[j].axis('off')
plt.tight_layout(pad=3.0)
plt.subplots_adjust(hspace=0.5)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
fig.savefig(tmpfile.name)
return tmpfile.name
# Text Query Handler
def search_by_text(text_input):
if not text_input.strip():
return "Please provide a text input.", None
query_vector = model.encode(text=text_input)[0].tolist()
results = qdrant_client.query_points(
collection_name="bge_visualized_m3_demo",
query=query_vector,
with_payload=True,
).points
image_path = visualize_results(results)
return f"Results for: {text_input}", image_path
# Image Query Handler
def search_by_image(image_input):
if image_input is None:
return "Please upload an image.", None
query_vector = model.encode(image=image_input)[0].tolist()
results = qdrant_client.query_points(
collection_name="bge_visualized_m3",
query=query_vector,
with_payload=True,
).points
image_path = visualize_results(results)
return "Results for image query", image_path
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# πŸ” Visualized BGE: Multimodal Search with Qdrant")
with gr.Tab("πŸ“ Text Query"):
text_input = gr.Textbox(label="Enter text to search")
text_output = gr.Textbox(label="Query Info")
text_image = gr.Image(label="Results", type="filepath")
text_btn = gr.Button("Search")
text_btn.click(fn=search_by_text, inputs=text_input, outputs=[text_output, text_image])
with gr.Tab("πŸ–ΌοΈ Image Query"):
image_input = gr.Image(label="Upload image to search", type="pil")
image_output = gr.Textbox(label="Query Info")
image_result = gr.Image(label="Results", type="filepath")
image_btn = gr.Button("Search")
image_btn.click(fn=search_by_image, inputs=image_input, outputs=[image_output, image_result])
demo.launch()