DINGOLANI commited on
Commit
85a27c5
·
verified ·
1 Parent(s): 8c1ee79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -11
app.py CHANGED
@@ -8,14 +8,14 @@ model_name = "patrickjohncyh/fashion-clip"
8
  model = CLIPModel.from_pretrained(model_name)
9
  processor = CLIPProcessor.from_pretrained(model_name)
10
 
11
- # Price extraction regex
12
  price_pattern = re.compile(r'(\bunder\b|\babove\b|\bbelow\b|\bbetween\b)?\s?(\d{1,5})\s?(AED|USD|EUR)?', re.IGNORECASE)
13
 
14
- def get_text_embedding(text):
15
  """
16
- Converts input text into an embedding using FashionCLIP.
17
  """
18
- inputs = processor(text=[text], images=None, return_tensors="pt", padding=True)
19
  with torch.no_grad():
20
  text_embedding = model.get_text_features(**inputs)
21
  return text_embedding
@@ -27,18 +27,19 @@ def extract_attributes(query):
27
  structured_output = {"Brand": "Unknown", "Category": "Unknown", "Gender": "Unknown", "Price": "Unknown"}
28
 
29
  # Get embedding for the query
30
- query_embedding = get_text_embedding(query)
31
 
32
- # Compare with embeddings of common fashion attribute words (using FashionCLIP)
33
  reference_labels = ["Brand", "Category", "Gender", "Price"]
34
  reference_embeddings = get_text_embedding(reference_labels)
35
 
36
- # Compute cosine similarity to classify the type of query
37
  similarities = torch.nn.functional.cosine_similarity(query_embedding, reference_embeddings)
38
  best_match_index = similarities.argmax().item()
39
 
40
- # Assign type dynamically
41
  attribute_type = reference_labels[best_match_index]
 
42
 
43
  # Extract price dynamically
44
  price_match = price_pattern.search(query)
@@ -46,9 +47,6 @@ def extract_attributes(query):
46
  condition, amount, currency = price_match.groups()
47
  structured_output["Price"] = f"{condition.capitalize() if condition else ''} {amount} {currency if currency else 'AED'}".strip()
48
 
49
- # Extract brand & category dynamically using FashionCLIP similarity
50
- structured_output[attribute_type] = query # Assigning full query text to matched attribute
51
-
52
  return structured_output
53
 
54
  # Define Gradio UI
 
8
  model = CLIPModel.from_pretrained(model_name)
9
  processor = CLIPProcessor.from_pretrained(model_name)
10
 
11
+ # Regex for price extraction
12
  price_pattern = re.compile(r'(\bunder\b|\babove\b|\bbelow\b|\bbetween\b)?\s?(\d{1,5})\s?(AED|USD|EUR)?', re.IGNORECASE)
13
 
14
+ def get_text_embedding(text_list):
15
  """
16
+ Converts a list of input texts into embeddings using FashionCLIP.
17
  """
18
+ inputs = processor(text=text_list, return_tensors="pt", padding=True) # Corrected input format
19
  with torch.no_grad():
20
  text_embedding = model.get_text_features(**inputs)
21
  return text_embedding
 
27
  structured_output = {"Brand": "Unknown", "Category": "Unknown", "Gender": "Unknown", "Price": "Unknown"}
28
 
29
  # Get embedding for the query
30
+ query_embedding = get_text_embedding([query])
31
 
32
+ # Reference labels for classification
33
  reference_labels = ["Brand", "Category", "Gender", "Price"]
34
  reference_embeddings = get_text_embedding(reference_labels)
35
 
36
+ # Compute cosine similarity
37
  similarities = torch.nn.functional.cosine_similarity(query_embedding, reference_embeddings)
38
  best_match_index = similarities.argmax().item()
39
 
40
+ # Assign attribute dynamically
41
  attribute_type = reference_labels[best_match_index]
42
+ structured_output[attribute_type] = query # Assigns the query text to the detected attribute
43
 
44
  # Extract price dynamically
45
  price_match = price_pattern.search(query)
 
47
  condition, amount, currency = price_match.groups()
48
  structured_output["Price"] = f"{condition.capitalize() if condition else ''} {amount} {currency if currency else 'AED'}".strip()
49
 
 
 
 
50
  return structured_output
51
 
52
  # Define Gradio UI