Spaces:
Running
Running
# chatbot.py | |
import streamlit as st | |
from transformers import pipeline, BlenderbotTokenizer, BlenderbotForConditionalGeneration | |
import torch | |
from typing import List, Dict | |
class ChatbotManager: | |
def __init__(self): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model = None | |
self.tokenizer = None | |
self.load_model() | |
self.initialize_chat() | |
def load_model(self): | |
"""Load Blenderbot model locally""" | |
try: | |
with st.spinner("Loading AI model (this may take a minute)..."): | |
model_name = "facebook/blenderbot-400M-distill" | |
self.tokenizer = BlenderbotTokenizer.from_pretrained(model_name) | |
self.model = BlenderbotForConditionalGeneration.from_pretrained(model_name).to(self.device) | |
st.success("Model loaded successfully!") | |
except Exception as e: | |
st.error(f"⚠️ Failed to load model: {str(e)}") | |
self.model = None | |
def initialize_chat(self): | |
"""Initialize chat session state""" | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
def clear_chat(self): | |
"""Reset chat history""" | |
st.session_state.chat_history = [] | |
st.success("Chat history cleared!") | |
def add_message(self, role: str, content: str): | |
"""Add a message to chat history""" | |
st.session_state.chat_history.append({"role": role, "content": content}) | |
def get_chat_history(self) -> List[Dict]: | |
"""Retrieve chat history""" | |
return st.session_state.chat_history | |
def generate_response(self, prompt: str) -> str: | |
"""Generate AI response using Blenderbot""" | |
if not self.model: | |
return "Model not loaded. Please try again later." | |
try: | |
# Format prompt with business context | |
business_prompt = f"""You are a professional business advisor. Provide helpful, concise advice on: | |
- Business strategy | |
- Marketing | |
- Product development | |
- Startup growth | |
User Question: {prompt} | |
Answer:""" | |
inputs = self.tokenizer([business_prompt], return_tensors="pt").to(self.device) | |
reply_ids = self.model.generate(**inputs, max_length=200) | |
response = self.tokenizer.decode(reply_ids[0], skip_special_tokens=True) | |
return response | |
except Exception as e: | |
return f"⚠️ Error generating response: {str(e)}" | |
def render_chat_interface(self): | |
"""Render the complete chat UI""" | |
st.header("💬 AI Business Mentor (Blenderbot)") | |
# Sidebar controls | |
with st.sidebar: | |
st.subheader("Settings") | |
if st.button("Clear Chat"): | |
self.clear_chat() | |
st.rerun() | |
st.markdown("---") | |
st.caption("Model: facebook/blenderbot-400M-distill") | |
st.caption(f"Device: {self.device.upper()}") | |
# Display chat history | |
for message in self.get_chat_history(): | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# User input | |
if prompt := st.chat_input("Ask about business..."): | |
self.add_message("user", prompt) | |
# Display user message immediately | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Generate and display AI response | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
response = self.generate_response(prompt) | |
st.markdown(response) | |
# Add response to history | |
self.add_message("assistant", response) | |
# Auto-refresh to show new messages | |
st.rerun() |