project11 / testapp /app7.py
yeongininin's picture
Add test01 file
acb3057
import os
import yaml
import gradio as gr
from sentence_transformers import SentenceTransformer, util
import torch
import shutil
import tempfile
import re
import pandas as pd
# ----- ํŒŒ์ผ ๊ฒฝ๋กœ ์ƒ์ˆ˜ -----
GLOSSARY_FILE = "glossary.md"
INFO_FILE = "info.md"
PERSONA_FILE = "persona.yaml"
CHITCHAT_FILE = "chitchat.yaml"
CEO_VIDEO_FILE = "ceo_video.mp4"
# ----- ์œ ํ‹ธ ํ•จ์ˆ˜ -----
def load_yaml(file_path, default_data=None):
try:
with open(file_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
except:
return default_data if default_data is not None else []
def parse_knowledge_base(file_path):
faqs = []
if not os.path.exists(file_path):
return []
content = open(file_path, encoding="utf-8").read()
blocks = re.findall(r"Q:\s*(.*?)\nA:\s*(.*?)(?=(\n{2,}Q:|\Z))", content, re.DOTALL)
for q,a,_ in blocks:
faqs.append({"question":q.strip(),"answer":a.strip()})
return faqs
# ----- ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ -----
persona = load_yaml(PERSONA_FILE, {})
chitchat_map = load_yaml(CHITCHAT_FILE, [])
glossary_base = parse_knowledge_base(GLOSSARY_FILE)
info_base = parse_knowledge_base(INFO_FILE)
glossary_qs = [x["question"] for x in glossary_base]
glossary_as = [x["answer"] for x in glossary_base]
info_qs = [x["question"] for x in info_base]
info_as = [x["answer"] for x in info_base]
# ----- ์ฑ—๋ด‡ ๋กœ์ง (๋ณ€๊ฒฝ ์—†์Œ) -----
model_cache = {}
def get_model(name):
if name not in model_cache:
model_cache[name] = SentenceTransformer(name)
return model_cache[name]
def best_faq_answer(user_question, kb_type, model_name):
model = get_model(model_name)
if kb_type=="์šฉ์–ด":
kb_qs, kb_as = glossary_qs, glossary_as
else:
kb_qs, kb_as = info_qs, info_as
emb = model.encode(kb_qs, convert_to_tensor=True)
q_emb = model.encode([user_question], convert_to_tensor=True)
scores = util.cos_sim(q_emb, emb)[0]
return kb_as[int(torch.argmax(scores))]
def find_chitchat(uq):
for chat in chitchat_map:
if any(kw in uq.lower() for kw in chat.get("keywords",[])):
return chat["answer"]
return None
def chat_interface(message, history, kb_type, model_name):
if not message.strip():
return history, ""
if chit:=find_chitchat(message):
resp = chit
else:
resp = best_faq_answer(message, kb_type, model_name)
history = history or []
history.append({"role":"user", "content":message})
history.append({"role":"assistant","content":resp})
# ์˜์ƒ์€ ๋งค๋ฒˆ ์ƒˆ ๋ณต์‚ฌ๋ณธ์„ ๋„์›Œ ์ค๋‹ˆ๋‹ค
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
shutil.copyfile(CEO_VIDEO_FILE, tmp.name)
video = gr.Video(value=tmp.name, autoplay=True, interactive=False)
return history, "", video
# ----- ๋ชจ๋ธ ๋น„๊ต ํ‰๊ฐ€ ํ•จ์ˆ˜ -----
def compare_models(kb_type, selected_models):
# ์งˆ๋ฌธ/์ •๋‹ต ์…‹
if kb_type=="์šฉ์–ด":
qs, ans = glossary_qs, glossary_as
else:
qs, ans = info_qs, info_as
# ํƒœ๊ทธ ์ œ๊ฑฐ
qs_clean = [re.sub(r"#.*","",q).strip() for q in qs]
records = []
total = len(qs)
# ๊ฐ ๋ชจ๋ธ๋งˆ๋‹ค
for m in selected_models:
model = get_model(m)
emb = model.encode(qs, convert_to_tensor=True) # corpus ์ž„๋ฒ ๋”ฉ
test_emb = model.encode(qs_clean, convert_to_tensor=True)
sims = util.cos_sim(test_emb, emb) # [N,N]
top1 = torch.argmax(sims, dim=1).tolist()
top3 = torch.topk(sims, k=3, dim=1).indices.tolist()
c1=c3=0
for i in range(total):
if ans[top1[i]]==ans[i]: c1+=1
if ans[i] in {ans[idx] for idx in top3[i]}: c3+=1
records.append({
"๋ชจ๋ธ": m,
"Topโ€‘1 ๋งž์€ ์ˆ˜": c1,
"Topโ€‘1 ์ •ํ™•๋„": f"{c1}/{total} ({c1/total:.2%})",
"Topโ€‘3 ๋งž์€ ์ˆ˜": c3,
"Topโ€‘3 ์ •ํ™•๋„": f"{c3}/{total} ({c3/total:.2%})",
})
return pd.DataFrame(records)
# ----- Gradio UI -----
model_choices = [
"sentence-transformers/LaBSE",
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
"sentence-transformers/bert-base-nli-mean-tokens",
"sentence-transformers/distiluse-base-multilingual-cased-v2",
"bert-base-uncased",
"distilbert-base-multilingual-cased" # ์˜ˆ์‹œ๋กœ ์„ฑ๋Šฅ ๋–จ์–ด์ง€๋Š” ํŽธ
]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Tab("๐Ÿ’ฌ ์ฑ—๋ด‡"):
with gr.Row():
with gr.Column(scale=1, min_width=400):
video_player = gr.Video(value=CEO_VIDEO_FILE, autoplay=False, interactive=False, height=360)
kb_type = gr.Radio(["์šฉ์–ด","์ •๋ณด"], value="์ •๋ณด", label="๊ฒ€์ƒ‰ ์œ ํ˜•")
model_name = gr.Dropdown(model_choices, value=model_choices[0], label="๋ชจ๋ธ ์„ ํƒ")
user_q = gr.Textbox(lines=2, placeholder="์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”")
send = gr.Button("์ „์†ก")
with gr.Column(scale=2):
chatbot = gr.Chatbot(type="messages", height=360)
send.click(chat_interface,
inputs=[user_q, chatbot, kb_type, model_name],
outputs=[chatbot, user_q, video_player])
user_q.submit(chat_interface,
inputs=[user_q, chatbot, kb_type, model_name],
outputs=[chatbot, user_q, video_player])
with gr.Tab("๐Ÿ›  ๋ชจ๋ธ ๋น„๊ต"):
cmp_type = gr.Radio(["์šฉ์–ด","์ •๋ณด"], value="์šฉ์–ด", label="ํ‰๊ฐ€ํ•  KB")
cmp_models = gr.CheckboxGroup(model_choices, value=[model_choices[0]], label="๋น„๊ตํ•  ๋ชจ๋ธ๋“ค")
run_cmp = gr.Button("๋น„๊ต ์‹คํ–‰")
cmp_table = gr.DataFrame(interactive=False)
run_cmp.click(compare_models,
inputs=[cmp_type, cmp_models],
outputs=[cmp_table])
if __name__=="__main__":
demo.launch()