SHIROI-07 commited on
Commit
010ed04
·
verified ·
1 Parent(s): e9b203d

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +20 -57
model.py CHANGED
@@ -1,57 +1,20 @@
1
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
- import torch
3
- import numpy as np
4
-
5
- # Load model and tokenizer
6
- model_name = "SamLowe/roberta-base-go_emotions"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
-
10
- #Define distress weights for emotions
11
- DISTRESS_WEIGHTS = {
12
- "grief": 1.0,
13
- "sadness": 0.9,
14
- "fear": 0.9,
15
- "disgust": 0.9,
16
- "anger": 0.9,
17
- "remorse": 0.8,
18
- "disappointment": 0.8,
19
- "nervousness": 0.8,
20
- "disapproval": 0.7,
21
- "embarrassment": 0.7,
22
- "annoyance": 0.6,
23
- "confusion": 0.6,
24
- "surprise": 0.4,
25
- "desire": 0.4,
26
- "love": 0.3,
27
- "excitement": 0.3,
28
- "pride": 0.3,
29
- "optimism": 0.3,
30
- "admiration": 0.2,
31
- "gratitude": 0.2,
32
- "relief": 0.2,
33
- "joy": 0.2,
34
- "amusement": 0.2,
35
- "neutral": 0.1,
36
- }
37
- # Labels from GoEmotions
38
- EMOTION_LABELS = tokenizer.convert_ids_to_tokens(list(range(model.config.num_labels)))
39
-
40
- # Ensure the model is in evaluation mode
41
- model.eval()
42
-
43
- def predict_emotions(text: str):
44
- inputs = tokenizer(text, return_tensors="pt", truncation=True)
45
- outputs = model(**inputs)
46
- probs = torch.sigmoid(outputs.logits)[0].detach().numpy()
47
-
48
- threshold = 0.3 # You can tune this
49
- predicted = {label: float(prob) for label, prob in zip(model.config.id2label.values(), probs) if prob > threshold}
50
- return predicted
51
-
52
- def calculate_distress(emotions: dict):
53
- distress_score = sum(
54
- emotions.get(emotion, 0) * DISTRESS_WEIGHTS.get(emotion, 0)
55
- for emotion in emotions
56
- )
57
- return round(distress_score, 3)
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ import torch
3
+
4
+ class EmotionModel:
5
+ def __init__(self):
6
+ self.model_name = "SamLowe/roberta-base-go_emotions"
7
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
8
+ self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
9
+ self.labels = self.model.config.id2label
10
+
11
+ def predict(self, text):
12
+ inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True)
13
+ with torch.no_grad():
14
+ logits = self.model(**inputs).logits
15
+ probs = torch.sigmoid(logits)[0]
16
+ return {
17
+ self.labels[i]: float(probs[i])
18
+ for i in range(len(probs)) if probs[i] > 0.3
19
+ }
20
+