Spaces:
Sleeping
Sleeping
File size: 5,965 Bytes
acb3057 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import os
import yaml
import gradio as gr
from sentence_transformers import SentenceTransformer, util
import torch
import shutil
import tempfile
import re
import pandas as pd
# ----- ํ์ผ ๊ฒฝ๋ก ์์ -----
GLOSSARY_FILE = "glossary.md"
INFO_FILE = "info.md"
PERSONA_FILE = "persona.yaml"
CHITCHAT_FILE = "chitchat.yaml"
CEO_VIDEO_FILE = "ceo_video.mp4"
# ----- ์ ํธ ํจ์ -----
def load_yaml(file_path, default_data=None):
try:
with open(file_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
except:
return default_data if default_data is not None else []
def parse_knowledge_base(file_path):
faqs = []
if not os.path.exists(file_path):
return []
content = open(file_path, encoding="utf-8").read()
blocks = re.findall(r"Q:\s*(.*?)\nA:\s*(.*?)(?=(\n{2,}Q:|\Z))", content, re.DOTALL)
for q,a,_ in blocks:
faqs.append({"question":q.strip(),"answer":a.strip()})
return faqs
# ----- ๋ฐ์ดํฐ ๋ก๋ฉ -----
persona = load_yaml(PERSONA_FILE, {})
chitchat_map = load_yaml(CHITCHAT_FILE, [])
glossary_base = parse_knowledge_base(GLOSSARY_FILE)
info_base = parse_knowledge_base(INFO_FILE)
glossary_qs = [x["question"] for x in glossary_base]
glossary_as = [x["answer"] for x in glossary_base]
info_qs = [x["question"] for x in info_base]
info_as = [x["answer"] for x in info_base]
# ----- ์ฑ๋ด ๋ก์ง (๋ณ๊ฒฝ ์์) -----
model_cache = {}
def get_model(name):
if name not in model_cache:
model_cache[name] = SentenceTransformer(name)
return model_cache[name]
def best_faq_answer(user_question, kb_type, model_name):
model = get_model(model_name)
if kb_type=="์ฉ์ด":
kb_qs, kb_as = glossary_qs, glossary_as
else:
kb_qs, kb_as = info_qs, info_as
emb = model.encode(kb_qs, convert_to_tensor=True)
q_emb = model.encode([user_question], convert_to_tensor=True)
scores = util.cos_sim(q_emb, emb)[0]
return kb_as[int(torch.argmax(scores))]
def find_chitchat(uq):
for chat in chitchat_map:
if any(kw in uq.lower() for kw in chat.get("keywords",[])):
return chat["answer"]
return None
def chat_interface(message, history, kb_type, model_name):
if not message.strip():
return history, ""
if chit:=find_chitchat(message):
resp = chit
else:
resp = best_faq_answer(message, kb_type, model_name)
history = history or []
history.append({"role":"user", "content":message})
history.append({"role":"assistant","content":resp})
# ์์์ ๋งค๋ฒ ์ ๋ณต์ฌ๋ณธ์ ๋์ ์ค๋๋ค
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
shutil.copyfile(CEO_VIDEO_FILE, tmp.name)
video = gr.Video(value=tmp.name, autoplay=True, interactive=False)
return history, "", video
# ----- ๋ชจ๋ธ ๋น๊ต ํ๊ฐ ํจ์ -----
def compare_models(kb_type, selected_models):
# ์ง๋ฌธ/์ ๋ต ์
if kb_type=="์ฉ์ด":
qs, ans = glossary_qs, glossary_as
else:
qs, ans = info_qs, info_as
# ํ๊ทธ ์ ๊ฑฐ
qs_clean = [re.sub(r"#.*","",q).strip() for q in qs]
records = []
total = len(qs)
# ๊ฐ ๋ชจ๋ธ๋ง๋ค
for m in selected_models:
model = get_model(m)
emb = model.encode(qs, convert_to_tensor=True) # corpus ์๋ฒ ๋ฉ
test_emb = model.encode(qs_clean, convert_to_tensor=True)
sims = util.cos_sim(test_emb, emb) # [N,N]
top1 = torch.argmax(sims, dim=1).tolist()
top3 = torch.topk(sims, k=3, dim=1).indices.tolist()
c1=c3=0
for i in range(total):
if ans[top1[i]]==ans[i]: c1+=1
if ans[i] in {ans[idx] for idx in top3[i]}: c3+=1
records.append({
"๋ชจ๋ธ": m,
"Topโ1 ๋ง์ ์": c1,
"Topโ1 ์ ํ๋": f"{c1}/{total} ({c1/total:.2%})",
"Topโ3 ๋ง์ ์": c3,
"Topโ3 ์ ํ๋": f"{c3}/{total} ({c3/total:.2%})",
})
return pd.DataFrame(records)
# ----- Gradio UI -----
model_choices = [
"sentence-transformers/LaBSE",
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
"sentence-transformers/bert-base-nli-mean-tokens",
"sentence-transformers/distiluse-base-multilingual-cased-v2",
"bert-base-uncased",
"distilbert-base-multilingual-cased" # ์์๋ก ์ฑ๋ฅ ๋จ์ด์ง๋ ํธ
]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Tab("๐ฌ ์ฑ๋ด"):
with gr.Row():
with gr.Column(scale=1, min_width=400):
video_player = gr.Video(value=CEO_VIDEO_FILE, autoplay=False, interactive=False, height=360)
kb_type = gr.Radio(["์ฉ์ด","์ ๋ณด"], value="์ ๋ณด", label="๊ฒ์ ์ ํ")
model_name = gr.Dropdown(model_choices, value=model_choices[0], label="๋ชจ๋ธ ์ ํ")
user_q = gr.Textbox(lines=2, placeholder="์ง๋ฌธ์ ์
๋ ฅํ์ธ์")
send = gr.Button("์ ์ก")
with gr.Column(scale=2):
chatbot = gr.Chatbot(type="messages", height=360)
send.click(chat_interface,
inputs=[user_q, chatbot, kb_type, model_name],
outputs=[chatbot, user_q, video_player])
user_q.submit(chat_interface,
inputs=[user_q, chatbot, kb_type, model_name],
outputs=[chatbot, user_q, video_player])
with gr.Tab("๐ ๋ชจ๋ธ ๋น๊ต"):
cmp_type = gr.Radio(["์ฉ์ด","์ ๋ณด"], value="์ฉ์ด", label="ํ๊ฐํ KB")
cmp_models = gr.CheckboxGroup(model_choices, value=[model_choices[0]], label="๋น๊ตํ ๋ชจ๋ธ๋ค")
run_cmp = gr.Button("๋น๊ต ์คํ")
cmp_table = gr.DataFrame(interactive=False)
run_cmp.click(compare_models,
inputs=[cmp_type, cmp_models],
outputs=[cmp_table])
if __name__=="__main__":
demo.launch()
|