Spaces:
Runtime error
Runtime error
File size: 2,564 Bytes
75fdda9 171a063 8c1ee79 75fdda9 8c1ee79 171a063 75fdda9 85a27c5 8c1ee79 85a27c5 8c1ee79 85a27c5 8c1ee79 85a27c5 8c1ee79 75fdda9 8c1ee79 75fdda9 8c1ee79 75fdda9 8c1ee79 85a27c5 171a063 85a27c5 8c1ee79 85a27c5 8c1ee79 171a063 85a27c5 8c1ee79 85a27c5 171a063 8c1ee79 171a063 8c1ee79 75fdda9 8c1ee79 75fdda9 171a063 75fdda9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import gradio as gr
import torch
from transformers import CLIPProcessor, CLIPModel
import re
# Load FashionCLIP model
model_name = "patrickjohncyh/fashion-clip"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
# Regex for price extraction
price_pattern = re.compile(r'(\bunder\b|\babove\b|\bbelow\b|\bbetween\b)?\s?(\d{1,5})\s?(AED|USD|EUR)?', re.IGNORECASE)
def get_text_embedding(text_list):
"""
Converts a list of input texts into embeddings using FashionCLIP.
"""
inputs = processor(text=text_list, return_tensors="pt", padding=True) # Corrected input format
with torch.no_grad():
text_embedding = model.get_text_features(**inputs)
return text_embedding
def extract_attributes(query):
"""
Extract structured fashion attributes dynamically using FashionCLIP.
"""
structured_output = {"Brand": "Unknown", "Category": "Unknown", "Gender": "Unknown", "Price": "Unknown"}
# Get embedding for the query
query_embedding = get_text_embedding([query])
# Reference labels for classification
reference_labels = ["Brand", "Category", "Gender", "Price"]
reference_embeddings = get_text_embedding(reference_labels)
# Compute cosine similarity
similarities = torch.nn.functional.cosine_similarity(query_embedding, reference_embeddings)
best_match_index = similarities.argmax().item()
# Assign attribute dynamically
attribute_type = reference_labels[best_match_index]
structured_output[attribute_type] = query # Assigns the query text to the detected attribute
# Extract price dynamically
price_match = price_pattern.search(query)
if price_match:
condition, amount, currency = price_match.groups()
structured_output["Price"] = f"{condition.capitalize() if condition else ''} {amount} {currency if currency else 'AED'}".strip()
return structured_output
# Define Gradio UI
def parse_query(user_query):
"""
Takes user query and returns structured attributes dynamically.
"""
parsed_output = extract_attributes(user_query)
return parsed_output # Returns structured JSON
with gr.Blocks() as demo:
gr.Markdown("# 🛍️ Fashion Query Parser using FashionCLIP")
query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., Gucci men’s perfume under 200AED")
output_box = gr.JSON(label="Parsed Output")
parse_button = gr.Button("Parse Query")
parse_button.click(parse_query, inputs=[query_input], outputs=[output_box])
demo.launch() |