File size: 5,766 Bytes
ea3c1f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ade9aa
 
 
 
 
 
 
 
 
ea3c1f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8d9884
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import re
import math
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import torch
import torch.nn as nn
from transformers import ElectraModel, AutoTokenizer
import numpy as np
from sklearn.linear_model import LinearRegression
from collections import defaultdict
import base64
from io import BytesIO

# ํฐํŠธ ์„ค์ •
font_path = './NanumGothic.ttf'
fm.fontManager.addfont(font_path)
plt.rcParams['font.family'] = fm.FontProperties(fname=font_path).get_name()
plt.rcParams['axes.unicode_minus'] = False

# ๋ผ๋ฒจ ์ •์˜
LABELS = ['๋ถˆํ‰/๋ถˆ๋งŒ', 'ํ™˜์˜/ํ˜ธ์˜', '๊ฐ๋™/๊ฐํƒ„', '์ง€๊ธ‹์ง€๊ธ‹', '๊ณ ๋งˆ์›€', '์Šฌํ””', 'ํ™”๋‚จ/๋ถ„๋…ธ', '์กด๊ฒฝ', '๊ธฐ๋Œ€๊ฐ', '์šฐ์ญ๋Œ/๋ฌด์‹œํ•จ', 
          '์•ˆํƒ€๊นŒ์›€/์‹ค๋ง', '๋น„์žฅํ•จ', '์˜์‹ฌ/๋ถˆ์‹ ', '๋ฟŒ๋“ฏํ•จ', 'ํŽธ์•ˆ/์พŒ์ ', '์‹ ๊ธฐํ•จ/๊ด€์‹ฌ', '์•„๊ปด์ฃผ๋Š”', '๋ถ€๋„๋Ÿฌ์›€', '๊ณตํฌ/๋ฌด์„œ์›€', 
          '์ ˆ๋ง', 'ํ•œ์‹ฌํ•จ', '์—ญ๊ฒจ์›€/์ง•๊ทธ๋Ÿฌ์›€', '์งœ์ฆ', '์–ด์ด์—†์Œ', '์—†์Œ', 'ํŒจ๋ฐฐ/์ž๊ธฐํ˜์˜ค', '๊ท€์ฐฎ์Œ', 'ํž˜๋“ฆ/์ง€์นจ', '์ฆ๊ฑฐ์›€/์‹ ๋‚จ', 
          '๊นจ๋‹ฌ์Œ', '์ฃ„์ฑ…๊ฐ', '์ฆ์˜ค/ํ˜์˜ค', 'ํ๋ญ‡ํ•จ(๊ท€์—ฌ์›€/์˜ˆ์จ)', '๋‹นํ™ฉ/๋‚œ์ฒ˜', '๊ฒฝ์•…', '๋ถ€๋‹ด/์•ˆ_๋‚ดํ‚ด', '์„œ๋Ÿฌ์›€', '์žฌ๋ฏธ์—†์Œ', 
          '๋ถˆ์Œํ•จ/์—ฐ๋ฏผ', '๋†€๋žŒ', 'ํ–‰๋ณต', '๋ถˆ์•ˆ/๊ฑฑ์ •', '๊ธฐ์จ', '์•ˆ์‹ฌ/์‹ ๋ขฐ']
NEGATIVE_EMOTIONS = [
    '๋ถˆํ‰/๋ถˆ๋งŒ', '์ง€๊ธ‹์ง€๊ธ‹', '์Šฌํ””', 'ํ™”๋‚จ/๋ถ„๋…ธ', '์˜์‹ฌ/๋ถˆ์‹ ', '๊ณตํฌ/๋ฌด์„œ์›€', '์ ˆ๋ง', 'ํ•œ์‹ฌํ•จ', '์—ญ๊ฒจ์›€/์ง•๊ทธ๋Ÿฌ์›€', '์งœ์ฆ', '์–ด์ด์—†์Œ',
    'ํŒจ๋ฐฐ/์ž๊ธฐํ˜์˜ค', '๊ท€์ฐฎ์Œ', 'ํž˜๋“ฆ/์ง€์นจ', '์ฃ„์ฑ…๊ฐ', '์ฆ์˜ค/ํ˜์˜ค', '๋‹นํ™ฉ/๋‚œ์ฒ˜', '๋ถ€๋‹ด/์•ˆ_๋‚ดํ‚ด', '์„œ๋Ÿฌ์›€', '์žฌ๋ฏธ์—†์Œ'
]

# ๋””๋ฐ”์ด์Šค
device = "cuda" if torch.cuda.is_available() else "cpu"

