Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,17 +1,14 @@
|
|
1 |
from flask import Flask, request, render_template
|
2 |
-
from
|
3 |
import re
|
4 |
|
5 |
app = Flask(__name__)
|
6 |
|
7 |
-
#
|
8 |
-
|
9 |
-
"ner",
|
10 |
-
model="dslim/bert-base-NER", # Verified public model
|
11 |
-
aggregation_strategy="simple"
|
12 |
-
)
|
13 |
|
14 |
-
def
|
|
|
15 |
result = {
|
16 |
"Brand": None,
|
17 |
"Category": None,
|
@@ -19,37 +16,70 @@ def extract_entities(query):
|
|
19 |
"Price": None
|
20 |
}
|
21 |
|
22 |
-
#
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
# Add keyword-based extraction for other fields
|
31 |
-
query_lower = query.lower()
|
32 |
-
if "perfume" in query_lower or "cologne" in query_lower:
|
33 |
-
result["Category"] = "Perfume"
|
34 |
-
if "men" in query_lower:
|
35 |
-
result["Gender"] = "Men"
|
36 |
-
elif "women" in query_lower:
|
37 |
-
result["Gender"] = "Women"
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
return result
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
@app.route("/", methods=["GET", "POST"])
|
47 |
def index():
|
|
|
|
|
48 |
if request.method == "POST":
|
49 |
-
query = request.form
|
50 |
-
|
51 |
-
|
52 |
-
return render_template("index.html", result=
|
53 |
|
54 |
if __name__ == "__main__":
|
55 |
-
app.run(
|
|
|
1 |
from flask import Flask, request, render_template
|
2 |
+
from huggingface_hub import InferenceClient
|
3 |
import re
|
4 |
|
5 |
app = Flask(__name__)
|
6 |
|
7 |
+
# Initialize DeepSeek-R1 client
|
8 |
+
client = InferenceClient(model="deepseek-ai/deepseek-llm-67b-chat")
|
|
|
|
|
|
|
|
|
9 |
|
10 |
+
def parse_llm_response(response):
|
11 |
+
"""Improved parsing that handles model's raw responses"""
|
12 |
result = {
|
13 |
"Brand": None,
|
14 |
"Category": None,
|
|
|
16 |
"Price": None
|
17 |
}
|
18 |
|
19 |
+
# Enhanced pattern matching for flexible JSON extraction
|
20 |
+
patterns = {
|
21 |
+
"brand": r'"brand":\s*"([^"]*)"',
|
22 |
+
"category": r'"category":\s*"([^"]*)"',
|
23 |
+
"gender": r'"gender":\s*"([^"]*)"',
|
24 |
+
"price_range": r'"price_range":\s*"([^"]*)"'
|
25 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
for key, pattern in patterns.items():
|
28 |
+
match = re.search(pattern, response, re.IGNORECASE)
|
29 |
+
if match:
|
30 |
+
value = match.group(1).strip()
|
31 |
+
if value.lower() in ["null", "n/a", ""]:
|
32 |
+
continue
|
33 |
+
if key == "brand":
|
34 |
+
result["Brand"] = value.title()
|
35 |
+
elif key == "category":
|
36 |
+
result["Category"] = value.title()
|
37 |
+
elif key == "gender":
|
38 |
+
result["Gender"] = value.title()
|
39 |
+
elif key == "price_range":
|
40 |
+
result["Price"] = value.upper()
|
41 |
|
42 |
return result
|
43 |
|
44 |
+
def analyze_query(query):
|
45 |
+
"""Enhanced prompt for luxury brand understanding"""
|
46 |
+
prompt = f"""Analyze this fashion query and extract structured data. Follow these rules:
|
47 |
+
|
48 |
+
1. Brand: Identify the luxury fashion brand mentioned (e.g., Gucci, Prada, Balenciaga)
|
49 |
+
2. Category: Product type (perfume, bag, shoes, etc.)
|
50 |
+
3. Gender: men, women, or unisex
|
51 |
+
4. Price: Exact price range from query
|
52 |
+
|
53 |
+
Return JSON format:
|
54 |
+
|
55 |
+
{{
|
56 |
+
"brand": "<brand name>",
|
57 |
+
"category": "<product category>",
|
58 |
+
"gender": "<target gender>",
|
59 |
+
"price_range": "<price info>"
|
60 |
+
}}
|
61 |
+
|
62 |
+
Query: "{query}"
|
63 |
+
"""
|
64 |
+
|
65 |
+
response = client.text_generation(
|
66 |
+
prompt=prompt,
|
67 |
+
max_new_tokens=200,
|
68 |
+
temperature=0.01, # More deterministic output
|
69 |
+
stop_sequences=["\n\n"] # Prevent extra text
|
70 |
+
)
|
71 |
+
|
72 |
+
return parse_llm_response(response)
|
73 |
+
|
74 |
@app.route("/", methods=["GET", "POST"])
|
75 |
def index():
|
76 |
+
result = None
|
77 |
+
query = ""
|
78 |
if request.method == "POST":
|
79 |
+
query = request.form.get("query", "")
|
80 |
+
if query.strip():
|
81 |
+
result = analyze_query(query)
|
82 |
+
return render_template("index.html", result=result, query=query)
|
83 |
|
84 |
if __name__ == "__main__":
|
85 |
+
app.run(host="0.0.0.0", port=7860)
|