import os import yaml import gradio as gr from sentence_transformers import SentenceTransformer, util import torch import shutil import tempfile import re import pandas as pd # ----- 파일 경로 상수 ----- GLOSSARY_FILE = "glossary.md" INFO_FILE = "info.md" PERSONA_FILE = "persona.yaml" CHITCHAT_FILE = "chitchat.yaml" CEO_VIDEO_FILE = "ceo_video.mp4" # ----- 유틸 함수 ----- def load_yaml(file_path, default_data=None): try: with open(file_path, "r", encoding="utf-8") as f: return yaml.safe_load(f) except: return default_data if default_data is not None else [] def parse_knowledge_base(file_path): faqs = [] if not os.path.exists(file_path): return [] content = open(file_path, encoding="utf-8").read() blocks = re.findall(r"Q:\s*(.*?)\nA:\s*(.*?)(?=(\n{2,}Q:|\Z))", content, re.DOTALL) for q,a,_ in blocks: faqs.append({"question":q.strip(),"answer":a.strip()}) return faqs # ----- 데이터 로딩 ----- persona = load_yaml(PERSONA_FILE, {}) chitchat_map = load_yaml(CHITCHAT_FILE, []) glossary_base = parse_knowledge_base(GLOSSARY_FILE) info_base = parse_knowledge_base(INFO_FILE) glossary_qs = [x["question"] for x in glossary_base] glossary_as = [x["answer"] for x in glossary_base] info_qs = [x["question"] for x in info_base] info_as = [x["answer"] for x in info_base] # ----- 챗봇 로직 (변경 없음) ----- model_cache = {} def get_model(name): if name not in model_cache: model_cache[name] = SentenceTransformer(name) return model_cache[name] def best_faq_answer(user_question, kb_type, model_name): model = get_model(model_name) if kb_type=="용어": kb_qs, kb_as = glossary_qs, glossary_as else: kb_qs, kb_as = info_qs, info_as emb = model.encode(kb_qs, convert_to_tensor=True) q_emb = model.encode([user_question], convert_to_tensor=True) scores = util.cos_sim(q_emb, emb)[0] return kb_as[int(torch.argmax(scores))] def find_chitchat(uq): for chat in chitchat_map: if any(kw in uq.lower() for kw in chat.get("keywords",[])): return chat["answer"] return None def chat_interface(message, history, kb_type, model_name): if not message.strip(): return history, "" if chit:=find_chitchat(message): resp = chit else: resp = best_faq_answer(message, kb_type, model_name) history = history or [] history.append({"role":"user", "content":message}) history.append({"role":"assistant","content":resp}) # 영상은 매번 새 복사본을 띄워 줍니다 tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) shutil.copyfile(CEO_VIDEO_FILE, tmp.name) video = gr.Video(value=tmp.name, autoplay=True, interactive=False) return history, "", video # ----- 모델 비교 평가 함수 ----- def compare_models(kb_type, selected_models): # 질문/정답 셋 if kb_type=="용어": qs, ans = glossary_qs, glossary_as else: qs, ans = info_qs, info_as # 태그 제거 qs_clean = [re.sub(r"#.*","",q).strip() for q in qs] records = [] total = len(qs) # 각 모델마다 for m in selected_models: model = get_model(m) emb = model.encode(qs, convert_to_tensor=True) # corpus 임베딩 test_emb = model.encode(qs_clean, convert_to_tensor=True) sims = util.cos_sim(test_emb, emb) # [N,N] top1 = torch.argmax(sims, dim=1).tolist() top3 = torch.topk(sims, k=3, dim=1).indices.tolist() c1=c3=0 for i in range(total): if ans[top1[i]]==ans[i]: c1+=1 if ans[i] in {ans[idx] for idx in top3[i]}: c3+=1 records.append({ "모델": m, "Top‑1 맞은 수": c1, "Top‑1 정확도": f"{c1}/{total} ({c1/total:.2%})", "Top‑3 맞은 수": c3, "Top‑3 정확도": f"{c3}/{total} ({c3/total:.2%})", }) return pd.DataFrame(records) # ----- Gradio UI ----- model_choices = [ "sentence-transformers/LaBSE", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", "sentence-transformers/bert-base-nli-mean-tokens", "sentence-transformers/distiluse-base-multilingual-cased-v2", "bert-base-uncased", "distilbert-base-multilingual-cased" # 예시로 성능 떨어지는 편 ] with gr.Blocks(theme=gr.themes.Soft()) as demo: with gr.Tab("💬 챗봇"): with gr.Row(): with gr.Column(scale=1, min_width=400): video_player = gr.Video(value=CEO_VIDEO_FILE, autoplay=False, interactive=False, height=360) kb_type = gr.Radio(["용어","정보"], value="정보", label="검색 유형") model_name = gr.Dropdown(model_choices, value=model_choices[0], label="모델 선택") user_q = gr.Textbox(lines=2, placeholder="질문을 입력하세요") send = gr.Button("전송") with gr.Column(scale=2): chatbot = gr.Chatbot(type="messages", height=360) send.click(chat_interface, inputs=[user_q, chatbot, kb_type, model_name], outputs=[chatbot, user_q, video_player]) user_q.submit(chat_interface, inputs=[user_q, chatbot, kb_type, model_name], outputs=[chatbot, user_q, video_player]) with gr.Tab("🛠 모델 비교"): cmp_type = gr.Radio(["용어","정보"], value="용어", label="평가할 KB") cmp_models = gr.CheckboxGroup(model_choices, value=[model_choices[0]], label="비교할 모델들") run_cmp = gr.Button("비교 실행") cmp_table = gr.DataFrame(interactive=False) run_cmp.click(compare_models, inputs=[cmp_type, cmp_models], outputs=[cmp_table]) if __name__=="__main__": demo.launch()