import streamlit as st import os import faiss import numpy as np import pandas as pd from sentence_transformers import SentenceTransformer import pickle from langchain_huggingface import HuggingFaceEndpoint from transformers import AutoTokenizer, AutoModel import torch # Load tokenizer and model once tokenizer = AutoTokenizer.from_pretrained('src/paraphrase-mpnet-base-v2') model = AutoModel.from_pretrained('src/paraphrase-mpnet-base-v2') def mean_pooling(model_output, attention_mask): token_embeddings = model_output[0] # First element is token embeddings input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) def encode(sentences): encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') with torch.no_grad(): model_output = model(**encoded_input) return mean_pooling(model_output, encoded_input['attention_mask']).cpu().numpy() # Fragrance card function def create_fragrance_card(name, rating, brand, perfumer_text, top_notes, middle_notes, base_notes, accords_text, explanation): # Create fragrance card HTML card_html = f"""

{name} ⭐{rating}

🏷️ Brand: {brand}

👃 Perfumer(s): {perfumer_text}

🌿 Top Notes: {top_notes}

💖 Heart Notes: {middle_notes}

🌲 Base Notes: {base_notes}

🎼 Main Accords: {accords_text}

💡 AI Explanation: {explanation}

""" return card_html # Load FAISS database, metadata, and encoder with cache @st.cache_resource def load_resources(): index = faiss.read_index('src/fragrance_faiss.index') with open('src/fragrance_metadata.pkl', 'rb') as f: metadata = pickle.load(f) return index, metadata # Gets a brief explanation from Ollama for why this fragrance matches the user's query def get_ollama_explanation(query, description): prompt = f""" A user is searching for a fragrance with this description: "{query}" One recommendation is: {description} Explain in 1-2 sentences, in plain English, why this fragrance matches the user's query. """ response = llm.invoke(prompt) return response.strip() # Load LLM llm = HuggingFaceEndpoint( repo_id="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", task="text-generation", huggingfacehub_api_token=os.environ["LLM_TOKEN"] ) # Initialize app st.set_page_config(page_title="Fragrance Recommendation System", layout="wide") # Add title to top of app interface st.title("Fragrance Recommendation System") # Sidebar filters st.sidebar.header("Filters") query = st.text_input("Describe your ideal fragrance:") col1, col2 = st.columns(2) with col1: k = st.slider("Number of recommendations:", 1, 10, 5) with col2: min_rating = st.slider("Minimum rating:", 1.0, 5.0, 3.5) gender_filter = st.sidebar.selectbox("Gender:", ["All", "Male", "Female", "Unisex"]) brand_filter = st.sidebar.text_input("Brand (leave empty for all):", "").title() note_filter = st.sidebar.text_input("Notes (comma-separated):", "").lower() # Load resources index, metadata = load_resources() # Convert rating_values to numeric if 'rating_value' in metadata.columns: metadata['rating_value'] = pd.to_numeric( metadata['rating_value'], errors='coerce') # Press button and start recommendations if st.button('Get Recommendations'): with st.spinner('Finding your fragrance recs...'): if query == "": st.warning("No query entered.") else: # Apply filters sequentially current_df = metadata.copy() # Gender filter if gender_filter != "All": current_df = current_df[current_df['gender'].str.lower() == gender_filter.lower()] # Brand filter if brand_filter: current_df = current_df[current_df['brand'].str.contains(brand_filter, case=False, na=False)] # Rating filter (with NaN handling) if 'rating_value' in current_df.columns: current_df = current_df[current_df['rating_value'].ge(min_rating)] # Note filter if note_filter: notes = [n.strip().lower() for n in note_filter.split(",")] def note_check(row): note_fields = [ str(row['top']).lower() if pd.notna(row['top']) else "", str(row['middle']).lower() if pd.notna(row['middle']) else "", str(row['base']).lower() if pd.notna(row['base']) else "" ] return any(note in field for note in notes for field in note_fields) current_df = current_df[current_df.apply(note_check, axis=1)] valid_indices = current_df.index.tolist() # Check if any fragrances remain if not valid_indices: st.warning("No fragrances match all your filters. Try relaxing some criteria.") st.stop() # Grab the vectors for fragrances still present after the filters filtered_vectors = np.vstack([index.reconstruct(int(idx)) for idx in valid_indices]) temp_index = faiss.IndexFlatIP(filtered_vectors.shape[1]) temp_index.add(filtered_vectors) # Encode the query and normalize it for cosine similarity query_vector = encode([query]) faiss.normalize_L2(query_vector) # Perform the search and returns indices of the most similar vectors and their similarity scores sim_score, I = temp_index.search(query_vector, min(k, len(valid_indices))) # Get the recommened fragrance's indices and similarity score results = [(valid_indices[i], sim_score[0][j]) for j, i in enumerate(I[0])] # Display results st.subheader(f"Recommended Fragrances ({len(results)} results)") cols = st.columns(3) for idx, (result_idx, sim_score) in enumerate(results): rec = metadata.loc[result_idx] # Extract data with fallbacks name = rec.get('perfume', 'Unknown') brand = rec.get('brand', 'Unknown') perfumer_text = rec.get('perfumer', 'Unknown') top_notes = rec.get('top', 'Unknown') middle_notes = rec.get('middle', 'Unknown') base_notes = rec.get('base', 'Unknown') accords_text = rec.get('accord', 'Unknown') rating = rec.get('rating_value', '?') # Create natural language fragrance description description = ( f"The fragrance is called {name}. It is by {brand}. " f"The perfumer is {perfumer_text}. The top notes are {top_notes}, " f"the heart notes are {middle_notes}, and the base notes are {base_notes}. " f"The main accords are {accords_text}." ) explanation = get_ollama_explanation(query, description) # Add rating to card card = create_fragrance_card( name, rating, brand, perfumer_text, top_notes, middle_notes, base_notes, accords_text, explanation ) cols[idx % 3].markdown(card, unsafe_allow_html=True)