# ๋ชจ๋ธ ์ •์˜
class KOTEtagger(nn.Module):
    def __init__(self):
        super().__init__()
        self.electra = ElectraModel.from_pretrained("beomi/KcELECTRA-base", revision='v2021').to(device)
        self.tokenizer = AutoTokenizer.from_pretrained("beomi/KcELECTRA-base", revision='v2021')
        self.classifier = nn.Linear(self.electra.config.hidden_size, 44).to(device)
    
    def forward(self, text):
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=512,
            return_token_type_ids=False,
            padding="max_length",
            return_attention_mask=True,
            return_tensors='pt',
        ).to(device)
        output = self.electra(encoding["input_ids"], attention_mask=encoding["attention_mask"])
        output = output.last_hidden_state[:, 0, :]
        output = self.classifier(output)
        return torch.sigmoid(output)

# ๋ชจ๋ธ ๋กœ๋“œ
trained_model = KOTEtagger()
trained_model.load_state_dict(torch.load("kote_pytorch_lightning.bin", map_location=device), strict=False)
trained_model.eval()

# ํ•จ์ˆ˜๋“ค
def parse_dialogue(text):
    lines = text.strip().split('\n')
    return [
        (match.group(1).strip(), match.group(2).strip())
        for line in lines
        if (match := re.match(r"([^:]+):(.+)", line.strip()))
    ]

def adjusted_score(raw_score, k=5):
    return 100 / (1 + math.exp(-k * (raw_score - 0.5)))

def apply_ema(scores, alpha=0.4):
    if not scores:
        return []
    smoothed = [scores[0]]
    for s in scores[1:]:
        smoothed.append(alpha * s + (1 - alpha) * smoothed[-1])
    return smoothed

# ๋ฉ”์ธ ์ฒ˜๋ฆฌ ํ•จ์ˆ˜
def predict_and_plot(raw_text):
    dialogue = parse_dialogue(raw_text)
    emotion_scores = defaultdict(lambda: defaultdict(list))

    # ์˜ˆ์ธก
    for speaker, sentence in dialogue:
        preds = trained_model(sentence)[0]
        for label, score in zip(LABELS, preds):
            if label in NEGATIVE_EMOTIONS:
                adjusted = adjusted_score(score.item())
                emotion_scores[speaker][label].append(adjusted)

    html_output = ""
    for speaker in emotion_scores:
        html_output += f"<h3>{speaker} ๊ฐ์ • ์˜ˆ์ธก ๊ฒฐ๊ณผ:</h3>"
        fig, ax = plt.subplots(figsize=(10, 4))
        max_y = 0
        plotted = False
        predicted_scores = {}

        for label in NEGATIVE_EMOTIONS:
            raw_scores = emotion_scores[speaker].get(label, [])
            scores = apply_ema(raw_scores)
            if len(scores) >= 2 and max(scores) >= 40:
                X = np.arange(len(scores)).reshape(-1, 1)
                y = np.array(scores)
                model = LinearRegression().fit(X, y)
                predicted = model.predict([[len(scores)]])[0]
                predicted_scores[label] = predicted
                line, = ax.plot(scores, label=label)
                color = line.get_color()
                ax.plot([len(scores)-1, len(scores)], [scores[-1], predicted], linestyle='--', color=color)
                plotted = True
                max_y = max(max_y, predicted, *scores)
                html_output += f"<p>- {label}: ์˜ˆ์ธก ์ ์ˆ˜ {predicted:.2f}"
                if predicted >= 80:
                    html_output += f" <b style='color:red'>โš ๏ธ ๊ฒฝ๊ณ !</b>"
                html_output += "</p>"

        if plotted:
            ax.set_title(f"{speaker}์˜ ๋ถ€์ • ๊ฐ์ • ๋ณ€ํ™” ๋ฐ ์˜ˆ์ธก")
            ax.set_xlabel("๋ฐœํ™” ์ˆœ์„œ")
            ax.set_ylabel("๊ฐ์ • ์ ์ˆ˜")
            ax.set_ylim(0, max(100, max_y + 10))
            ax.legend()
            ax.grid(True)
            buf = BytesIO()
            plt.tight_layout()
            plt.savefig(buf, format='png')
            plt.close(fig)
            img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
            html_output += f"<img src='data:image/png;base64,{img_base64}'/><hr/>"
        else:
            html_output += "<p>โš ๏ธ ์‹œ๊ฐํ™”ํ•  ์ˆ˜ ์žˆ๋Š” ๊ฐ์ •์ด ์—†์Šต๋‹ˆ๋‹ค.</p><hr/>"

    return html_output