shrish191 commited on
Commit
79a81bc
Β·
verified Β·
1 Parent(s): 31d12dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -0
app.py CHANGED
@@ -535,6 +535,7 @@ demo = gr.TabbedInterface(
535
 
536
  demo.launch()
537
  '''
 
538
  import gradio as gr
539
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
540
  import torch
@@ -690,6 +691,177 @@ demo = gr.TabbedInterface(
690
  )
691
 
692
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693
 
694
 
695
 
 
535
 
536
  demo.launch()
537
  '''
538
+ '''
539
  import gradio as gr
540
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
541
  import torch
 
691
  )
692
 
693
  demo.launch()
694
+ '''
695
+ import gradio as gr
696
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
697
+ import torch
698
+ from scipy.special import softmax
699
+ import praw
700
+ import os
701
+ import pytesseract
702
+ from PIL import Image
703
+ import cv2
704
+ import numpy as np
705
+ import re
706
+ import matplotlib.pyplot as plt
707
+ import pandas as pd
708
+ from langdetect import detect
709
+
710
+ # Install tesseract OCR (only runs once in Hugging Face Spaces)
711
+ os.system("apt-get update && apt-get install -y tesseract-ocr")
712
+
713
+ # Load main lightweight model (English)
714
+ main_model_name = "distilbert-base-uncased-finetuned-sst-2-english"
715
+ model = AutoModelForSequenceClassification.from_pretrained(main_model_name)
716
+ tokenizer = AutoTokenizer.from_pretrained(main_model_name)
717
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
718
+ model.to(device)
719
+
720
+ # Load fallback multilingual model
721
+ multi_model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
722
+ multi_tokenizer = AutoTokenizer.from_pretrained(multi_model_name)
723
+ multi_model = AutoModelForSequenceClassification.from_pretrained(multi_model_name).to(device)
724
+
725
+ # Reddit API setup
726
+ reddit = praw.Reddit(
727
+ client_id=os.getenv("REDDIT_CLIENT_ID"),
728
+ client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
729
+ user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-ui-finalyear2025-shrish191")
730
+ )
731
+
732
+ def fetch_reddit_text(reddit_url):
733
+ try:
734
+ submission = reddit.submission(url=reddit_url)
735
+ return f"{submission.title}\n\n{submission.selftext}"
736
+ except Exception as e:
737
+ return f"Error fetching Reddit post: {str(e)}"
738
+
739
+ def multilingual_classifier(text):
740
+ encoded_input = multi_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device)
741
+ with torch.no_grad():
742
+ output = multi_model(**encoded_input)
743
+ scores = softmax(output.logits.cpu().numpy()[0])
744
+ stars = np.argmax(scores) + 1
745
+
746
+ if stars in [1, 2]:
747
+ return "Prediction: Negative"
748
+ elif stars == 3:
749
+ return "Prediction: Neutral"
750
+ else:
751
+ return "Prediction: Positive"
752
+
753
+ def clean_ocr_text(text):
754
+ text = text.strip()
755
+ text = re.sub(r'\s+', ' ', text)
756
+ text = re.sub(r'[^\x00-\x7F]+', '', text)
757
+ return text
758
+
759
+ def classify_sentiment(text_input, reddit_url, image):
760
+ if reddit_url.strip():
761
+ text = fetch_reddit_text(reddit_url)
762
+ elif image is not None:
763
+ try:
764
+ img_array = np.array(image)
765
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
766
+ thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 11, 2)
767
+ text = pytesseract.image_to_string(thresh)
768
+ text = clean_ocr_text(text)
769
+ except Exception as e:
770
+ return f"[!] OCR failed: {str(e)}"
771
+ elif text_input.strip():
772
+ text = text_input
773
+ else:
774
+ return "[!] Please enter some text, upload an image, or provide a Reddit URL."
775
+
776
+ if text.lower().startswith("error") or "Unable to extract" in text:
777
+ return f"[!] {text}"
778
+
779
+ # Truncate to first 400 words
780
+ text = ' '.join(text.split()[:400])
781
+
782
+ try:
783
+ lang = detect(text)
784
+ if lang == 'en':
785
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
786
+ with torch.no_grad():
787
+ outputs = model(**inputs)
788
+ scores = softmax(outputs.logits.cpu().numpy()[0])
789
+ labels = ['Negative', 'Positive']
790
+ return f"Prediction: {labels[scores.argmax()]}"
791
+ else:
792
+ return multilingual_classifier(text)
793
+ except Exception as e:
794
+ return f"[!] Prediction error: {str(e)}"
795
+
796
+ def analyze_subreddit(subreddit_name):
797
+ try:
798
+ subreddit = reddit.subreddit(subreddit_name)
799
+ posts = list(subreddit.hot(limit=20))
800
+
801
+ sentiments = []
802
+ titles = []
803
+
804
+ for post in posts:
805
+ text = f"{post.title}\n{post.selftext}"
806
+ text = ' '.join(text.split()[:400])
807
+ try:
808
+ lang = detect(text)
809
+ if lang == 'en':
810
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
811
+ with torch.no_grad():
812
+ outputs = model(**inputs)
813
+ scores = softmax(outputs.logits.cpu().numpy()[0])
814
+ labels = ['Negative', 'Positive']
815
+ sentiment = labels[scores.argmax()]
816
+ else:
817
+ sentiment = multilingual_classifier(text).split(": ")[-1]
818
+ except:
819
+ sentiment = "Error"
820
+ sentiments.append(sentiment)
821
+ titles.append(post.title)
822
+
823
+ df = pd.DataFrame({"Title": titles, "Sentiment": sentiments})
824
+ sentiment_counts = df["Sentiment"].value_counts()
825
+
826
+ fig, ax = plt.subplots()
827
+ sentiment_counts.plot(kind="bar", ax=ax)
828
+ ax.set_title(f"Sentiment Distribution in r/{subreddit_name}")
829
+ ax.set_xlabel("Sentiment")
830
+ ax.set_ylabel("Number of Posts")
831
+
832
+ return fig, df
833
+ except Exception as e:
834
+ return f"[!] Error: {str(e)}", pd.DataFrame()
835
+
836
+ main_interface = gr.Interface(
837
+ fn=classify_sentiment,
838
+ inputs=[
839
+ gr.Textbox(label="Text Input", placeholder="Paste content here...", lines=4),
840
+ gr.Textbox(label="Reddit Post URL", placeholder="Optional", lines=1),
841
+ gr.Image(label="Upload Image (optional)", type="pil")
842
+ ],
843
+ outputs="text",
844
+ title="Sentiment Analyzer",
845
+ description="πŸ” Analyze sentiment of any text, Reddit post URL, or image content."
846
+ )
847
+
848
+ subreddit_interface = gr.Interface(
849
+ fn=analyze_subreddit,
850
+ inputs=gr.Textbox(label="Subreddit Name", placeholder="e.g., AskReddit"),
851
+ outputs=[
852
+ gr.Plot(label="Sentiment Distribution"),
853
+ gr.Dataframe(label="Post Titles and Sentiments", wrap=True)
854
+ ],
855
+ title="Subreddit Sentiment Analysis",
856
+ description="πŸ“Š Analyze top 20 posts of any subreddit."
857
+ )
858
+
859
+ demo = gr.TabbedInterface(
860
+ interface_list=[main_interface, subreddit_interface],
861
+ tab_names=["General Sentiment Analysis", "Subreddit Analysis"]
862
+ )
863
+
864
+ demo.launch()
865
 
866
 
867