shrish191 commited on
Commit
ee2a5a1
Β·
verified Β·
1 Parent(s): 396109f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -1
app.py CHANGED
@@ -359,7 +359,7 @@ demo = gr.Interface(
359
 
360
  demo.launch()
361
  '''
362
-
363
  import gradio as gr
364
  from transformers import TFBertForSequenceClassification, BertTokenizer
365
  import tensorflow as tf
@@ -534,7 +534,159 @@ demo = gr.TabbedInterface(
534
  )
535
 
536
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
 
539
 
540
 
 
359
 
360
  demo.launch()
361
  '''
362
+ '''
363
  import gradio as gr
364
  from transformers import TFBertForSequenceClassification, BertTokenizer
365
  import tensorflow as tf
 
534
  )
535
 
536
  demo.launch()
537
+ '''
538
+ import gradio as gr
539
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
540
+ import torch
541
+ from scipy.special import softmax
542
+ import praw
543
+ import os
544
+ import pytesseract
545
+ from PIL import Image
546
+ import cv2
547
+ import numpy as np
548
+ import re
549
+ import matplotlib.pyplot as plt
550
+ import pandas as pd
551
+
552
+ # Load main lightweight model (PyTorch based)
553
+ main_model_name = "distilbert-base-uncased-finetuned-sst-2-english"
554
+ model = AutoModelForSequenceClassification.from_pretrained(main_model_name)
555
+ tokenizer = AutoTokenizer.from_pretrained(main_model_name)
556
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
557
+ model.to(device)
558
 
559
+ # Load fallback model
560
+ fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
561
+ fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
562
+ fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name).to(device)
563
+
564
+ # Reddit API setup
565
+ reddit = praw.Reddit(
566
+ client_id=os.getenv("REDDIT_CLIENT_ID"),
567
+ client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
568
+ user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-ui-finalyear2025-shrish191")
569
+ )
570
+
571
+ def fetch_reddit_text(reddit_url):
572
+ try:
573
+ submission = reddit.submission(url=reddit_url)
574
+ return f"{submission.title}\n\n{submission.selftext}"
575
+ except Exception as e:
576
+ return f"Error fetching Reddit post: {str(e)}"
577
+
578
+ def fallback_classifier(text):
579
+ encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device)
580
+ with torch.no_grad():
581
+ output = fallback_model(**encoded_input)
582
+ scores = softmax(output.logits.cpu().numpy()[0])
583
+ labels = ['Negative', 'Neutral', 'Positive']
584
+ return f"Prediction: {labels[scores.argmax()]}"
585
+
586
+ def clean_ocr_text(text):
587
+ text = text.strip()
588
+ text = re.sub(r'\s+', ' ', text)
589
+ text = re.sub(r'[^\x00-\x7F]+', '', text)
590
+ return text
591
+
592
+ def classify_sentiment(text_input, reddit_url, image):
593
+ if reddit_url.strip():
594
+ text = fetch_reddit_text(reddit_url)
595
+ elif image is not None:
596
+ try:
597
+ img_array = np.array(image)
598
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
599
+ thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 11, 2)
600
+ text = pytesseract.image_to_string(thresh)
601
+ text = clean_ocr_text(text)
602
+ except Exception as e:
603
+ return f"[!] OCR failed: {str(e)}"
604
+ elif text_input.strip():
605
+ text = text_input
606
+ else:
607
+ return "[!] Please enter some text, upload an image, or provide a Reddit URL."
608
+
609
+ if text.lower().startswith("error") or "Unable to extract" in text:
610
+ return f"[!] {text}"
611
+
612
+ # Truncate to first 400 words
613
+ text = ' '.join(text.split()[:400])
614
+
615
+ try:
616
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
617
+ with torch.no_grad():
618
+ outputs = model(**inputs)
619
+ scores = softmax(outputs.logits.cpu().numpy()[0])
620
+ labels = ['Negative', 'Positive']
621
+ return f"Prediction: {labels[scores.argmax()]}"
622
+ except Exception as e:
623
+ return f"[!] Prediction error: {str(e)}"
624
+
625
+ def analyze_subreddit(subreddit_name):
626
+ try:
627
+ subreddit = reddit.subreddit(subreddit_name)
628
+ posts = list(subreddit.hot(limit=20))
629
+
630
+ sentiments = []
631
+ titles = []
632
+
633
+ for post in posts:
634
+ text = f"{post.title}\n{post.selftext}"
635
+ text = ' '.join(text.split()[:400])
636
+ try:
637
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
638
+ with torch.no_grad():
639
+ outputs = model(**inputs)
640
+ scores = softmax(outputs.logits.cpu().numpy()[0])
641
+ labels = ['Negative', 'Positive']
642
+ sentiment = labels[scores.argmax()]
643
+ except:
644
+ sentiment = "Fallback"
645
+ sentiments.append(sentiment)
646
+ titles.append(post.title)
647
+
648
+ df = pd.DataFrame({"Title": titles, "Sentiment": sentiments})
649
+ sentiment_counts = df["Sentiment"].value_counts()
650
+
651
+ fig, ax = plt.subplots()
652
+ sentiment_counts.plot(kind="bar", ax=ax)
653
+ ax.set_title(f"Sentiment Distribution in r/{subreddit_name}")
654
+ ax.set_xlabel("Sentiment")
655
+ ax.set_ylabel("Number of Posts")
656
+
657
+ return fig, df
658
+ except Exception as e:
659
+ return f"[!] Error: {str(e)}", pd.DataFrame()
660
+
661
+ main_interface = gr.Interface(
662
+ fn=classify_sentiment,
663
+ inputs=[
664
+ gr.Textbox(label="Text Input", placeholder="Paste content here...", lines=4),
665
+ gr.Textbox(label="Reddit Post URL", placeholder="Optional", lines=1),
666
+ gr.Image(label="Upload Image (optional)", type="pil")
667
+ ],
668
+ outputs="text",
669
+ title="Sentiment Analyzer",
670
+ description="πŸ” Analyze sentiment of any text, Reddit post URL, or image content."
671
+ )
672
+
673
+ subreddit_interface = gr.Interface(
674
+ fn=analyze_subreddit,
675
+ inputs=gr.Textbox(label="Subreddit Name", placeholder="e.g., AskReddit"),
676
+ outputs=[
677
+ gr.Plot(label="Sentiment Distribution"),
678
+ gr.Dataframe(label="Post Titles and Sentiments", wrap=True)
679
+ ],
680
+ title="Subreddit Sentiment Analysis",
681
+ description="πŸ“Š Analyze top 20 posts of any subreddit."
682
+ )
683
+
684
+ demo = gr.TabbedInterface(
685
+ interface_list=[main_interface, subreddit_interface],
686
+ tab_names=["General Sentiment Analysis", "Subreddit Analysis"]
687
+ )
688
+
689
+ demo.launch()
690
 
691
 
692