shrish191 commited on
Commit
80c934d
·
verified ·
1 Parent(s): 105f551

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -10
app.py CHANGED
@@ -717,11 +717,14 @@ tokenizer = AutoTokenizer.from_pretrained(main_model_name)
717
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
718
  model.to(device)
719
 
720
- # Load fallback multilingual model
721
- multi_model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
722
  multi_tokenizer = AutoTokenizer.from_pretrained(multi_model_name)
723
  multi_model = AutoModelForSequenceClassification.from_pretrained(multi_model_name).to(device)
724
 
 
 
 
725
  # Reddit API setup
726
  reddit = praw.Reddit(
727
  client_id=os.getenv("REDDIT_CLIENT_ID"),
@@ -741,14 +744,7 @@ def multilingual_classifier(text):
741
  with torch.no_grad():
742
  output = multi_model(**encoded_input)
743
  scores = softmax(output.logits.cpu().numpy()[0])
744
- stars = np.argmax(scores) + 1
745
-
746
- if stars in [1, 2]:
747
- return "Prediction: Negative"
748
- elif stars == 3:
749
- return "Prediction: Neutral"
750
- else:
751
- return "Prediction: Positive"
752
 
753
  def clean_ocr_text(text):
754
  text = text.strip()
@@ -867,6 +863,13 @@ demo.launch()
867
 
868
 
869
 
 
 
 
 
 
 
 
870
 
871
 
872
 
 
717
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
718
  model.to(device)
719
 
720
+ # Load fallback multilingual model (direct sentiment labels)
721
+ multi_model_name = "cardiffnlp/twitter-xlm-roberta-base-sentiment"
722
  multi_tokenizer = AutoTokenizer.from_pretrained(multi_model_name)
723
  multi_model = AutoModelForSequenceClassification.from_pretrained(multi_model_name).to(device)
724
 
725
+ # Labels for multilingual model
726
+ multi_labels = ['Negative', 'Neutral', 'Positive']
727
+
728
  # Reddit API setup
729
  reddit = praw.Reddit(
730
  client_id=os.getenv("REDDIT_CLIENT_ID"),
 
744
  with torch.no_grad():
745
  output = multi_model(**encoded_input)
746
  scores = softmax(output.logits.cpu().numpy()[0])
747
+ return f"Prediction: {multi_labels[np.argmax(scores)]}"
 
 
 
 
 
 
 
748
 
749
  def clean_ocr_text(text):
750
  text = text.strip()
 
863
 
864
 
865
 
866
+
867
+
868
+
869
+
870
+
871
+
872
+
873
 
874
 
875