Spaces:
Sleeping
Sleeping
import re | |
import math | |
import matplotlib.pyplot as plt | |
import matplotlib.font_manager as fm | |
import torch | |
import torch.nn as nn | |
from transformers import ElectraModel, AutoTokenizer | |
import numpy as np | |
from sklearn.linear_model import LinearRegression | |
from collections import defaultdict | |
import base64 | |
from io import BytesIO | |
# ํฐํธ ์ค์ | |
font_path = './NanumGothic.ttf' | |
fm.fontManager.addfont(font_path) | |
plt.rcParams['font.family'] = fm.FontProperties(fname=font_path).get_name() | |
plt.rcParams['axes.unicode_minus'] = False | |
# ๋ผ๋ฒจ ์ ์ | |
LABELS = ['๋ถํ/๋ถ๋ง', 'ํ์/ํธ์', '๊ฐ๋/๊ฐํ', '์ง๊ธ์ง๊ธ', '๊ณ ๋ง์', '์ฌํ', 'ํ๋จ/๋ถ๋ ธ', '์กด๊ฒฝ', '๊ธฐ๋๊ฐ', '์ฐ์ญ๋/๋ฌด์ํจ', | |
'์ํ๊น์/์ค๋ง', '๋น์ฅํจ', '์์ฌ/๋ถ์ ', '๋ฟ๋ฏํจ', 'ํธ์/์พ์ ', '์ ๊ธฐํจ/๊ด์ฌ', '์๊ปด์ฃผ๋', '๋ถ๋๋ฌ์', '๊ณตํฌ/๋ฌด์์', | |
'์ ๋ง', 'ํ์ฌํจ', '์ญ๊ฒจ์/์ง๊ทธ๋ฌ์', '์ง์ฆ', '์ด์ด์์', '์์', 'ํจ๋ฐฐ/์๊ธฐํ์ค', '๊ท์ฐฎ์', 'ํ๋ฆ/์ง์นจ', '์ฆ๊ฑฐ์/์ ๋จ', | |
'๊นจ๋ฌ์', '์ฃ์ฑ ๊ฐ', '์ฆ์ค/ํ์ค', 'ํ๋ญํจ(๊ท์ฌ์/์์จ)', '๋นํฉ/๋์ฒ', '๊ฒฝ์ ', '๋ถ๋ด/์_๋ดํด', '์๋ฌ์', '์ฌ๋ฏธ์์', | |
'๋ถ์ํจ/์ฐ๋ฏผ', '๋๋', 'ํ๋ณต', '๋ถ์/๊ฑฑ์ ', '๊ธฐ์จ', '์์ฌ/์ ๋ขฐ'] | |
NEGATIVE_EMOTIONS = [ | |
'๋ถํ/๋ถ๋ง', '์ง๊ธ์ง๊ธ', '์ฌํ', 'ํ๋จ/๋ถ๋ ธ', '์์ฌ/๋ถ์ ', '๊ณตํฌ/๋ฌด์์', '์ ๋ง', 'ํ์ฌํจ', '์ญ๊ฒจ์/์ง๊ทธ๋ฌ์', '์ง์ฆ', '์ด์ด์์', | |
'ํจ๋ฐฐ/์๊ธฐํ์ค', '๊ท์ฐฎ์', 'ํ๋ฆ/์ง์นจ', '์ฃ์ฑ ๊ฐ', '์ฆ์ค/ํ์ค', '๋นํฉ/๋์ฒ', '๋ถ๋ด/์_๋ดํด', '์๋ฌ์', '์ฌ๋ฏธ์์' | |
] | |
# ๋๋ฐ์ด์ค | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# ๋ชจ๋ธ ์ ์ | |
class KOTEtagger(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.electra = ElectraModel.from_pretrained("beomi/KcELECTRA-base", revision='v2021').to(device) | |
self.tokenizer = AutoTokenizer.from_pretrained("beomi/KcELECTRA-base", revision='v2021') | |
self.classifier = nn.Linear(self.electra.config.hidden_size, 44).to(device) | |
def forward(self, text): | |
encoding = self.tokenizer.encode_plus( | |
text, | |
add_special_tokens=True, | |
max_length=512, | |
return_token_type_ids=False, | |
padding="max_length", | |
return_attention_mask=True, | |
return_tensors='pt', | |
).to(device) | |
output = self.electra(encoding["input_ids"], attention_mask=encoding["attention_mask"]) | |
output = output.last_hidden_state[:, 0, :] | |
output = self.classifier(output) | |
return torch.sigmoid(output) | |
# ๋ชจ๋ธ ๋ก๋ | |
trained_model = KOTEtagger() | |
trained_model.load_state_dict(torch.load("kote_pytorch_lightning.bin", map_location=device), strict=False) | |
trained_model.eval() | |
# ํจ์๋ค | |
def parse_dialogue(text): | |
lines = text.strip().split('\n') | |
return [ | |
(match.group(1).strip(), match.group(2).strip()) | |
for line in lines | |
if (match := re.match(r"([^:]+):(.+)", line.strip())) | |
] | |
def adjusted_score(raw_score, k=5): | |
return 100 / (1 + math.exp(-k * (raw_score - 0.5))) | |
def apply_ema(scores, alpha=0.4): | |
if not scores: | |
return [] | |
smoothed = [scores[0]] | |
for s in scores[1:]: | |
smoothed.append(alpha * s + (1 - alpha) * smoothed[-1]) | |
return smoothed | |
# ๋ฉ์ธ ์ฒ๋ฆฌ ํจ์ | |
def predict_and_plot(raw_text): | |
dialogue = parse_dialogue(raw_text) | |
emotion_scores = defaultdict(lambda: defaultdict(list)) | |
# ์์ธก | |
for speaker, sentence in dialogue: | |
preds = trained_model(sentence)[0] | |
for label, score in zip(LABELS, preds): | |
if label in NEGATIVE_EMOTIONS: | |
adjusted = adjusted_score(score.item()) | |
emotion_scores[speaker][label].append(adjusted) | |
html_output = "" | |
for speaker in emotion_scores: | |
html_output += f"<h3>{speaker} ๊ฐ์ ์์ธก ๊ฒฐ๊ณผ:</h3>" | |
fig, ax = plt.subplots(figsize=(10, 4)) | |
max_y = 0 | |
plotted = False | |
predicted_scores = {} | |
for label in NEGATIVE_EMOTIONS: | |
raw_scores = emotion_scores[speaker].get(label, []) | |
scores = apply_ema(raw_scores) | |
if len(scores) >= 2 and max(scores) >= 40: | |
X = np.arange(len(scores)).reshape(-1, 1) | |
y = np.array(scores) | |
model = LinearRegression().fit(X, y) | |
predicted = model.predict([[len(scores)]])[0] | |
predicted_scores[label] = predicted | |
line, = ax.plot(scores, label=label) | |
color = line.get_color() | |
ax.plot([len(scores)-1, len(scores)], [scores[-1], predicted], linestyle='--', color=color) | |
plotted = True | |
max_y = max(max_y, predicted, *scores) | |
html_output += f"<p>- {label}: ์์ธก ์ ์ {predicted:.2f}" | |
if predicted >= 80: | |
html_output += f" <b style='color:red'>โ ๏ธ ๊ฒฝ๊ณ !</b>" | |
html_output += "</p>" | |
if plotted: | |
ax.set_title(f"{speaker}์ ๋ถ์ ๊ฐ์ ๋ณํ ๋ฐ ์์ธก") | |
ax.set_xlabel("๋ฐํ ์์") | |
ax.set_ylabel("๊ฐ์ ์ ์") | |
ax.set_ylim(0, max(100, max_y + 10)) | |
ax.legend() | |
ax.grid(True) | |
buf = BytesIO() | |
plt.tight_layout() | |
plt.savefig(buf, format='png') | |
plt.close(fig) | |
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') | |
html_output += f"<img src='data:image/png;base64,{img_base64}'/><hr/>" | |
else: | |
html_output += "<p>โ ๏ธ ์๊ฐํํ ์ ์๋ ๊ฐ์ ์ด ์์ต๋๋ค.</p><hr/>" | |
return html_output |