File size: 2,564 Bytes
75fdda9
171a063
 
8c1ee79
75fdda9
8c1ee79
171a063
 
 
75fdda9
85a27c5
8c1ee79
 
85a27c5
8c1ee79
85a27c5
8c1ee79
85a27c5
8c1ee79
 
 
 
 
75fdda9
8c1ee79
75fdda9
8c1ee79
75fdda9
8c1ee79
85a27c5
171a063
85a27c5
8c1ee79
 
 
85a27c5
8c1ee79
 
171a063
85a27c5
8c1ee79
85a27c5
171a063
8c1ee79
 
 
 
 
 
 
171a063
 
8c1ee79
 
 
 
 
 
 
75fdda9
8c1ee79
 
75fdda9
171a063
 
75fdda9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr
import torch
from transformers import CLIPProcessor, CLIPModel
import re

# Load FashionCLIP model
model_name = "patrickjohncyh/fashion-clip"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)

# Regex for price extraction
price_pattern = re.compile(r'(\bunder\b|\babove\b|\bbelow\b|\bbetween\b)?\s?(\d{1,5})\s?(AED|USD|EUR)?', re.IGNORECASE)

def get_text_embedding(text_list):
    """
    Converts a list of input texts into embeddings using FashionCLIP.
    """
    inputs = processor(text=text_list, return_tensors="pt", padding=True)  # Corrected input format
    with torch.no_grad():
        text_embedding = model.get_text_features(**inputs)
    return text_embedding

def extract_attributes(query):
    """
    Extract structured fashion attributes dynamically using FashionCLIP.
    """
    structured_output = {"Brand": "Unknown", "Category": "Unknown", "Gender": "Unknown", "Price": "Unknown"}

    # Get embedding for the query
    query_embedding = get_text_embedding([query])

    # Reference labels for classification
    reference_labels = ["Brand", "Category", "Gender", "Price"]
    reference_embeddings = get_text_embedding(reference_labels)

    # Compute cosine similarity
    similarities = torch.nn.functional.cosine_similarity(query_embedding, reference_embeddings)
    best_match_index = similarities.argmax().item()

    # Assign attribute dynamically
    attribute_type = reference_labels[best_match_index]
    structured_output[attribute_type] = query  # Assigns the query text to the detected attribute

    # Extract price dynamically
    price_match = price_pattern.search(query)
    if price_match:
        condition, amount, currency = price_match.groups()
        structured_output["Price"] = f"{condition.capitalize() if condition else ''} {amount} {currency if currency else 'AED'}".strip()

    return structured_output

# Define Gradio UI
def parse_query(user_query):
    """
    Takes user query and returns structured attributes dynamically.
    """
    parsed_output = extract_attributes(user_query)
    return parsed_output  # Returns structured JSON

with gr.Blocks() as demo:
    gr.Markdown("# 🛍️ Fashion Query Parser using FashionCLIP")

    query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., Gucci men’s perfume under 200AED")
    output_box = gr.JSON(label="Parsed Output")

    parse_button = gr.Button("Parse Query")
    parse_button.click(parse_query, inputs=[query_input], outputs=[output_box])

demo.launch()