shrish191 commited on
Commit
5d14718
Β·
verified Β·
1 Parent(s): 36b4452

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -0
app.py CHANGED
@@ -692,6 +692,7 @@ demo = gr.TabbedInterface(
692
 
693
  demo.launch()
694
  '''
 
695
  import gradio as gr
696
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
697
  import torch
@@ -858,6 +859,189 @@ demo = gr.TabbedInterface(
858
  )
859
 
860
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
861
 
862
 
863
 
 
692
 
693
  demo.launch()
694
  '''
695
+ '''
696
  import gradio as gr
697
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
698
  import torch
 
859
  )
860
 
861
  demo.launch()
862
+ '''
863
+ import gradio as gr
864
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
865
+ import torch
866
+ from scipy.special import softmax
867
+ import praw
868
+ import os
869
+ import pytesseract
870
+ from PIL import Image
871
+ import cv2
872
+ import numpy as np
873
+ import re
874
+ import matplotlib.pyplot as plt
875
+ import pandas as pd
876
+ from langdetect import detect
877
+
878
+ # Install tesseract OCR (only runs once in Hugging Face Spaces)
879
+ os.system("apt-get update && apt-get install -y tesseract-ocr")
880
+
881
+ # Load main lightweight model (English)
882
+ main_model_name = "distilbert-base-uncased-finetuned-sst-2-english"
883
+ model = AutoModelForSequenceClassification.from_pretrained(main_model_name)
884
+ tokenizer = AutoTokenizer.from_pretrained(main_model_name)
885
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
886
+ model.to(device)
887
+
888
+ # Load multilingual fallback model (global languages)
889
+ multi_model_name = "cardiffnlp/twitter-xlm-roberta-base-sentiment"
890
+ multi_tokenizer = AutoTokenizer.from_pretrained(multi_model_name)
891
+ multi_model = AutoModelForSequenceClassification.from_pretrained(multi_model_name).to(device)
892
+ multi_labels = ['Negative', 'Neutral', 'Positive']
893
+
894
+ # Load Hinglish/Hindi fallback model
895
+ hinglish_model_name = "iisc-dsi/hinglish-sentiment-model"
896
+ hinglish_tokenizer = AutoTokenizer.from_pretrained(hinglish_model_name)
897
+ hinglish_model = AutoModelForSequenceClassification.from_pretrained(hinglish_model_name).to(device)
898
+ hinglish_labels = ['Negative', 'Neutral', 'Positive']
899
+
900
+ # Reddit API setup
901
+ reddit = praw.Reddit(
902
+ client_id=os.getenv("REDDIT_CLIENT_ID"),
903
+ client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
904
+ user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-ui-finalyear2025-shrish191")
905
+ )
906
+
907
+ def fetch_reddit_text(reddit_url):
908
+ try:
909
+ submission = reddit.submission(url=reddit_url)
910
+ return f"{submission.title}\n\n{submission.selftext}"
911
+ except Exception as e:
912
+ return f"Error fetching Reddit post: {str(e)}"
913
+
914
+ def multilingual_classifier(text):
915
+ encoded_input = multi_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device)
916
+ with torch.no_grad():
917
+ output = multi_model(**encoded_input)
918
+ scores = softmax(output.logits.cpu().numpy()[0])
919
+ return f"Prediction: {multi_labels[np.argmax(scores)]}"
920
+
921
+ def hinglish_classifier(text):
922
+ encoded_input = hinglish_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device)
923
+ with torch.no_grad():
924
+ output = hinglish_model(**encoded_input)
925
+ scores = softmax(output.logits.cpu().numpy()[0])
926
+ return f"Prediction: {hinglish_labels[np.argmax(scores)]}"
927
+
928
+ def clean_ocr_text(text):
929
+ text = text.strip()
930
+ text = re.sub(r'\s+', ' ', text)
931
+ text = re.sub(r'[^\x00-\x7F]+', '', text)
932
+ return text
933
+
934
+ def classify_sentiment(text_input, reddit_url, image):
935
+ if reddit_url.strip():
936
+ text = fetch_reddit_text(reddit_url)
937
+ elif image is not None:
938
+ try:
939
+ img_array = np.array(image)
940
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
941
+ thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 11, 2)
942
+ text = pytesseract.image_to_string(thresh)
943
+ text = clean_ocr_text(text)
944
+ except Exception as e:
945
+ return f"[!] OCR failed: {str(e)}"
946
+ elif text_input.strip():
947
+ text = text_input
948
+ else:
949
+ return "[!] Please enter some text, upload an image, or provide a Reddit URL."
950
+
951
+ if text.lower().startswith("error") or "Unable to extract" in text:
952
+ return f"[!] {text}"
953
+
954
+ text = ' '.join(text.split()[:400])
955
+
956
+ try:
957
+ lang = detect(text)
958
+ if lang == 'en':
959
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
960
+ with torch.no_grad():
961
+ outputs = model(**inputs)
962
+ scores = softmax(outputs.logits.cpu().numpy()[0])
963
+ labels = ['Negative', 'Positive']
964
+ label = labels[scores.argmax()]
965
+ elif lang == 'hi':
966
+ label = hinglish_classifier(text).split(": ")[-1]
967
+ else:
968
+ label = multilingual_classifier(text).split(": ")[-1]
969
+
970
+ return f"🌐 Detected Language: {lang.upper()} | Prediction: {label}"
971
+ except Exception as e:
972
+ return f"[!] Prediction error: {str(e)}"
973
+
974
+ def analyze_subreddit(subreddit_name):
975
+ try:
976
+ subreddit = reddit.subreddit(subreddit_name)
977
+ posts = list(subreddit.hot(limit=20))
978
+
979
+ sentiments = []
980
+ titles = []
981
+
982
+ for post in posts:
983
+ text = f"{post.title}\n{post.selftext}"
984
+ text = ' '.join(text.split()[:400])
985
+ try:
986
+ lang = detect(text)
987
+ if lang == 'en':
988
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
989
+ with torch.no_grad():
990
+ outputs = model(**inputs)
991
+ scores = softmax(outputs.logits.cpu().numpy()[0])
992
+ labels = ['Negative', 'Positive']
993
+ sentiment = labels[scores.argmax()]
994
+ elif lang == 'hi':
995
+ sentiment = hinglish_classifier(text).split(": ")[-1]
996
+ else:
997
+ sentiment = multilingual_classifier(text).split(": ")[-1]
998
+ except:
999
+ sentiment = "Error"
1000
+ sentiments.append(sentiment)
1001
+ titles.append(post.title)
1002
+
1003
+ df = pd.DataFrame({"Title": titles, "Sentiment": sentiments})
1004
+ sentiment_counts = df["Sentiment"].value_counts()
1005
+
1006
+ fig, ax = plt.subplots()
1007
+ sentiment_counts.plot(kind="bar", ax=ax)
1008
+ ax.set_title(f"Sentiment Distribution in r/{subreddit_name}")
1009
+ ax.set_xlabel("Sentiment")
1010
+ ax.set_ylabel("Number of Posts")
1011
+
1012
+ return fig, df
1013
+ except Exception as e:
1014
+ return f"[!] Error: {str(e)}", pd.DataFrame()
1015
+
1016
+ main_interface = gr.Interface(
1017
+ fn=classify_sentiment,
1018
+ inputs=[
1019
+ gr.Textbox(label="Text Input", placeholder="Paste content here...", lines=4),
1020
+ gr.Textbox(label="Reddit Post URL", placeholder="Optional", lines=1),
1021
+ gr.Image(label="Upload Image (optional)", type="pil")
1022
+ ],
1023
+ outputs="text",
1024
+ title="Sentiment Analyzer",
1025
+ description="πŸ” Analyze sentiment of any text, Reddit post URL, or image content."
1026
+ )
1027
+
1028
+ subreddit_interface = gr.Interface(
1029
+ fn=analyze_subreddit,
1030
+ inputs=gr.Textbox(label="Subreddit Name", placeholder="e.g., AskReddit"),
1031
+ outputs=[
1032
+ gr.Plot(label="Sentiment Distribution"),
1033
+ gr.Dataframe(label="Post Titles and Sentiments", wrap=True)
1034
+ ],
1035
+ title="Subreddit Sentiment Analysis",
1036
+ description="πŸ“Š Analyze top 20 posts of any subreddit."
1037
+ )
1038
+
1039
+ demo = gr.TabbedInterface(
1040
+ interface_list=[main_interface, subreddit_interface],
1041
+ tab_names=["General Sentiment Analysis", "Subreddit Analysis"]
1042
+ )
1043
+
1044
+ demo.launch()
1045
 
1046
 
1047