Spaces:
Running
Running
File size: 4,035 Bytes
34a309c 4421377 241e2a0 4421377 480d3e2 4421377 241e2a0 34a309c 241e2a0 0ee3833 241e2a0 0ee3833 241e2a0 0ee3833 34a309c 241e2a0 4421377 0ee3833 241e2a0 0ee3833 241e2a0 67ac5b4 0ee3833 241e2a0 0ee3833 241e2a0 67ac5b4 241e2a0 67ac5b4 0ee3833 67ac5b4 452e69e 241e2a0 452e69e 241e2a0 452e69e 241e2a0 452e69e 241e2a0 452e69e 241e2a0 452e69e 241e2a0 452e69e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
# chatbot.py
import streamlit as st
from transformers import pipeline, BlenderbotTokenizer, BlenderbotForConditionalGeneration
import torch
from typing import List, Dict
class ChatbotManager:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = None
self.tokenizer = None
self.load_model()
self.initialize_chat()
def load_model(self):
"""Load Blenderbot model locally"""
try:
with st.spinner("Loading AI model (this may take a minute)..."):
model_name = "facebook/blenderbot-400M-distill"
self.tokenizer = BlenderbotTokenizer.from_pretrained(model_name)
self.model = BlenderbotForConditionalGeneration.from_pretrained(model_name).to(self.device)
st.success("Model loaded successfully!")
except Exception as e:
st.error(f"⚠️ Failed to load model: {str(e)}")
self.model = None
def initialize_chat(self):
"""Initialize chat session state"""
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
def clear_chat(self):
"""Reset chat history"""
st.session_state.chat_history = []
st.success("Chat history cleared!")
def add_message(self, role: str, content: str):
"""Add a message to chat history"""
st.session_state.chat_history.append({"role": role, "content": content})
def get_chat_history(self) -> List[Dict]:
"""Retrieve chat history"""
return st.session_state.chat_history
def generate_response(self, prompt: str) -> str:
"""Generate AI response using Blenderbot"""
if not self.model:
return "Model not loaded. Please try again later."
try:
# Format prompt with business context
business_prompt = f"""You are a professional business advisor. Provide helpful, concise advice on:
- Business strategy
- Marketing
- Product development
- Startup growth
User Question: {prompt}
Answer:"""
inputs = self.tokenizer([business_prompt], return_tensors="pt").to(self.device)
reply_ids = self.model.generate(**inputs, max_length=200)
response = self.tokenizer.decode(reply_ids[0], skip_special_tokens=True)
return response
except Exception as e:
return f"⚠️ Error generating response: {str(e)}"
def render_chat_interface(self):
"""Render the complete chat UI"""
st.header("💬 AI Business Mentor (Blenderbot)")
# Sidebar controls
with st.sidebar:
st.subheader("Settings")
if st.button("Clear Chat"):
self.clear_chat()
st.rerun()
st.markdown("---")
st.caption("Model: facebook/blenderbot-400M-distill")
st.caption(f"Device: {self.device.upper()}")
# Display chat history
for message in self.get_chat_history():
with st.chat_message(message["role"]):
st.markdown(message["content"])
# User input
if prompt := st.chat_input("Ask about business..."):
self.add_message("user", prompt)
# Display user message immediately
with st.chat_message("user"):
st.markdown(prompt)
# Generate and display AI response
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = self.generate_response(prompt)
st.markdown(response)
# Add response to history
self.add_message("assistant", response)
# Auto-refresh to show new messages
st.rerun() |