leewatson commited on
Commit
ea3c1f8
ยท
verified ยท
1 Parent(s): ddbf59e

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. NanumGothic.ttf +3 -0
  3. app.py +16 -0
  4. emotion_predictor.py +132 -0
  5. requirements.txt +5 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ NanumGothic.ttf filter=lfs diff=lfs merge=lfs -text
NanumGothic.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48a28e97b34fc8e5b157657633670cd1b7de126cfc414da65ce9c3d5bc8be733
3
+ size 4691820
app.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from emotion_predictor import predict_and_plot
3
+
4
+ def analyze_dialogue(text):
5
+ return predict_and_plot(text)
6
+
7
+ iface = gr.Interface(
8
+ fn=analyze_dialogue,
9
+ inputs=gr.Textbox(lines=15, label="๋Œ€ํ™” ์ž…๋ ฅ (ํ˜•์‹: ํ™”์ž: ๋ฐœํ™”๋ฌธ)"),
10
+ outputs="html",
11
+ title="KOTE ๊ฐ์ • ์˜ˆ์ธก ๋ฐ ์‹œ๊ฐํ™”",
12
+ description="ํ˜•์‹์— ๋งž๋Š” ๋Œ€ํ™”๋ฅผ ์ž…๋ ฅํ•˜๋ฉด, ํ™”์ž๋ณ„ ๋ถ€์ • ๊ฐ์ • ์˜ˆ์ธก๊ณผ ์‹œ๊ฐํ™” ๊ฒฐ๊ณผ๋ฅผ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค."
13
+ )
14
+
15
+ iface.launch()
16
+ #11
emotion_predictor.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import math
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib.font_manager as fm
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import ElectraModel, AutoTokenizer
8
+ import numpy as np
9
+ from sklearn.linear_model import LinearRegression
10
+ from collections import defaultdict
11
+ import base64
12
+ from io import BytesIO
13
+
14
+ # ํฐํŠธ ์„ค์ •
15
+ font_path = './NanumGothic.ttf'
16
+ fm.fontManager.addfont(font_path)
17
+ plt.rcParams['font.family'] = fm.FontProperties(fname=font_path).get_name()
18
+ plt.rcParams['axes.unicode_minus'] = False
19
+
20
+ # ๋ผ๋ฒจ ์ •์˜
21
+ LABELS = [ ... ] # ์ƒ๋žต ์—†์ด LABEL ์ „์ฒด ๋ฆฌ์ŠคํŠธ ์‚ฝ์ž…
22
+ NEGATIVE_EMOTIONS = [ ... ] # ์ƒ๋žต ์—†์ด NEGATIVE ์ „์ฒด ๋ฆฌ์ŠคํŠธ ์‚ฝ์ž…
23
+
24
+ # ๋””๋ฐ”์ด์Šค
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+ # ๋ชจ๋ธ ์ •์˜
28
+ class KOTEtagger(nn.Module):
29
+ def __init__(self):
30
+ super().__init__()
31
+ self.electra = ElectraModel.from_pretrained("beomi/KcELECTRA-base", revision='v2021').to(device)
32
+ self.tokenizer = AutoTokenizer.from_pretrained("beomi/KcELECTRA-base", revision='v2021')
33
+ self.classifier = nn.Linear(self.electra.config.hidden_size, 44).to(device)
34
+
35
+ def forward(self, text):
36
+ encoding = self.tokenizer.encode_plus(
37
+ text,
38
+ add_special_tokens=True,
39
+ max_length=512,
40
+ return_token_type_ids=False,
41
+ padding="max_length",
42
+ return_attention_mask=True,
43
+ return_tensors='pt',
44
+ ).to(device)
45
+ output = self.electra(encoding["input_ids"], attention_mask=encoding["attention_mask"])
46
+ output = output.last_hidden_state[:, 0, :]
47
+ output = self.classifier(output)
48
+ return torch.sigmoid(output)
49
+
50
+ # ๋ชจ๋ธ ๋กœ๋“œ
51
+ trained_model = KOTEtagger()
52
+ trained_model.load_state_dict(torch.load("kote_pytorch_lightning.bin", map_location=device), strict=False)
53
+ trained_model.eval()
54
+
55
+ # ํ•จ์ˆ˜๋“ค
56
+ def parse_dialogue(text):
57
+ lines = text.strip().split('\n')
58
+ return [
59
+ (match.group(1).strip(), match.group(2).strip())
60
+ for line in lines
61
+ if (match := re.match(r"([^:]+):(.+)", line.strip()))
62
+ ]
63
+
64
+ def adjusted_score(raw_score, k=5):
65
+ return 100 / (1 + math.exp(-k * (raw_score - 0.5)))
66
+
67
+ def apply_ema(scores, alpha=0.4):
68
+ if not scores:
69
+ return []
70
+ smoothed = [scores[0]]
71
+ for s in scores[1:]:
72
+ smoothed.append(alpha * s + (1 - alpha) * smoothed[-1])
73
+ return smoothed
74
+
75
+ # ๋ฉ”์ธ ์ฒ˜๋ฆฌ ํ•จ์ˆ˜
76
+ def predict_and_plot(raw_text):
77
+ dialogue = parse_dialogue(raw_text)
78
+ emotion_scores = defaultdict(lambda: defaultdict(list))
79
+
80
+ # ์˜ˆ์ธก
81
+ for speaker, sentence in dialogue:
82
+ preds = trained_model(sentence)[0]
83
+ for label, score in zip(LABELS, preds):
84
+ if label in NEGATIVE_EMOTIONS:
85
+ adjusted = adjusted_score(score.item())
86
+ emotion_scores[speaker][label].append(adjusted)
87
+
88
+ html_output = ""
89
+ for speaker in emotion_scores:
90
+ html_output += f"<h3>{speaker} ๊ฐ์ • ์˜ˆ์ธก ๊ฒฐ๊ณผ:</h3>"
91
+ fig, ax = plt.subplots(figsize=(10, 4))
92
+ max_y = 0
93
+ plotted = False
94
+ predicted_scores = {}
95
+
96
+ for label in NEGATIVE_EMOTIONS:
97
+ raw_scores = emotion_scores[speaker].get(label, [])
98
+ scores = apply_ema(raw_scores)
99
+ if len(scores) >= 2 and max(scores) >= 40:
100
+ X = np.arange(len(scores)).reshape(-1, 1)
101
+ y = np.array(scores)
102
+ model = LinearRegression().fit(X, y)
103
+ predicted = model.predict([[len(scores)]])[0]
104
+ predicted_scores[label] = predicted
105
+ line, = ax.plot(scores, label=label)
106
+ color = line.get_color()
107
+ ax.plot([len(scores)-1, len(scores)], [scores[-1], predicted], linestyle='--', color=color)
108
+ plotted = True
109
+ max_y = max(max_y, predicted, *scores)
110
+ html_output += f"<p>- {label}: ์˜ˆ์ธก ์ ์ˆ˜ {predicted:.2f}"
111
+ if predicted >= 80:
112
+ html_output += f" <b style='color:red'>โš ๏ธ ๊ฒฝ๊ณ !</b>"
113
+ html_output += "</p>"
114
+
115
+ if plotted:
116
+ ax.set_title(f"{speaker}์˜ ๋ถ€์ • ๊ฐ์ • ๋ณ€ํ™” ๋ฐ ์˜ˆ์ธก")
117
+ ax.set_xlabel("๋ฐœํ™” ์ˆœ์„œ")
118
+ ax.set_ylabel("๊ฐ์ • ์ ์ˆ˜")
119
+ ax.set_ylim(0, max(100, max_y + 10))
120
+ ax.legend()
121
+ ax.grid(True)
122
+ buf = BytesIO()
123
+ plt.tight_layout()
124
+ plt.savefig(buf, format='png')
125
+ plt.close(fig)
126
+ img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
127
+ html_output += f"<img src='data:image/png;base64,{img_base64}'/><hr/>"
128
+ else:
129
+ html_output += "<p>โš ๏ธ ์‹œ๊ฐํ™”ํ•  ์ˆ˜ ์žˆ๋Š” ๊ฐ์ •์ด ์—†์Šต๋‹ˆ๋‹ค.</p><hr/>"
130
+
131
+ return html_output
132
+ #22
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ matplotlib
4
+ scikit-learn
5
+ gradio