Spaces:
Sleeping
Sleeping
import os | |
import yaml | |
import gradio as gr | |
from sentence_transformers import SentenceTransformer, util | |
import torch | |
import shutil | |
import tempfile | |
import re | |
import pandas as pd | |
# ----- ํ์ผ ๊ฒฝ๋ก ์์ ----- | |
GLOSSARY_FILE = "glossary.md" | |
INFO_FILE = "info.md" | |
PERSONA_FILE = "persona.yaml" | |
CHITCHAT_FILE = "chitchat.yaml" | |
CEO_VIDEO_FILE = "ceo_video.mp4" | |
# ----- ์ ํธ ํจ์ ----- | |
def load_yaml(file_path, default_data=None): | |
try: | |
with open(file_path, "r", encoding="utf-8") as f: | |
return yaml.safe_load(f) | |
except: | |
return default_data if default_data is not None else [] | |
def parse_knowledge_base(file_path): | |
faqs = [] | |
if not os.path.exists(file_path): | |
return [] | |
content = open(file_path, encoding="utf-8").read() | |
blocks = re.findall(r"Q:\s*(.*?)\nA:\s*(.*?)(?=(\n{2,}Q:|\Z))", content, re.DOTALL) | |
for q,a,_ in blocks: | |
faqs.append({"question":q.strip(),"answer":a.strip()}) | |
return faqs | |
# ----- ๋ฐ์ดํฐ ๋ก๋ฉ ----- | |
persona = load_yaml(PERSONA_FILE, {}) | |
chitchat_map = load_yaml(CHITCHAT_FILE, []) | |
glossary_base = parse_knowledge_base(GLOSSARY_FILE) | |
info_base = parse_knowledge_base(INFO_FILE) | |
glossary_qs = [x["question"] for x in glossary_base] | |
glossary_as = [x["answer"] for x in glossary_base] | |
info_qs = [x["question"] for x in info_base] | |
info_as = [x["answer"] for x in info_base] | |
# ----- ์ฑ๋ด ๋ก์ง (๋ณ๊ฒฝ ์์) ----- | |
model_cache = {} | |
def get_model(name): | |
if name not in model_cache: | |
model_cache[name] = SentenceTransformer(name) | |
return model_cache[name] | |
def best_faq_answer(user_question, kb_type, model_name): | |
model = get_model(model_name) | |
if kb_type=="์ฉ์ด": | |
kb_qs, kb_as = glossary_qs, glossary_as | |
else: | |
kb_qs, kb_as = info_qs, info_as | |
emb = model.encode(kb_qs, convert_to_tensor=True) | |
q_emb = model.encode([user_question], convert_to_tensor=True) | |
scores = util.cos_sim(q_emb, emb)[0] | |
return kb_as[int(torch.argmax(scores))] | |
def find_chitchat(uq): | |
for chat in chitchat_map: | |
if any(kw in uq.lower() for kw in chat.get("keywords",[])): | |
return chat["answer"] | |
return None | |
def chat_interface(message, history, kb_type, model_name): | |
if not message.strip(): | |
return history, "" | |
if chit:=find_chitchat(message): | |
resp = chit | |
else: | |
resp = best_faq_answer(message, kb_type, model_name) | |
history = history or [] | |
history.append({"role":"user", "content":message}) | |
history.append({"role":"assistant","content":resp}) | |
# ์์์ ๋งค๋ฒ ์ ๋ณต์ฌ๋ณธ์ ๋์ ์ค๋๋ค | |
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
shutil.copyfile(CEO_VIDEO_FILE, tmp.name) | |
video = gr.Video(value=tmp.name, autoplay=True, interactive=False) | |
return history, "", video | |
# ----- ๋ชจ๋ธ ๋น๊ต ํ๊ฐ ํจ์ ----- | |
def compare_models(kb_type, selected_models): | |
# ์ง๋ฌธ/์ ๋ต ์ | |
if kb_type=="์ฉ์ด": | |
qs, ans = glossary_qs, glossary_as | |
else: | |
qs, ans = info_qs, info_as | |
# ํ๊ทธ ์ ๊ฑฐ | |
qs_clean = [re.sub(r"#.*","",q).strip() for q in qs] | |
records = [] | |
total = len(qs) | |
# ๊ฐ ๋ชจ๋ธ๋ง๋ค | |
for m in selected_models: | |
model = get_model(m) | |
emb = model.encode(qs, convert_to_tensor=True) # corpus ์๋ฒ ๋ฉ | |
test_emb = model.encode(qs_clean, convert_to_tensor=True) | |
sims = util.cos_sim(test_emb, emb) # [N,N] | |
top1 = torch.argmax(sims, dim=1).tolist() | |
top3 = torch.topk(sims, k=3, dim=1).indices.tolist() | |
c1=c3=0 | |
for i in range(total): | |
if ans[top1[i]]==ans[i]: c1+=1 | |
if ans[i] in {ans[idx] for idx in top3[i]}: c3+=1 | |
records.append({ | |
"๋ชจ๋ธ": m, | |
"Topโ1 ๋ง์ ์": c1, | |
"Topโ1 ์ ํ๋": f"{c1}/{total} ({c1/total:.2%})", | |
"Topโ3 ๋ง์ ์": c3, | |
"Topโ3 ์ ํ๋": f"{c3}/{total} ({c3/total:.2%})", | |
}) | |
return pd.DataFrame(records) | |
# ----- Gradio UI ----- | |
model_choices = [ | |
"sentence-transformers/LaBSE", | |
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", | |
"sentence-transformers/bert-base-nli-mean-tokens", | |
"sentence-transformers/distiluse-base-multilingual-cased-v2", | |
"bert-base-uncased", | |
"distilbert-base-multilingual-cased" # ์์๋ก ์ฑ๋ฅ ๋จ์ด์ง๋ ํธ | |
] | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
with gr.Tab("๐ฌ ์ฑ๋ด"): | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=400): | |
video_player = gr.Video(value=CEO_VIDEO_FILE, autoplay=False, interactive=False, height=360) | |
kb_type = gr.Radio(["์ฉ์ด","์ ๋ณด"], value="์ ๋ณด", label="๊ฒ์ ์ ํ") | |
model_name = gr.Dropdown(model_choices, value=model_choices[0], label="๋ชจ๋ธ ์ ํ") | |
user_q = gr.Textbox(lines=2, placeholder="์ง๋ฌธ์ ์ ๋ ฅํ์ธ์") | |
send = gr.Button("์ ์ก") | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot(type="messages", height=360) | |
send.click(chat_interface, | |
inputs=[user_q, chatbot, kb_type, model_name], | |
outputs=[chatbot, user_q, video_player]) | |
user_q.submit(chat_interface, | |
inputs=[user_q, chatbot, kb_type, model_name], | |
outputs=[chatbot, user_q, video_player]) | |
with gr.Tab("๐ ๋ชจ๋ธ ๋น๊ต"): | |
cmp_type = gr.Radio(["์ฉ์ด","์ ๋ณด"], value="์ฉ์ด", label="ํ๊ฐํ KB") | |
cmp_models = gr.CheckboxGroup(model_choices, value=[model_choices[0]], label="๋น๊ตํ ๋ชจ๋ธ๋ค") | |
run_cmp = gr.Button("๋น๊ต ์คํ") | |
cmp_table = gr.DataFrame(interactive=False) | |
run_cmp.click(compare_models, | |
inputs=[cmp_type, cmp_models], | |
outputs=[cmp_table]) | |
if __name__=="__main__": | |
demo.launch() | |