Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from transformers import AutoTokenizer, AutoModel | |
import streamlit as st | |
import os | |
import numpy as np | |
st.markdown(""" | |
<style> | |
.big-title { | |
font-size: 1.8em; | |
font-weight: 800; | |
margin-bottom: 0.2em; | |
} | |
.sub-info { | |
font-size: 1.1em; | |
color: #666; | |
margin-bottom: 1.2em; | |
} | |
.card { | |
background-color: #f1f3f6; | |
padding: 1.2em; | |
border-left: 5px solid #3366cc; | |
border-radius: 6px; | |
margin-bottom: 1em; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
KOTE_LABELS = [ | |
'๋ถํ/๋ถ๋ง', 'ํ์/ํธ์', '๊ฐ๋/๊ฐํ', '์ง๊ธ์ง๊ธ', '๊ณ ๋ง์', '์ฌํ', 'ํ๋จ/๋ถ๋ ธ', '์กด๊ฒฝ', | |
'๊ธฐ๋๊ฐ', '์ฐ์ญ๋/๋ฌด์ํจ', '์ํ๊น์/์ค๋ง', '๋น์ฅํจ', '์์ฌ/๋ถ์ ', '๋ฟ๋ฏํจ', 'ํธ์/์พ์ ', | |
'์ ๊ธฐํจ/๊ด์ฌ', '์๊ปด์ฃผ๋', '๋ถ๋๋ฌ์', '๊ณตํฌ/๋ฌด์์', '์ ๋ง', 'ํ์ฌํจ', '์ญ๊ฒจ์/์ง๊ทธ๋ฌ์', | |
'์ง์ฆ', '์ด์ด์์', '์์', 'ํจ๋ฐฐ/์๊ธฐํ์ค', '๊ท์ฐฎ์', 'ํ๋ฆ/์ง์นจ', '์ฆ๊ฑฐ์/์ ๋จ', '๊นจ๋ฌ์', | |
'์ฃ์ฑ ๊ฐ', '์ฆ์ค/ํ์ค', 'ํ๋ญํจ(๊ท์ฌ์/์์จ)', '๋นํฉ/๋์ฒ', '๊ฒฝ์ ', '๋ถ๋ด/์_๋ดํด', '์๋ฌ์', | |
'์ฌ๋ฏธ์์', '๋ถ์ํจ/์ฐ๋ฏผ', '๋๋', 'ํ๋ณต', '๋ถ์/๊ฑฑ์ ', '๊ธฐ์จ', '์์ฌ/์ ๋ขฐ' | |
] | |
class MLPClassifier(nn.Module): | |
def __init__(self, input_dim=1024, num_labels=44): | |
super(MLPClassifier, self).__init__() | |
self.mlp = nn.Sequential( | |
nn.Linear(input_dim, 512), | |
nn.BatchNorm1d(512), | |
nn.ReLU(), | |
nn.Dropout(0.3), | |
nn.Linear(512, 256), | |
nn.BatchNorm1d(256), | |
nn.ReLU(), | |
nn.Dropout(0.3), | |
nn.Linear(256, num_labels) | |
) | |
def forward(self, x): | |
return self.mlp(x) | |
def load_model(): | |
device = torch.device("cpu") | |
tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large") | |
base_model = AutoModel.from_pretrained("klue/roberta-large").eval() | |
mlp_model = MLPClassifier().eval() | |
ckpt_path = os.path.join("checkpoints", "mlp_model.pth") | |
mlp_model.load_state_dict(torch.load(ckpt_path, map_location=device)) | |
return tokenizer, base_model, mlp_model | |
tokenizer, base_model, mlp_model = load_model() | |
def predict_emotion(text, top_k=5): | |
encoded = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512) | |
with torch.no_grad(): | |
outputs = base_model(**encoded) | |
cls_emb = outputs.last_hidden_state[:, 0, :] | |
logits = mlp_model(cls_emb) | |
probs = torch.sigmoid(logits).squeeze(0).numpy() | |
result = sorted(zip(KOTE_LABELS, probs), key=lambda x: x[1], reverse=True) | |
return result[:top_k], probs | |
tabs = st.tabs(["๊ฐ์ ๋ถ์ ์ฒดํ", "AI๋ ์ด๋ป๊ฒ ๊ฐ์ ์ ์ดํดํ ๊น?", "Few-shot Fine-tuning์ด๋?", "ํ์ฉ๊ณผ ์์", "๊ธฐํ ์๋ฃ๋ฃ"]) | |
with tabs[0]: | |
st.markdown('<div class="big-title">๐ญ ํ๊ตญ์ด ๊ฐ์ ๋ถ์ AI ์ฒดํ</div>', unsafe_allow_html=True) | |
st.markdown('<div class="sub-info">2025๋ 1ํ๊ธฐ ๋์งํธ ์ธ๋ฌธํ ์ ๋ฌธ (SLA23501) ยท <b>Team ์๋ฌ๋๋ณผ</b><br>๊ฐ์ํ ยท ๊น๋์ฐ ยท ์ ์์ ยท ์ ํ์ฑ ยท ์ต์ข ์ค</div>', unsafe_allow_html=True) | |
st.markdown(""" | |
<div class="card"> | |
์ธ๊ณต์ง๋ฅ์ ์ ๋ ฅ๋ ๋ฌธ์ฅ์ ๋ถ์ํด ๊ฐ์ ์ด ์ด๋ป๊ฒ ํํ๋์๋์ง๋ฅผ ์์ธกํฉ๋๋ค.<br> | |
์๋์ ๋ฌธ์ฅ์ ์ ๋ ฅํ๊ฑฐ๋ ์์ ๋ฌธ์ฅ์ ๋ถ๋ฌ์จ ํ, ๊ฐ์ ์์ธก ๋ฒํผ์ ๋๋ฌ ์ฒดํํด๋ณด์ธ์. | |
</div> | |
""", unsafe_allow_html=True) | |
if "text_input" not in st.session_state: | |
st.session_state.text_input = "" | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
if st.button("๐ ์์ ๋ฌธ์ฅ ๋ถ๋ฌ์ค๊ธฐ"): | |
st.session_state.text_input = "๊ทธ๊ฑธ ์ด์ ๋งํด์ค? ์น์ ํ๋ค ์ ๋ง" | |
with col2: | |
predict_clicked = st.button("๐ ๊ฐ์ ์์ธกํ๊ธฐ") | |
text = st.text_area( | |
"โ๏ธ ๋ฌธ์ฅ์ ์ ๋ ฅํ์ธ์:", | |
value=st.session_state.text_input, | |
height=120, | |
placeholder="์: ์ค๋ ํ๋ฃจ ์ ๋ง ํ๋ณตํ์ด์." | |
) | |
st.session_state.text_input = text | |
if predict_clicked: | |
if text.strip(): | |
with st.spinner("AI๊ฐ ๊ฐ์ ์ ๋ถ์ ์ค์ ๋๋ค..."): | |
results, full_probs = predict_emotion(text) | |
top_emotion, top_prob = results[0] | |
st.markdown( | |
f'<div class="card"><div class="highlight">โ ๊ฐ์ฅ ๊ฐํ๊ฒ ํํ๋ ๊ฐ์ : <b>{top_emotion}</b> ({top_prob:.2f})</div>', | |
unsafe_allow_html=True | |
) | |
st.subheader("๐ ์์ ๊ฐ์ ๊ฒฐ๊ณผ") | |
for label, prob in results: | |
st.markdown(f"- **{label}**: `{prob:.3f}`") | |
st.subheader("๐ ํ๋ฅ ๋ถํฌ (Top 5)") | |
st.bar_chart({label: prob for label, prob in results}) | |
st.markdown("</div>", unsafe_allow_html=True) | |
else: | |
st.warning("๋ฌธ์ฅ์ ๋จผ์ ์ ๋ ฅํด์ฃผ์ธ์.") | |
with tabs[1]: | |
st.markdown('<div class="big-title">๐ค ์ธ๊ณต์ง๋ฅ์ ๊ฐ์ ์ ์ด๋ป๊ฒ ์ดํดํ ๊น์?</div>', unsafe_allow_html=True) | |
st.markdown(""" | |
<div class="card"> | |
์ธ๊ณต์ง๋ฅ์ ๋ฌธ์ฅ์ '์ซ์์ ๋ฒกํฐ'๋ก ๋ฐ๊พธ์ด ์ดํดํฉ๋๋ค.<br><br> | |
์ด ๊ณผ์ ์ ๋ค์๊ณผ ๊ฐ์ ๋จ๊ณ๋ก ์ด๋ฃจ์ด์ง๋๋ค: | |
<ol> | |
<li><b>์ฌ์ ํ์ต ์ธ์ด ๋ชจ๋ธ</b>(KLUE-RoBERTa)์ด ๋ฌธ์ฅ์ ์ฝ๊ณ ํต์ฌ ์๋ฏธ๋ฅผ ์ถ์ถํฉ๋๋ค.</li> | |
<li>์ด ๊ฒฐ๊ณผ๋ฅผ ๋ฐํ์ผ๋ก <b>๊ฐ์ ๋ถ๋ฅ๊ธฐ</b>(MLP)๊ฐ ๊ฐ์ ์ ์์ธกํฉ๋๋ค.</li> | |
<li>๊ฐ ๊ฐ์ ์ ๋ํ ๊ฐ๋ฅ์ฑ์ <b>ํ๋ฅ ๋ก</b> ๋ณด์ฌ์ค๋๋ค.</li> | |
</ol> | |
</div> | |
""", unsafe_allow_html=True) | |
with tabs[2]: | |
st.markdown('<div class="big-title">๐ง Few-shot Fine-tuning์ด๋?</div>', unsafe_allow_html=True) | |
st.markdown(""" | |
<div class="card"> | |
์ฐ๋ฆฌ๊ฐ ์ฌ์ฉํ๋ KLUE-RoBERTa ๋ชจ๋ธ์ ์ด๋ฏธ ์๋ง์ ๋ฌธ์ฅ์ ํ์ตํ ๊ฑฐ๋ํ ์ธ์ด ๋ชจ๋ธ์ ๋๋ค.<br><br> | |
ํ์ง๋ง ๊ฐ์ ๋ถ์์ด๋ผ๋ ํน์ ํ ์์ ์ ๋ง๊ฒ ์กฐ๊ธ ๋ ํ์ต์ํค๋ ๊ณผ์ ์ด ํ์ํฉ๋๋ค.<br><br> | |
์ด๋ ์ ์ฒด ๋ชจ๋ธ์ ๋ค์ ํ์ตํ์ง ์๊ณ , ๋ง์ง๋ง ๋ถ๋ฅ๊ธฐ(MLP)๋ง ํ์ตํ๋ ๋ฐฉ์์ด ๋ฐ๋ก | |
<b>Few-shot Fine-tuning</b>์ ๋๋ค.<br><br> | |
์ด ๋ฐฉ๋ฒ์ ํตํด ์ ์ ์์ ๊ฐ์ ๋ฐ์ดํฐ๋ง์ผ๋ก๋ ๋์ ์ฑ๋ฅ์ ๋ฌ์ฑํ ์ ์์ต๋๋ค. | |
</div> | |
""", unsafe_allow_html=True) | |
with tabs[3]: | |
st.markdown('<div class="big-title">๐ ์ด ๊ธฐ์ ์ ์ด๋์ ์ฐ์ผ ์ ์์๊น์?</div>', unsafe_allow_html=True) | |
st.markdown(""" | |
<div class="card"> | |
์ด ๊ฐ์ ๋ถ์ ๊ธฐ์ ์ ๋จ์ํ ๋ฌธ์ฅ์ ๊ฐ์ ์ ๋ถ๋ฅํ๋ ๋ฐ ๊ทธ์น์ง ์๊ณ , | |
<b>๋์งํธ ์ฌํ์์์ ๊ฐ์ ํ๋ฆ</b>๊ณผ <b>๊ณต๋ก ์ฅ์ ์ ์์ ๊ตฌ์กฐ</b>๋ฅผ ์ดํดํ๋ ๋ฐ๊น์ง ํ์ฅ๋ ์ ์์ต๋๋ค.<br><br> | |
ํนํ ๋ค์๊ณผ ๊ฐ์ ๋ถ์ผ์ ํ์ฉ๋ ์ ์์ต๋๋ค: | |
<ul> | |
<li>๐ก <b>์ ์ฑ ๋๊ธ ํ์ง</b>: ์ ํด ํํ, ํ์ค ํํ์ ์กฐ๊ธฐ์ ๊ฐ์งํ๊ณ ํํฐ๋ง</li> | |
<li>๐ <b>๋๊ธ ๊ฐ์ ํ๋ฆ ์๊ฐํ</b>: ์ ํ๋ธ๋ ๋ด์ค ๋๊ธ์์ ๊ฐ์ ์ ํ ๊ตฌ์กฐ ๋ถ์</li> | |
<li>๐ฐ <b>์ฌํ ์ด์ ๊ณต๊ฐ/ํ์ค ๋ฐ์ ์ถ์ </b>: ํน์ ์ฌ๊ฑด์ ๋ํ ๊ฐ์ ๋ฐ์ ๋ชจ๋ํฐ๋ง</li> | |
<li>๐ฌ <b>์จ๋ผ์ธ ๊ณต๋ก ์ฅ ๊ฐ์ ์ ์ผ ์ฐ๊ตฌ</b>: ๊ฐ์ ์ด ๋๊ธ-๋๋๊ธ๋ก ์ด๋ป๊ฒ ํ์ฐ๋๋์ง ์ ๋์ ๋ถ์</li> | |
</ul> | |
๋์๊ฐ ์ด ๊ธฐ์ ์ <b>๋์งํธ ์ธ๋ฌธํ์ ์๋ก์ด ๋ถ์ ๋๊ตฌ</b>๋ก ํ์ฉ๋์ด, | |
ํ ์คํธ ๊ธฐ๋ฐ ์ฌ๋ก ์ ์ ์์ ๊ตฌ์กฐ๋ฅผ ๋ณด๋ค ๊น์ด ์๊ฒ ์ดํดํ๋ ๊ธฐ๋ฐ์ด ๋ ์ ์์ต๋๋ค. | |
</div> | |
""", unsafe_allow_html=True) | |
with tabs[4]: | |
st.image("image/clustering-2.png", use_column_width=True) | |
st.image("image/clustering-10.png", use_column_width=True) | |
st.image("image/clustering-plutchiks.png", use_column_width=True) | |
st.image("image/clustering-plutchiks-bert.png", use_column_width=True) | |