Update app.py
Browse files
app.py
CHANGED
@@ -717,11 +717,14 @@ tokenizer = AutoTokenizer.from_pretrained(main_model_name)
|
|
717 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
718 |
model.to(device)
|
719 |
|
720 |
-
# Load fallback multilingual model
|
721 |
-
multi_model_name = "
|
722 |
multi_tokenizer = AutoTokenizer.from_pretrained(multi_model_name)
|
723 |
multi_model = AutoModelForSequenceClassification.from_pretrained(multi_model_name).to(device)
|
724 |
|
|
|
|
|
|
|
725 |
# Reddit API setup
|
726 |
reddit = praw.Reddit(
|
727 |
client_id=os.getenv("REDDIT_CLIENT_ID"),
|
@@ -741,14 +744,7 @@ def multilingual_classifier(text):
|
|
741 |
with torch.no_grad():
|
742 |
output = multi_model(**encoded_input)
|
743 |
scores = softmax(output.logits.cpu().numpy()[0])
|
744 |
-
|
745 |
-
|
746 |
-
if stars in [1, 2]:
|
747 |
-
return "Prediction: Negative"
|
748 |
-
elif stars == 3:
|
749 |
-
return "Prediction: Neutral"
|
750 |
-
else:
|
751 |
-
return "Prediction: Positive"
|
752 |
|
753 |
def clean_ocr_text(text):
|
754 |
text = text.strip()
|
@@ -867,6 +863,13 @@ demo.launch()
|
|
867 |
|
868 |
|
869 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
870 |
|
871 |
|
872 |
|
|
|
717 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
718 |
model.to(device)
|
719 |
|
720 |
+
# Load fallback multilingual model (direct sentiment labels)
|
721 |
+
multi_model_name = "cardiffnlp/twitter-xlm-roberta-base-sentiment"
|
722 |
multi_tokenizer = AutoTokenizer.from_pretrained(multi_model_name)
|
723 |
multi_model = AutoModelForSequenceClassification.from_pretrained(multi_model_name).to(device)
|
724 |
|
725 |
+
# Labels for multilingual model
|
726 |
+
multi_labels = ['Negative', 'Neutral', 'Positive']
|
727 |
+
|
728 |
# Reddit API setup
|
729 |
reddit = praw.Reddit(
|
730 |
client_id=os.getenv("REDDIT_CLIENT_ID"),
|
|
|
744 |
with torch.no_grad():
|
745 |
output = multi_model(**encoded_input)
|
746 |
scores = softmax(output.logits.cpu().numpy()[0])
|
747 |
+
return f"Prediction: {multi_labels[np.argmax(scores)]}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
748 |
|
749 |
def clean_ocr_text(text):
|
750 |
text = text.strip()
|
|
|
863 |
|
864 |
|
865 |
|
866 |
+
|
867 |
+
|
868 |
+
|
869 |
+
|
870 |
+
|
871 |
+
|
872 |
+
|
873 |
|
874 |
|
875 |
|