kshs33_emotion_predict / emotion_predictor.py
leewatson's picture
Update emotion_predictor.py
f8d9884 verified
import re
import math
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import torch
import torch.nn as nn
from transformers import ElectraModel, AutoTokenizer
import numpy as np
from sklearn.linear_model import LinearRegression
from collections import defaultdict
import base64
from io import BytesIO
# ํฐํŠธ ์„ค์ •
font_path = './NanumGothic.ttf'
fm.fontManager.addfont(font_path)
plt.rcParams['font.family'] = fm.FontProperties(fname=font_path).get_name()
plt.rcParams['axes.unicode_minus'] = False
# ๋ผ๋ฒจ ์ •์˜
LABELS = ['๋ถˆํ‰/๋ถˆ๋งŒ', 'ํ™˜์˜/ํ˜ธ์˜', '๊ฐ๋™/๊ฐํƒ„', '์ง€๊ธ‹์ง€๊ธ‹', '๊ณ ๋งˆ์›€', '์Šฌํ””', 'ํ™”๋‚จ/๋ถ„๋…ธ', '์กด๊ฒฝ', '๊ธฐ๋Œ€๊ฐ', '์šฐ์ญ๋Œ/๋ฌด์‹œํ•จ',
'์•ˆํƒ€๊นŒ์›€/์‹ค๋ง', '๋น„์žฅํ•จ', '์˜์‹ฌ/๋ถˆ์‹ ', '๋ฟŒ๋“ฏํ•จ', 'ํŽธ์•ˆ/์พŒ์ ', '์‹ ๊ธฐํ•จ/๊ด€์‹ฌ', '์•„๊ปด์ฃผ๋Š”', '๋ถ€๋„๋Ÿฌ์›€', '๊ณตํฌ/๋ฌด์„œ์›€',
'์ ˆ๋ง', 'ํ•œ์‹ฌํ•จ', '์—ญ๊ฒจ์›€/์ง•๊ทธ๋Ÿฌ์›€', '์งœ์ฆ', '์–ด์ด์—†์Œ', '์—†์Œ', 'ํŒจ๋ฐฐ/์ž๊ธฐํ˜์˜ค', '๊ท€์ฐฎ์Œ', 'ํž˜๋“ฆ/์ง€์นจ', '์ฆ๊ฑฐ์›€/์‹ ๋‚จ',
'๊นจ๋‹ฌ์Œ', '์ฃ„์ฑ…๊ฐ', '์ฆ์˜ค/ํ˜์˜ค', 'ํ๋ญ‡ํ•จ(๊ท€์—ฌ์›€/์˜ˆ์จ)', '๋‹นํ™ฉ/๋‚œ์ฒ˜', '๊ฒฝ์•…', '๋ถ€๋‹ด/์•ˆ_๋‚ดํ‚ด', '์„œ๋Ÿฌ์›€', '์žฌ๋ฏธ์—†์Œ',
'๋ถˆ์Œํ•จ/์—ฐ๋ฏผ', '๋†€๋žŒ', 'ํ–‰๋ณต', '๋ถˆ์•ˆ/๊ฑฑ์ •', '๊ธฐ์จ', '์•ˆ์‹ฌ/์‹ ๋ขฐ']
NEGATIVE_EMOTIONS = [
'๋ถˆํ‰/๋ถˆ๋งŒ', '์ง€๊ธ‹์ง€๊ธ‹', '์Šฌํ””', 'ํ™”๋‚จ/๋ถ„๋…ธ', '์˜์‹ฌ/๋ถˆ์‹ ', '๊ณตํฌ/๋ฌด์„œ์›€', '์ ˆ๋ง', 'ํ•œ์‹ฌํ•จ', '์—ญ๊ฒจ์›€/์ง•๊ทธ๋Ÿฌ์›€', '์งœ์ฆ', '์–ด์ด์—†์Œ',
'ํŒจ๋ฐฐ/์ž๊ธฐํ˜์˜ค', '๊ท€์ฐฎ์Œ', 'ํž˜๋“ฆ/์ง€์นจ', '์ฃ„์ฑ…๊ฐ', '์ฆ์˜ค/ํ˜์˜ค', '๋‹นํ™ฉ/๋‚œ์ฒ˜', '๋ถ€๋‹ด/์•ˆ_๋‚ดํ‚ด', '์„œ๋Ÿฌ์›€', '์žฌ๋ฏธ์—†์Œ'
]
# ๋””๋ฐ”์ด์Šค
device = "cuda" if torch.cuda.is_available() else "cpu"
# ๋ชจ๋ธ ์ •์˜
class KOTEtagger(nn.Module):
def __init__(self):
super().__init__()
self.electra = ElectraModel.from_pretrained("beomi/KcELECTRA-base", revision='v2021').to(device)
self.tokenizer = AutoTokenizer.from_pretrained("beomi/KcELECTRA-base", revision='v2021')
self.classifier = nn.Linear(self.electra.config.hidden_size, 44).to(device)
def forward(self, text):
encoding = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=512,
return_token_type_ids=False,
padding="max_length",
return_attention_mask=True,
return_tensors='pt',
).to(device)
output = self.electra(encoding["input_ids"], attention_mask=encoding["attention_mask"])
output = output.last_hidden_state[:, 0, :]
output = self.classifier(output)
return torch.sigmoid(output)
# ๋ชจ๋ธ ๋กœ๋“œ
trained_model = KOTEtagger()
trained_model.load_state_dict(torch.load("kote_pytorch_lightning.bin", map_location=device), strict=False)
trained_model.eval()
# ํ•จ์ˆ˜๋“ค
def parse_dialogue(text):
lines = text.strip().split('\n')
return [
(match.group(1).strip(), match.group(2).strip())
for line in lines
if (match := re.match(r"([^:]+):(.+)", line.strip()))
]
def adjusted_score(raw_score, k=5):
return 100 / (1 + math.exp(-k * (raw_score - 0.5)))
def apply_ema(scores, alpha=0.4):
if not scores:
return []
smoothed = [scores[0]]
for s in scores[1:]:
smoothed.append(alpha * s + (1 - alpha) * smoothed[-1])
return smoothed
# ๋ฉ”์ธ ์ฒ˜๋ฆฌ ํ•จ์ˆ˜
def predict_and_plot(raw_text):
dialogue = parse_dialogue(raw_text)
emotion_scores = defaultdict(lambda: defaultdict(list))
# ์˜ˆ์ธก
for speaker, sentence in dialogue:
preds = trained_model(sentence)[0]
for label, score in zip(LABELS, preds):
if label in NEGATIVE_EMOTIONS:
adjusted = adjusted_score(score.item())
emotion_scores[speaker][label].append(adjusted)
html_output = ""
for speaker in emotion_scores:
html_output += f"<h3>{speaker} ๊ฐ์ • ์˜ˆ์ธก ๊ฒฐ๊ณผ:</h3>"
fig, ax = plt.subplots(figsize=(10, 4))
max_y = 0
plotted = False
predicted_scores = {}
for label in NEGATIVE_EMOTIONS:
raw_scores = emotion_scores[speaker].get(label, [])
scores = apply_ema(raw_scores)
if len(scores) >= 2 and max(scores) >= 40:
X = np.arange(len(scores)).reshape(-1, 1)
y = np.array(scores)
model = LinearRegression().fit(X, y)
predicted = model.predict([[len(scores)]])[0]
predicted_scores[label] = predicted
line, = ax.plot(scores, label=label)
color = line.get_color()
ax.plot([len(scores)-1, len(scores)], [scores[-1], predicted], linestyle='--', color=color)
plotted = True
max_y = max(max_y, predicted, *scores)
html_output += f"<p>- {label}: ์˜ˆ์ธก ์ ์ˆ˜ {predicted:.2f}"
if predicted >= 80:
html_output += f" <b style='color:red'>โš ๏ธ ๊ฒฝ๊ณ !</b>"
html_output += "</p>"
if plotted:
ax.set_title(f"{speaker}์˜ ๋ถ€์ • ๊ฐ์ • ๋ณ€ํ™” ๋ฐ ์˜ˆ์ธก")
ax.set_xlabel("๋ฐœํ™” ์ˆœ์„œ")
ax.set_ylabel("๊ฐ์ • ์ ์ˆ˜")
ax.set_ylim(0, max(100, max_y + 10))
ax.legend()
ax.grid(True)
buf = BytesIO()
plt.tight_layout()
plt.savefig(buf, format='png')
plt.close(fig)
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
html_output += f"<img src='data:image/png;base64,{img_base64}'/><hr/>"
else:
html_output += "<p>โš ๏ธ ์‹œ๊ฐํ™”ํ•  ์ˆ˜ ์žˆ๋Š” ๊ฐ์ •์ด ์—†์Šต๋‹ˆ๋‹ค.</p><hr/>"
return html_output