File size: 4,035 Bytes
34a309c
4421377
241e2a0
 
 
4421377
480d3e2
4421377
241e2a0
 
 
 
34a309c
 
241e2a0
 
0ee3833
241e2a0
 
 
 
 
0ee3833
241e2a0
0ee3833
 
34a309c
241e2a0
 
4421377
0ee3833
 
241e2a0
0ee3833
 
 
 
241e2a0
 
67ac5b4
0ee3833
241e2a0
0ee3833
 
241e2a0
 
 
 
 
67ac5b4
241e2a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67ac5b4
0ee3833
67ac5b4
452e69e
241e2a0
 
 
 
 
 
 
 
 
 
 
 
 
452e69e
 
 
 
 
 
 
241e2a0
452e69e
 
 
 
 
 
241e2a0
452e69e
241e2a0
 
452e69e
 
241e2a0
452e69e
 
241e2a0
452e69e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# chatbot.py
import streamlit as st
from transformers import pipeline, BlenderbotTokenizer, BlenderbotForConditionalGeneration
import torch
from typing import List, Dict

class ChatbotManager:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = None
        self.tokenizer = None
        self.load_model()
        self.initialize_chat()
    
    def load_model(self):
        """Load Blenderbot model locally"""
        try:
            with st.spinner("Loading AI model (this may take a minute)..."):
                model_name = "facebook/blenderbot-400M-distill"
                self.tokenizer = BlenderbotTokenizer.from_pretrained(model_name)
                self.model = BlenderbotForConditionalGeneration.from_pretrained(model_name).to(self.device)
            st.success("Model loaded successfully!")
        except Exception as e:
            st.error(f"⚠️ Failed to load model: {str(e)}")
            self.model = None
    
    def initialize_chat(self):
        """Initialize chat session state"""
        if "chat_history" not in st.session_state:
            st.session_state.chat_history = []
    
    def clear_chat(self):
        """Reset chat history"""
        st.session_state.chat_history = []
        st.success("Chat history cleared!")
    
    def add_message(self, role: str, content: str):
        """Add a message to chat history"""
        st.session_state.chat_history.append({"role": role, "content": content})
    
    def get_chat_history(self) -> List[Dict]:
        """Retrieve chat history"""
        return st.session_state.chat_history
    
    def generate_response(self, prompt: str) -> str:
        """Generate AI response using Blenderbot"""
        if not self.model:
            return "Model not loaded. Please try again later."
        
        try:
            # Format prompt with business context
            business_prompt = f"""You are a professional business advisor. Provide helpful, concise advice on:
            - Business strategy
            - Marketing
            - Product development
            - Startup growth
            
            User Question: {prompt}
            
            Answer:"""
            
            inputs = self.tokenizer([business_prompt], return_tensors="pt").to(self.device)
            reply_ids = self.model.generate(**inputs, max_length=200)
            response = self.tokenizer.decode(reply_ids[0], skip_special_tokens=True)
            
            return response
        except Exception as e:
            return f"⚠️ Error generating response: {str(e)}"
    
    def render_chat_interface(self):
        """Render the complete chat UI"""
        st.header("💬 AI Business Mentor (Blenderbot)")
        
        # Sidebar controls
        with st.sidebar:
            st.subheader("Settings")
            if st.button("Clear Chat"):
                self.clear_chat()
                st.rerun()
            
            st.markdown("---")
            st.caption("Model: facebook/blenderbot-400M-distill")
            st.caption(f"Device: {self.device.upper()}")
        
        # Display chat history
        for message in self.get_chat_history():
            with st.chat_message(message["role"]):
                st.markdown(message["content"])
        
        # User input
        if prompt := st.chat_input("Ask about business..."):
            self.add_message("user", prompt)
            
            # Display user message immediately
            with st.chat_message("user"):
                st.markdown(prompt)
            
            # Generate and display AI response
            with st.chat_message("assistant"):
                with st.spinner("Thinking..."):
                    response = self.generate_response(prompt)
                    st.markdown(response)
            
            # Add response to history
            self.add_message("assistant", response)
            
            # Auto-refresh to show new messages
            st.rerun()