jrkotun commited on
Commit
96a5a78
Β·
verified Β·
1 Parent(s): 0d560c7

Upload 3 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ src/fragrance_faiss.index filter=lfs diff=lfs merge=lfs -text
src/fragrance_faiss.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a235f0b4596acedbea9331f0fdc4f7c354e2cbbf3631adca4f5cf83b8778b988
3
+ size 73921581
src/fragrance_metadata.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a3dbbdd17c47cd65492a0ff104e4be008879e7862dfa5d26b7f896fa9831a4d
3
+ size 5195729
src/streamlit_app.py CHANGED
@@ -1,40 +1,179 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ import faiss
4
+ import numpy as np
5
+ import pandas as pd
6
+ from sentence_transformers import SentenceTransformer
7
+ import pickle
8
+ from langchain_ollama import ChatOllama
9
+
10
+ # Fragrance card function
11
+ def create_fragrance_card(name, rating, brand, perfumer_text, top_notes, middle_notes, base_notes, accords_text, explanation):
12
+ # Create fragrance card HTML
13
+ card_html = f"""
14
+ <div style="border: 1px solid #ddd; padding: 15px; margin: 10px; border-radius: 15px;
15
+ background: linear-gradient(to bottom right, #ffffff, #f2f6fc);
16
+ width: 400px; color: #222; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
17
+ <h3 style="color: #3a3a3a; text-align: center;">{name} ⭐{rating}</h3>
18
+ <p><strong>🏷️ Brand:</strong> {brand}</p>
19
+ <p><strong>πŸ‘ƒ Perfumer(s):</strong> {perfumer_text}</p>
20
+ <p><strong>🌿 Top Notes:</strong> {top_notes}</p>
21
+ <p><strong>πŸ’– Heart Notes:</strong> {middle_notes}</p>
22
+ <p><strong>🌲 Base Notes:</strong> {base_notes}</p>
23
+ <p><strong>🎼 Main Accords:</strong> {accords_text}</p>
24
+ <p><strong>πŸ’‘ AI Explanation:</strong> {explanation}</p>
25
+ </div>
26
+ """
27
+
28
+ return card_html
29
+
30
+ # Load FAISS database, metadata, and encoder with cache
31
+ @st.cache_resource
32
+ def load_resources():
33
+ index = faiss.read_index('fragrance_faiss.index')
34
+ with open('fragrance_metadata.pkl', 'rb') as f:
35
+ metadata = pickle.load(f)
36
+ encoder = SentenceTransformer('paraphrase-mpnet-base-v2')
37
+ return index, metadata, encoder
38
+
39
+ # Gets a brief explanation from Ollama for why this fragrance matches the user's query
40
+ def get_ollama_explanation(query, description):
41
+ prompt = f"""
42
+ A user is searching for a fragrance with this description: "{query}"
43
+
44
+ One recommendation is:
45
+ {description}
46
+
47
+ Explain in 1-2 sentences, in plain English, why this fragrance matches the user's query.
48
+ """
49
+ response = llm.invoke(prompt)
50
+ return response.content.strip()
51
+
52
+ # Load Ollama
53
+ llm = ChatOllama(model="llama3.2")
54
+
55
+ # Initialize app
56
+ st.set_page_config(page_title="Fragrance Recommendation System", layout="wide")
57
+
58
+ # Add title to top of app interface
59
+ st.title("Fragrance Recommendation System")
60
+
61
+ # Sidebar filters
62
+ st.sidebar.header("Filters")
63
+ query = st.text_input("Describe your ideal fragrance:")
64
+
65
+ col1, col2 = st.columns(2)
66
+ with col1:
67
+ k = st.slider("Number of recommendations:", 1, 10, 5)
68
+ with col2:
69
+ min_rating = st.slider("Minimum rating:", 1.0, 5.0, 3.5)
70
+
71
+ gender_filter = st.sidebar.selectbox("Gender:", ["All", "Male", "Female", "Unisex"])
72
+ brand_filter = st.sidebar.text_input("Brand (leave empty for all):", "").title()
73
+ note_filter = st.sidebar.text_input("Notes (comma-separated):", "").lower()
74
+
75
+ # Load resources
76
+ index, metadata, encoder = load_resources()
77
+
78
+ # Convert rating_values to numeric
79
+ if 'rating_value' in metadata.columns:
80
+ metadata['rating_value'] = pd.to_numeric(
81
+ metadata['rating_value'],
82
+ errors='coerce')
83
+
84
+ # Press button and start recommendations
85
+ if st.button('Get Recommendations'):
86
+ with st.spinner('Finding your fragrance recs...'):
87
+ if query == "":
88
+ st.warning("No query entered.")
89
+ else:
90
+ # Apply filters sequentially
91
+ current_df = metadata.copy()
92
+
93
+ # Gender filter
94
+ if gender_filter != "All":
95
+ current_df = current_df[current_df['gender'].str.lower() == gender_filter.lower()]
96
+
97
+ # Brand filter
98
+ if brand_filter:
99
+ current_df = current_df[current_df['brand'].str.contains(brand_filter, case=False, na=False)]
100
+
101
+ # Rating filter (with NaN handling)
102
+ if 'rating_value' in current_df.columns:
103
+ current_df = current_df[current_df['rating_value'].ge(min_rating)]
104
+
105
+ # Note filter
106
+ if note_filter:
107
+ notes = [n.strip().lower() for n in note_filter.split(",")]
108
+ def note_check(row):
109
+ note_fields = [
110
+ str(row['top']).lower() if pd.notna(row['top']) else "",
111
+ str(row['middle']).lower() if pd.notna(row['middle']) else "",
112
+ str(row['base']).lower() if pd.notna(row['base']) else ""
113
+ ]
114
+ return any(note in field for note in notes for field in note_fields)
115
+
116
+ current_df = current_df[current_df.apply(note_check, axis=1)]
117
+
118
+ valid_indices = current_df.index.tolist()
119
+
120
+ # Check if any fragrances remain
121
+ if not valid_indices:
122
+ st.warning("No fragrances match all your filters. Try relaxing some criteria.")
123
+ st.stop()
124
+
125
+ # Grab the vectors for fragrances still present after the filters
126
+ filtered_vectors = np.vstack([index.reconstruct(int(idx)) for idx in valid_indices])
127
+ temp_index = faiss.IndexFlatIP(filtered_vectors.shape[1])
128
+ temp_index.add(filtered_vectors)
129
+
130
+ # Encode the query and normalize it for cosine similarity
131
+ query_vector = encoder.encode([query])
132
+ faiss.normalize_L2(query_vector)
133
+
134
+ # Perform the search and returns indices of the most similar vectors and their similarity scores
135
+ sim_score, I = temp_index.search(query_vector, min(k, len(valid_indices)))
136
+
137
+ # Get the recommened fragrance's indices and similarity score
138
+ results = [(valid_indices[i], sim_score[0][j]) for j, i in enumerate(I[0])]
139
+
140
+ # Display results
141
+ st.subheader(f"Recommended Fragrances ({len(results)} results)")
142
+ cols = st.columns(3)
143
+
144
+ for idx, (result_idx, sim_score) in enumerate(results):
145
+ rec = metadata.loc[result_idx]
146
+
147
+ # Extract data with fallbacks
148
+ name = rec.get('perfume', 'Unknown')
149
+ brand = rec.get('brand', 'Unknown')
150
+ perfumer_text = rec.get('perfumer', 'Unknown')
151
+ top_notes = rec.get('top', 'Unknown')
152
+ middle_notes = rec.get('middle', 'Unknown')
153
+ base_notes = rec.get('base', 'Unknown')
154
+ accords_text = rec.get('accord', 'Unknown')
155
+ rating = rec.get('rating_value', '?')
156
+
157
+ # Create natural language fragrance description
158
+ description = (
159
+ f"The fragrance is called {name}. It is by {brand}. "
160
+ f"The perfumer is {perfumer_text}. The top notes are {top_notes}, "
161
+ f"the heart notes are {middle_notes}, and the base notes are {base_notes}. "
162
+ f"The main accords are {accords_text}."
163
+ )
164
+
165
+ explanation = get_ollama_explanation(query, description)
166
+
167
+ # Add rating to card
168
+ card = create_fragrance_card(
169
+ name,
170
+ rating,
171
+ brand,
172
+ perfumer_text,
173
+ top_notes,
174
+ middle_notes,
175
+ base_notes,
176
+ accords_text,
177
+ explanation
178
+ )
179
+ cols[idx % 3].markdown(card, unsafe_allow_html=True)