Humanity / app.py
JeongHyunsung's picture
Update app.py
8d4d16e verified
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import streamlit as st
import os
import numpy as np
st.markdown("""
<style>
.big-title {
font-size: 1.8em;
font-weight: 800;
margin-bottom: 0.2em;
}
.sub-info {
font-size: 1.1em;
color: #666;
margin-bottom: 1.2em;
}
.card {
background-color: #f1f3f6;
padding: 1.2em;
border-left: 5px solid #3366cc;
border-radius: 6px;
margin-bottom: 1em;
}
</style>
""", unsafe_allow_html=True)
KOTE_LABELS = [
'๋ถˆํ‰/๋ถˆ๋งŒ', 'ํ™˜์˜/ํ˜ธ์˜', '๊ฐ๋™/๊ฐํƒ„', '์ง€๊ธ‹์ง€๊ธ‹', '๊ณ ๋งˆ์›€', '์Šฌํ””', 'ํ™”๋‚จ/๋ถ„๋…ธ', '์กด๊ฒฝ',
'๊ธฐ๋Œ€๊ฐ', '์šฐ์ญ๋Œ/๋ฌด์‹œํ•จ', '์•ˆํƒ€๊นŒ์›€/์‹ค๋ง', '๋น„์žฅํ•จ', '์˜์‹ฌ/๋ถˆ์‹ ', '๋ฟŒ๋“ฏํ•จ', 'ํŽธ์•ˆ/์พŒ์ ',
'์‹ ๊ธฐํ•จ/๊ด€์‹ฌ', '์•„๊ปด์ฃผ๋Š”', '๋ถ€๋„๋Ÿฌ์›€', '๊ณตํฌ/๋ฌด์„œ์›€', '์ ˆ๋ง', 'ํ•œ์‹ฌํ•จ', '์—ญ๊ฒจ์›€/์ง•๊ทธ๋Ÿฌ์›€',
'์งœ์ฆ', '์–ด์ด์—†์Œ', '์—†์Œ', 'ํŒจ๋ฐฐ/์ž๊ธฐํ˜์˜ค', '๊ท€์ฐฎ์Œ', 'ํž˜๋“ฆ/์ง€์นจ', '์ฆ๊ฑฐ์›€/์‹ ๋‚จ', '๊นจ๋‹ฌ์Œ',
'์ฃ„์ฑ…๊ฐ', '์ฆ์˜ค/ํ˜์˜ค', 'ํ๋ญ‡ํ•จ(๊ท€์—ฌ์›€/์˜ˆ์จ)', '๋‹นํ™ฉ/๋‚œ์ฒ˜', '๊ฒฝ์•…', '๋ถ€๋‹ด/์•ˆ_๋‚ดํ‚ด', '์„œ๋Ÿฌ์›€',
'์žฌ๋ฏธ์—†์Œ', '๋ถˆ์Œํ•จ/์—ฐ๋ฏผ', '๋†€๋žŒ', 'ํ–‰๋ณต', '๋ถˆ์•ˆ/๊ฑฑ์ •', '๊ธฐ์จ', '์•ˆ์‹ฌ/์‹ ๋ขฐ'
]
class MLPClassifier(nn.Module):
def __init__(self, input_dim=1024, num_labels=44):
super(MLPClassifier, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(input_dim, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_labels)
)
def forward(self, x):
return self.mlp(x)
@st.cache_resource
def load_model():
device = torch.device("cpu")
tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
base_model = AutoModel.from_pretrained("klue/roberta-large").eval()
mlp_model = MLPClassifier().eval()
ckpt_path = os.path.join("checkpoints", "mlp_model.pth")
mlp_model.load_state_dict(torch.load(ckpt_path, map_location=device))
return tokenizer, base_model, mlp_model
tokenizer, base_model, mlp_model = load_model()
def predict_emotion(text, top_k=5):
encoded = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
with torch.no_grad():
outputs = base_model(**encoded)
cls_emb = outputs.last_hidden_state[:, 0, :]
logits = mlp_model(cls_emb)
probs = torch.sigmoid(logits).squeeze(0).numpy()
result = sorted(zip(KOTE_LABELS, probs), key=lambda x: x[1], reverse=True)
return result[:top_k], probs
tabs = st.tabs(["๊ฐ์ • ๋ถ„์„ ์ฒดํ—˜", "AI๋Š” ์–ด๋–ป๊ฒŒ ๊ฐ์ •์„ ์ดํ•ดํ• ๊นŒ?", "Few-shot Fine-tuning์ด๋ž€?", "ํ™œ์šฉ๊ณผ ์˜์˜", "๊ธฐํƒ€ ์ž๋ฃŒ๋ฃŒ"])
with tabs[0]:
st.markdown('<div class="big-title">๐ŸŽญ ํ•œ๊ตญ์–ด ๊ฐ์ • ๋ถ„์„ AI ์ฒดํ—˜</div>', unsafe_allow_html=True)
st.markdown('<div class="sub-info">2025๋…„ 1ํ•™๊ธฐ ๋””์ง€ํ„ธ ์ธ๋ฌธํ•™ ์ž…๋ฌธ (SLA23501) ยท <b>Team ์ƒ๋Ÿฌ๋“œ๋ณผ</b><br>๊ฐ•์ˆ˜ํ˜„ ยท ๊น€๋™์šฐ ยท ์ •์˜ˆ์€ ยท ์ •ํ˜„์„ฑ ยท ์ตœ์ข…์œค</div>', unsafe_allow_html=True)
st.markdown("""
<div class="card">
์ธ๊ณต์ง€๋Šฅ์€ ์ž…๋ ฅ๋œ ๋ฌธ์žฅ์„ ๋ถ„์„ํ•ด ๊ฐ์ •์ด ์–ด๋–ป๊ฒŒ ํ‘œํ˜„๋˜์—ˆ๋Š”์ง€๋ฅผ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.<br>
์•„๋ž˜์— ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜๊ฑฐ๋‚˜ ์˜ˆ์‹œ ๋ฌธ์žฅ์„ ๋ถˆ๋Ÿฌ์˜จ ํ›„, ๊ฐ์ • ์˜ˆ์ธก ๋ฒ„ํŠผ์„ ๋ˆŒ๋Ÿฌ ์ฒดํ—˜ํ•ด๋ณด์„ธ์š”.
</div>
""", unsafe_allow_html=True)
if "text_input" not in st.session_state:
st.session_state.text_input = ""
col1, col2 = st.columns([1, 1])
with col1:
if st.button("๐Ÿ“Œ ์˜ˆ์‹œ ๋ฌธ์žฅ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ"):
st.session_state.text_input = "๊ทธ๊ฑธ ์ด์ œ ๋งํ•ด์ค˜? ์นœ์ ˆํ•˜๋„ค ์ •๋ง"
with col2:
predict_clicked = st.button("๐Ÿ” ๊ฐ์ • ์˜ˆ์ธกํ•˜๊ธฐ")
text = st.text_area(
"โœ๏ธ ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜์„ธ์š”:",
value=st.session_state.text_input,
height=120,
placeholder="์˜ˆ: ์˜ค๋Š˜ ํ•˜๋ฃจ ์ •๋ง ํ–‰๋ณตํ–ˆ์–ด์š”."
)
st.session_state.text_input = text
if predict_clicked:
if text.strip():
with st.spinner("AI๊ฐ€ ๊ฐ์ •์„ ๋ถ„์„ ์ค‘์ž…๋‹ˆ๋‹ค..."):
results, full_probs = predict_emotion(text)
top_emotion, top_prob = results[0]
st.markdown(
f'<div class="card"><div class="highlight">โœ… ๊ฐ€์žฅ ๊ฐ•ํ•˜๊ฒŒ ํ‘œํ˜„๋œ ๊ฐ์ •: <b>{top_emotion}</b> ({top_prob:.2f})</div>',
unsafe_allow_html=True
)
st.subheader("๐Ÿ“Š ์ƒ์œ„ ๊ฐ์ • ๊ฒฐ๊ณผ")
for label, prob in results:
st.markdown(f"- **{label}**: `{prob:.3f}`")
st.subheader("๐Ÿ“ˆ ํ™•๋ฅ  ๋ถ„ํฌ (Top 5)")
st.bar_chart({label: prob for label, prob in results})
st.markdown("</div>", unsafe_allow_html=True)
else:
st.warning("๋ฌธ์žฅ์„ ๋จผ์ € ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”.")
with tabs[1]:
st.markdown('<div class="big-title">๐Ÿค– ์ธ๊ณต์ง€๋Šฅ์€ ๊ฐ์ •์„ ์–ด๋–ป๊ฒŒ ์ดํ•ดํ• ๊นŒ์š”?</div>', unsafe_allow_html=True)
st.markdown("""
<div class="card">
์ธ๊ณต์ง€๋Šฅ์€ ๋ฌธ์žฅ์„ '์ˆซ์ž์˜ ๋ฒกํ„ฐ'๋กœ ๋ฐ”๊พธ์–ด ์ดํ•ดํ•ฉ๋‹ˆ๋‹ค.<br><br>
์ด ๊ณผ์ •์€ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๋‹จ๊ณ„๋กœ ์ด๋ฃจ์–ด์ง‘๋‹ˆ๋‹ค:
<ol>
<li><b>์‚ฌ์ „ํ•™์Šต ์–ธ์–ด ๋ชจ๋ธ</b>(KLUE-RoBERTa)์ด ๋ฌธ์žฅ์„ ์ฝ๊ณ  ํ•ต์‹ฌ ์˜๋ฏธ๋ฅผ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.</li>
<li>์ด ๊ฒฐ๊ณผ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ <b>๊ฐ์ • ๋ถ„๋ฅ˜๊ธฐ</b>(MLP)๊ฐ€ ๊ฐ์ •์„ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.</li>
<li>๊ฐ ๊ฐ์ •์— ๋Œ€ํ•œ ๊ฐ€๋Šฅ์„ฑ์„ <b>ํ™•๋ฅ ๋กœ</b> ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.</li>
</ol>
</div>
""", unsafe_allow_html=True)
with tabs[2]:
st.markdown('<div class="big-title">๐Ÿง  Few-shot Fine-tuning์ด๋ž€?</div>', unsafe_allow_html=True)
st.markdown("""
<div class="card">
์šฐ๋ฆฌ๊ฐ€ ์‚ฌ์šฉํ•˜๋Š” KLUE-RoBERTa ๋ชจ๋ธ์€ ์ด๋ฏธ ์ˆ˜๋งŽ์€ ๋ฌธ์žฅ์„ ํ•™์Šตํ•œ ๊ฑฐ๋Œ€ํ•œ ์–ธ์–ด ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.<br><br>
ํ•˜์ง€๋งŒ ๊ฐ์ • ๋ถ„์„์ด๋ผ๋Š” ํŠน์ •ํ•œ ์ž‘์—…์— ๋งž๊ฒŒ ์กฐ๊ธˆ ๋” ํ•™์Šต์‹œํ‚ค๋Š” ๊ณผ์ •์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.<br><br>
์ด๋•Œ ์ „์ฒด ๋ชจ๋ธ์„ ๋‹ค์‹œ ํ•™์Šตํ•˜์ง€ ์•Š๊ณ , ๋งˆ์ง€๋ง‰ ๋ถ„๋ฅ˜๊ธฐ(MLP)๋งŒ ํ•™์Šตํ•˜๋Š” ๋ฐฉ์‹์ด ๋ฐ”๋กœ
<b>Few-shot Fine-tuning</b>์ž…๋‹ˆ๋‹ค.<br><br>
์ด ๋ฐฉ๋ฒ•์„ ํ†ตํ•ด ์ ์€ ์–‘์˜ ๊ฐ์ • ๋ฐ์ดํ„ฐ๋งŒ์œผ๋กœ๋„ ๋†’์€ ์„ฑ๋Šฅ์„ ๋‹ฌ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
</div>
""", unsafe_allow_html=True)
with tabs[3]:
st.markdown('<div class="big-title">๐Ÿ“Œ ์ด ๊ธฐ์ˆ ์€ ์–ด๋””์— ์“ฐ์ผ ์ˆ˜ ์žˆ์„๊นŒ์š”?</div>', unsafe_allow_html=True)
st.markdown("""
<div class="card">
์ด ๊ฐ์ • ๋ถ„์„ ๊ธฐ์ˆ ์€ ๋‹จ์ˆœํžˆ ๋ฌธ์žฅ์˜ ๊ฐ์ •์„ ๋ถ„๋ฅ˜ํ•˜๋Š” ๋ฐ ๊ทธ์น˜์ง€ ์•Š๊ณ ,
<b>๋””์ง€ํ„ธ ์‚ฌํšŒ์—์„œ์˜ ๊ฐ์ • ํ๋ฆ„</b>๊ณผ <b>๊ณต๋ก ์žฅ์˜ ์ •์„œ์  ๊ตฌ์กฐ</b>๋ฅผ ์ดํ•ดํ•˜๋Š” ๋ฐ๊นŒ์ง€ ํ™•์žฅ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.<br><br>
ํŠนํžˆ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๋ถ„์•ผ์— ํ™œ์šฉ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:
<ul>
<li>๐Ÿ˜ก <b>์•…์„ฑ ๋Œ“๊ธ€ ํƒ์ง€</b>: ์œ ํ•ด ํ‘œํ˜„, ํ˜์˜ค ํ‘œํ˜„์„ ์กฐ๊ธฐ์— ๊ฐ์ง€ํ•˜๊ณ  ํ•„ํ„ฐ๋ง</li>
<li>๐Ÿ“ˆ <b>๋Œ“๊ธ€ ๊ฐ์ • ํ๋ฆ„ ์‹œ๊ฐํ™”</b>: ์œ ํŠœ๋ธŒ๋‚˜ ๋‰ด์Šค ๋Œ“๊ธ€์—์„œ ๊ฐ์ • ์ „ํŒŒ ๊ตฌ์กฐ ๋ถ„์„</li>
<li>๐Ÿ“ฐ <b>์‚ฌํšŒ ์ด์Šˆ ๊ณต๊ฐ/ํ˜์˜ค ๋ฐ˜์‘ ์ถ”์ </b>: ํŠน์ • ์‚ฌ๊ฑด์— ๋Œ€ํ•œ ๊ฐ์ • ๋ฐ˜์‘ ๋ชจ๋‹ˆํ„ฐ๋ง</li>
<li>๐Ÿ’ฌ <b>์˜จ๋ผ์ธ ๊ณต๋ก ์žฅ ๊ฐ์ • ์ „์—ผ ์—ฐ๊ตฌ</b>: ๊ฐ์ •์ด ๋Œ“๊ธ€-๋Œ€๋Œ“๊ธ€๋กœ ์–ด๋–ป๊ฒŒ ํ™•์‚ฐ๋˜๋Š”์ง€ ์ •๋Ÿ‰์  ๋ถ„์„</li>
</ul>
๋‚˜์•„๊ฐ€ ์ด ๊ธฐ์ˆ ์€ <b>๋””์ง€ํ„ธ ์ธ๋ฌธํ•™์˜ ์ƒˆ๋กœ์šด ๋ถ„์„ ๋„๊ตฌ</b>๋กœ ํ™œ์šฉ๋˜์–ด,
ํ…์ŠคํŠธ ๊ธฐ๋ฐ˜ ์—ฌ๋ก ์˜ ์ •์„œ์  ๊ตฌ์กฐ๋ฅผ ๋ณด๋‹ค ๊นŠ์ด ์žˆ๊ฒŒ ์ดํ•ดํ•˜๋Š” ๊ธฐ๋ฐ˜์ด ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
</div>
""", unsafe_allow_html=True)
with tabs[4]:
st.image("image/clustering-2.png", use_column_width=True)
st.image("image/clustering-10.png", use_column_width=True)
st.image("image/clustering-plutchiks.png", use_column_width=True)
st.image("image/clustering-plutchiks-bert.png", use_column_width=True)