Pranav0111 commited on
Commit
af686c2
·
verified ·
1 Parent(s): a8c313a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -5
app.py CHANGED
@@ -5,6 +5,7 @@ import pandas as pd
5
  import numpy as np
6
  from datetime import datetime, timedelta
7
  from typing import Dict, List, Any
 
8
 
9
  # --- Data Processing Class ---
10
  class DataProcessor:
@@ -377,6 +378,84 @@ def render_chat():
377
  st.markdown(response)
378
  st.session_state.messages.append({"role": "assistant", "content": response})
379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  def main():
381
  st.set_page_config(
382
  page_title="Prospira",
@@ -391,7 +470,7 @@ def main():
391
 
392
  page = st.radio(
393
  "Navigation",
394
- ["Dashboard", "Analytics", "Brainstorm", "Chat"]
395
  )
396
 
397
  if page == "Dashboard":
@@ -400,8 +479,7 @@ def main():
400
  render_analytics()
401
  elif page == "Brainstorm":
402
  render_brainstorm_page()
 
 
403
  elif page == "Chat":
404
- render_chat()
405
-
406
- if __name__ == "__main__":
407
- main()
 
5
  import numpy as np
6
  from datetime import datetime, timedelta
7
  from typing import Dict, List, Any
8
+ from render_ai_assistant import render_ai_assistant
9
 
10
  # --- Data Processing Class ---
11
  class DataProcessor:
 
378
  st.markdown(response)
379
  st.session_state.messages.append({"role": "assistant", "content": response})
380
 
381
+ def load_huggingface_model(model_name="google/flan-t5-base"):
382
+ """
383
+ Load a pre-trained model from Hugging Face
384
+
385
+ Args:
386
+ model_name (str): Hugging Face model identifier
387
+
388
+ Returns:
389
+ tuple: Loaded model and tokenizer
390
+ """
391
+ try:
392
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
393
+ model = AutoModelForCausalLM.from_pretrained(model_name)
394
+ return model, tokenizer
395
+ except Exception as e:
396
+ st.error(f"Error loading model: {e}")
397
+ return None, None
398
+
399
+ def generate_text(model, tokenizer, prompt, max_length=200):
400
+ """
401
+ Generate text based on input prompt
402
+
403
+ Args:
404
+ model: Loaded Hugging Face model
405
+ tokenizer: Model's tokenizer
406
+ prompt (str): Input text prompt
407
+ max_length (int): Maximum generated text length
408
+
409
+ Returns:
410
+ str: Generated text
411
+ """
412
+ inputs = tokenizer(prompt, return_tensors="pt")
413
+
414
+ with torch.no_grad():
415
+ outputs = model.generate(
416
+ **inputs,
417
+ max_length=max_length,
418
+ num_return_sequences=1,
419
+ do_sample=True,
420
+ temperature=0.7
421
+ )
422
+
423
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
424
+
425
+ def render_ai_assistant():
426
+ st.title("🤖 Business AI Assistant")
427
+
428
+ # Model Selection
429
+ model_options = {
430
+ "Google Flan-T5": "google/flan-t5-base",
431
+ "DialoGPT": "microsoft/DialoGPT-medium",
432
+ "GPT-2 Small": "gpt2"
433
+ }
434
+
435
+ selected_model = st.selectbox(
436
+ "Choose AI Model",
437
+ list(model_options.keys())
438
+ )
439
+
440
+ # Load Selected Model
441
+ model_name = model_options[selected_model]
442
+ model, tokenizer = load_huggingface_model(model_name)
443
+
444
+ if model and tokenizer:
445
+ # Prompt Input
446
+ user_prompt = st.text_area(
447
+ "Enter your business query",
448
+ placeholder="Ask about business strategy, product analysis, etc."
449
+ )
450
+
451
+ if st.button("Generate Response"):
452
+ with st.spinner("Generating response..."):
453
+ response = generate_text(model, tokenizer, user_prompt)
454
+ st.success("AI Response:")
455
+ st.write(response)
456
+ else:
457
+ st.error("Failed to load model")
458
+
459
  def main():
460
  st.set_page_config(
461
  page_title="Prospira",
 
470
 
471
  page = st.radio(
472
  "Navigation",
473
+ ["Dashboard", "Analytics", "Brainstorm", "AI Assistant", "Chat"] # Added AI Assistant
474
  )
475
 
476
  if page == "Dashboard":
 
479
  render_analytics()
480
  elif page == "Brainstorm":
481
  render_brainstorm_page()
482
+ elif page == "AI Assistant": # New condition
483
+ render_ai_assistant()
484
  elif page == "Chat":
485
+ render_chat()