DINGOLANI commited on
Commit
59f66b5
·
verified ·
1 Parent(s): fb61b7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -33
app.py CHANGED
@@ -1,17 +1,14 @@
1
  from flask import Flask, request, render_template
2
- from transformers import pipeline
3
  import re
4
 
5
  app = Flask(__name__)
6
 
7
- # Use a PUBLICLY AVAILABLE model that works on free tier
8
- ner_pipeline = pipeline(
9
- "ner",
10
- model="dslim/bert-base-NER", # Verified public model
11
- aggregation_strategy="simple"
12
- )
13
 
14
- def extract_entities(query):
 
15
  result = {
16
  "Brand": None,
17
  "Category": None,
@@ -19,37 +16,70 @@ def extract_entities(query):
19
  "Price": None
20
  }
21
 
22
- # Extract entities
23
- entities = ner_pipeline(query)
24
-
25
- # Process entities
26
- for entity in entities:
27
- if entity["entity_group"] == "ORG": # Organizations are likely brands
28
- result["Brand"] = entity["word"]
29
-
30
- # Add keyword-based extraction for other fields
31
- query_lower = query.lower()
32
- if "perfume" in query_lower or "cologne" in query_lower:
33
- result["Category"] = "Perfume"
34
- if "men" in query_lower:
35
- result["Gender"] = "Men"
36
- elif "women" in query_lower:
37
- result["Gender"] = "Women"
38
 
39
- # Price extraction
40
- price_match = re.search(r"under (\d+)\s*AED", query, re.IGNORECASE)
41
- if price_match:
42
- result["Price"] = f"Under {price_match.group(1)} AED"
 
 
 
 
 
 
 
 
 
 
43
 
44
  return result
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  @app.route("/", methods=["GET", "POST"])
47
  def index():
 
 
48
  if request.method == "POST":
49
- query = request.form["query"]
50
- result = extract_entities(query)
51
- return render_template("index.html", result=result, query=query)
52
- return render_template("index.html", result=None)
53
 
54
  if __name__ == "__main__":
55
- app.run(debug=True, host="0.0.0.0", port=7860)
 
1
  from flask import Flask, request, render_template
2
+ from huggingface_hub import InferenceClient
3
  import re
4
 
5
  app = Flask(__name__)
6
 
7
+ # Initialize DeepSeek-R1 client
8
+ client = InferenceClient(model="deepseek-ai/deepseek-llm-67b-chat")
 
 
 
 
9
 
10
+ def parse_llm_response(response):
11
+ """Improved parsing that handles model's raw responses"""
12
  result = {
13
  "Brand": None,
14
  "Category": None,
 
16
  "Price": None
17
  }
18
 
19
+ # Enhanced pattern matching for flexible JSON extraction
20
+ patterns = {
21
+ "brand": r'"brand":\s*"([^"]*)"',
22
+ "category": r'"category":\s*"([^"]*)"',
23
+ "gender": r'"gender":\s*"([^"]*)"',
24
+ "price_range": r'"price_range":\s*"([^"]*)"'
25
+ }
 
 
 
 
 
 
 
 
 
26
 
27
+ for key, pattern in patterns.items():
28
+ match = re.search(pattern, response, re.IGNORECASE)
29
+ if match:
30
+ value = match.group(1).strip()
31
+ if value.lower() in ["null", "n/a", ""]:
32
+ continue
33
+ if key == "brand":
34
+ result["Brand"] = value.title()
35
+ elif key == "category":
36
+ result["Category"] = value.title()
37
+ elif key == "gender":
38
+ result["Gender"] = value.title()
39
+ elif key == "price_range":
40
+ result["Price"] = value.upper()
41
 
42
  return result
43
 
44
+ def analyze_query(query):
45
+ """Enhanced prompt for luxury brand understanding"""
46
+ prompt = f"""Analyze this fashion query and extract structured data. Follow these rules:
47
+
48
+ 1. Brand: Identify the luxury fashion brand mentioned (e.g., Gucci, Prada, Balenciaga)
49
+ 2. Category: Product type (perfume, bag, shoes, etc.)
50
+ 3. Gender: men, women, or unisex
51
+ 4. Price: Exact price range from query
52
+
53
+ Return JSON format:
54
+
55
+ {{
56
+ "brand": "<brand name>",
57
+ "category": "<product category>",
58
+ "gender": "<target gender>",
59
+ "price_range": "<price info>"
60
+ }}
61
+
62
+ Query: "{query}"
63
+ """
64
+
65
+ response = client.text_generation(
66
+ prompt=prompt,
67
+ max_new_tokens=200,
68
+ temperature=0.01, # More deterministic output
69
+ stop_sequences=["\n\n"] # Prevent extra text
70
+ )
71
+
72
+ return parse_llm_response(response)
73
+
74
  @app.route("/", methods=["GET", "POST"])
75
  def index():
76
+ result = None
77
+ query = ""
78
  if request.method == "POST":
79
+ query = request.form.get("query", "")
80
+ if query.strip():
81
+ result = analyze_query(query)
82
+ return render_template("index.html", result=result, query=query)
83
 
84
  if __name__ == "__main__":
85
+ app.run(host="0.0.0.0", port=7860)