Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,7 @@ import pandas as pd
|
|
5 |
import numpy as np
|
6 |
from datetime import datetime, timedelta
|
7 |
from typing import Dict, List, Any
|
|
|
8 |
|
9 |
# --- Data Processing Class ---
|
10 |
class DataProcessor:
|
@@ -377,6 +378,84 @@ def render_chat():
|
|
377 |
st.markdown(response)
|
378 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
379 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
def main():
|
381 |
st.set_page_config(
|
382 |
page_title="Prospira",
|
@@ -391,7 +470,7 @@ def main():
|
|
391 |
|
392 |
page = st.radio(
|
393 |
"Navigation",
|
394 |
-
["Dashboard", "Analytics", "Brainstorm", "Chat"]
|
395 |
)
|
396 |
|
397 |
if page == "Dashboard":
|
@@ -400,8 +479,7 @@ def main():
|
|
400 |
render_analytics()
|
401 |
elif page == "Brainstorm":
|
402 |
render_brainstorm_page()
|
|
|
|
|
403 |
elif page == "Chat":
|
404 |
-
render_chat()
|
405 |
-
|
406 |
-
if __name__ == "__main__":
|
407 |
-
main()
|
|
|
5 |
import numpy as np
|
6 |
from datetime import datetime, timedelta
|
7 |
from typing import Dict, List, Any
|
8 |
+
from render_ai_assistant import render_ai_assistant
|
9 |
|
10 |
# --- Data Processing Class ---
|
11 |
class DataProcessor:
|
|
|
378 |
st.markdown(response)
|
379 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
380 |
|
381 |
+
def load_huggingface_model(model_name="google/flan-t5-base"):
|
382 |
+
"""
|
383 |
+
Load a pre-trained model from Hugging Face
|
384 |
+
|
385 |
+
Args:
|
386 |
+
model_name (str): Hugging Face model identifier
|
387 |
+
|
388 |
+
Returns:
|
389 |
+
tuple: Loaded model and tokenizer
|
390 |
+
"""
|
391 |
+
try:
|
392 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
393 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
394 |
+
return model, tokenizer
|
395 |
+
except Exception as e:
|
396 |
+
st.error(f"Error loading model: {e}")
|
397 |
+
return None, None
|
398 |
+
|
399 |
+
def generate_text(model, tokenizer, prompt, max_length=200):
|
400 |
+
"""
|
401 |
+
Generate text based on input prompt
|
402 |
+
|
403 |
+
Args:
|
404 |
+
model: Loaded Hugging Face model
|
405 |
+
tokenizer: Model's tokenizer
|
406 |
+
prompt (str): Input text prompt
|
407 |
+
max_length (int): Maximum generated text length
|
408 |
+
|
409 |
+
Returns:
|
410 |
+
str: Generated text
|
411 |
+
"""
|
412 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
413 |
+
|
414 |
+
with torch.no_grad():
|
415 |
+
outputs = model.generate(
|
416 |
+
**inputs,
|
417 |
+
max_length=max_length,
|
418 |
+
num_return_sequences=1,
|
419 |
+
do_sample=True,
|
420 |
+
temperature=0.7
|
421 |
+
)
|
422 |
+
|
423 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
424 |
+
|
425 |
+
def render_ai_assistant():
|
426 |
+
st.title("🤖 Business AI Assistant")
|
427 |
+
|
428 |
+
# Model Selection
|
429 |
+
model_options = {
|
430 |
+
"Google Flan-T5": "google/flan-t5-base",
|
431 |
+
"DialoGPT": "microsoft/DialoGPT-medium",
|
432 |
+
"GPT-2 Small": "gpt2"
|
433 |
+
}
|
434 |
+
|
435 |
+
selected_model = st.selectbox(
|
436 |
+
"Choose AI Model",
|
437 |
+
list(model_options.keys())
|
438 |
+
)
|
439 |
+
|
440 |
+
# Load Selected Model
|
441 |
+
model_name = model_options[selected_model]
|
442 |
+
model, tokenizer = load_huggingface_model(model_name)
|
443 |
+
|
444 |
+
if model and tokenizer:
|
445 |
+
# Prompt Input
|
446 |
+
user_prompt = st.text_area(
|
447 |
+
"Enter your business query",
|
448 |
+
placeholder="Ask about business strategy, product analysis, etc."
|
449 |
+
)
|
450 |
+
|
451 |
+
if st.button("Generate Response"):
|
452 |
+
with st.spinner("Generating response..."):
|
453 |
+
response = generate_text(model, tokenizer, user_prompt)
|
454 |
+
st.success("AI Response:")
|
455 |
+
st.write(response)
|
456 |
+
else:
|
457 |
+
st.error("Failed to load model")
|
458 |
+
|
459 |
def main():
|
460 |
st.set_page_config(
|
461 |
page_title="Prospira",
|
|
|
470 |
|
471 |
page = st.radio(
|
472 |
"Navigation",
|
473 |
+
["Dashboard", "Analytics", "Brainstorm", "AI Assistant", "Chat"] # Added AI Assistant
|
474 |
)
|
475 |
|
476 |
if page == "Dashboard":
|
|
|
479 |
render_analytics()
|
480 |
elif page == "Brainstorm":
|
481 |
render_brainstorm_page()
|
482 |
+
elif page == "AI Assistant": # New condition
|
483 |
+
render_ai_assistant()
|
484 |
elif page == "Chat":
|
485 |
+
render_chat()
|
|
|
|
|
|