Spaces:
Sleeping
Sleeping
import os | |
import yaml | |
import gradio as gr | |
from sentence_transformers import SentenceTransformer, util | |
import torch | |
import shutil | |
import tempfile | |
# ํ์ผ ๊ฒฝ๋ก | |
KNOWLEDGE_FILE = "company_knowledge.md" | |
PERSONA_FILE = "persona.yaml" | |
CHITCHAT_FILE = "chitchat.yaml" | |
KEYWORD_MAP_FILE = "keyword_map.yaml" | |
CEO_VIDEO_FILE = "ceo_video.mp4" | |
CEO_IMG_FILE = "ceo.jpg" # ํ์์ ์ฌ์ฉ | |
def load_yaml(file_path, default_data=None): | |
try: | |
with open(file_path, "r", encoding="utf-8") as f: | |
return yaml.safe_load(f) | |
except Exception: | |
return default_data if default_data is not None else [] | |
def parse_knowledge_base(file_path): | |
import re | |
faqs = [] | |
if not os.path.exists(file_path): | |
return [] | |
with open(file_path, encoding="utf-8") as f: | |
content = f.read() | |
# Q:\s*(...) \nA:\s*(...)\n{2,} ๋๋ ๋ | |
blocks = re.findall(r"Q:\s*(.*?)\nA:\s*(.*?)(?=(\n{2,}Q:|\Z))", content, re.DOTALL) | |
for q, a, _ in blocks: | |
faqs.append({"question": q.strip(), "answer": a.strip()}) | |
return faqs | |
# ๋ฐ์ดํฐ ๋ก๋ | |
persona = load_yaml(PERSONA_FILE, {}) | |
chitchat_map = load_yaml(CHITCHAT_FILE, []) | |
keyword_map = load_yaml(KEYWORD_MAP_FILE, []) | |
knowledge_base = parse_knowledge_base(KNOWLEDGE_FILE) | |
kb_questions = [item['question'] for item in knowledge_base] | |
kb_answers = [item['answer'] for item in knowledge_base] | |
# ๋ฌด๋ฃ ์๋ฒ ๋ฉ ๋ชจ๋ธ | |
model = SentenceTransformer('distilbert-base-multilingual-cased') | |
if kb_questions: | |
kb_embeddings = model.encode(kb_questions, convert_to_tensor=True) | |
else: | |
kb_embeddings = None | |
# ์ญ์ ์ (์ทจ์์ ) ์ ์ฉ ํจ์ | |
def apply_strike(text, del_section="6000~6500๋ง์, ์ฑ๊ณผ๊ธ 1800~2400๋ง์"): | |
# ๊ธ์ฌ ์ ๋ณด๊ฐ ํฌํจ๋ ๋ต๋ณ์ผ ๋๋ง strike-through | |
if del_section in text: | |
return text.replace( | |
del_section, | |
f"<s>{del_section}</s>" | |
) | |
return text | |
# Chitchat(์ธ์ฌ ๋ฑ) ๋งค์นญ | |
def find_chitchat(user_question): | |
uq = user_question.lower() | |
for chat in chitchat_map: | |
if any(kw in uq for kw in chat.get('keywords', [])): | |
return chat['answer'] | |
return None | |
# ํค์๋ ๊ธฐ๋ฐ Q ๋งคํ (๋ณต์ง: ํด๊ฐ ์ ๋, ๊ต์ก, ๋ณต๋ฆฌํ์ ๋ฑ ๊ฐํ) | |
def map_user_question_to_knowledge(user_question): | |
uq = user_question.lower() | |
for item in keyword_map: | |
for kw in item.get('keywords', []): | |
if kw in uq: | |
return item['question'] | |
return None | |
def find_answer_by_question(q): | |
for item in knowledge_base: | |
if item['question'] == q: | |
return item['answer'] | |
return None | |
def find_answer_by_keywords(user_question): | |
uq = user_question.lower() | |
for item in knowledge_base: | |
for kw in item.get('keywords', []): | |
if kw in uq: | |
return item['answer'] | |
return None | |
def best_faq_answer(user_question): | |
uq = user_question.strip() | |
if not uq: | |
return "๋ฌด์์ด ๊ถ๊ธํ์ ์ง ๋ง์ํด ์ฃผ์ธ์!" | |
chit = find_chitchat(uq) | |
if chit: | |
return chit | |
# (1) ํค์๋๋งต ์ฐ์ ๋งคํ (๋ณต์ง/๊ธ์ฌ ๊ฐ๊ฐ ๋ถ๋ฆฌ) | |
mapped_q = map_user_question_to_knowledge(uq) | |
if mapped_q: | |
answer = find_answer_by_question(mapped_q) | |
if answer: | |
# ๋ณต์ง ๋ถ์ผ: '์ฐ๋ด ์์ค' ๋ต๋ณ ์๋ ๊ฒฝ์ฐ์๋ ์ญ์ ์ ์์ | |
if "์ฐ๋ด" in mapped_q: | |
return apply_strike(answer) | |
return answer | |
# (2) knowledge_base ์ง์ ํค์๋ ๋งค์นญ (๋ณต์ง ๊ด๋ จ ํค์๋ ๊ฐํ๋์ด์ผ ํจ!) | |
answer = find_answer_by_keywords(uq) | |
if answer: | |
return answer | |
# (3) ์๋ฒ ๋ฉ ์ ์ฌ๋ ๊ธฐ๋ฐ soft-matching | |
if kb_embeddings is not None and len(kb_answers) > 0: | |
q_emb = model.encode([uq], convert_to_tensor=True) | |
scores = util.cos_sim(q_emb, kb_embeddings)[0] | |
best_idx = int(torch.argmax(scores)) | |
best_question = kb_questions[best_idx] | |
# ๋ณต์ง์ง๋ฌธ์ธ๋ฐ ์ฐ๋ดํค์๋ ๋งค์นญ๋๋ ๊ฒฝ์ฐ, ๋ณต์ง ์ฐ์ ๋ต๋ณ์ ์ ํํ๋๋ก | |
# ์๋ if์์ ์ค์ ๋ณต์ง ๋ต๋ณ ์ฐ์ ์ฝ๋ | |
๋ณต์ง๊ฐ๋ฅ = ["๋ณต์ง", "ํด๊ฐ", "๊ต์ก", "ํ์ฌ", "๋ํธํ", "๋ณต๋ฆฌํ์", "์ ๋"] | |
์ฐ๋ด๊ฐ๋ฅ = ["์ฐ๋ด", "๊ธ์ฌ", "์๊ธ", "์๊ธ", "๋ณด์", "๋ด๊ธ", "์ฒ์ฐ"] | |
if any(w in uq for w in ๋ณต์ง๊ฐ๋ฅ) and not any(w in best_question for w in ์ฐ๋ด๊ฐ๋ฅ): | |
return kb_answers[best_idx] | |
# ์ญ์ ์ ์ ์ฐ๋ด ๋ต๋ณ์๋ง | |
if "์ฐ๋ด" in best_question or "๊ธ์ฌ" in best_question: | |
return apply_strike(kb_answers[best_idx]) | |
return kb_answers[best_idx] | |
# (4) fallback | |
return persona.get('style', {}).get('unknown_answer', "์์ง ์ค๋น๋์ง ์์ ์ง๋ฌธ์ ๋๋ค. ๋ค๋ฅธ ์ง๋ฌธ๋ ํด์ฃผ์ธ์!") | |
# ์ง๋ฌธ ๋ฐ์ ๋๋ง๋ค CEO ์์ ๋ณต์ฌ๋ณธ ์์ํ์ผ๋ก ์์ฑ โ autoplay ํ์ค | |
def get_temp_video_copy(): | |
temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
temp_filepath = temp_file.name | |
temp_file.close() | |
shutil.copyfile(CEO_VIDEO_FILE, temp_filepath) | |
return temp_filepath | |
def chat_interface(message, history): | |
bot_response = best_faq_answer(message) | |
history.append((message, bot_response)) | |
temp_video_path = get_temp_video_copy() | |
# ํ ์คํธ์ html๊ฐ๋ฅํ๋ฉด answer์ html์ญ์ ์ ์ ์ง | |
return history, "", gr.update(value=temp_video_path, autoplay=True, interactive=False, elem_id="ceo-video-panel") | |
with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo: | |
with gr.Row(elem_id="main-row"): | |
with gr.Column(scale=1, min_width=350): | |
video_player = gr.Video( | |
value=CEO_VIDEO_FILE, | |
autoplay=False, loop=False, interactive=False, | |
height=350, elem_id="ceo-video-panel" | |
) | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot( | |
label="", | |
height=350, | |
elem_id="chatbot-box", | |
show_copy_button=True | |
) | |
with gr.Row(): | |
msg_input = gr.Textbox(placeholder="๋ฌด์์ด๋ ๋ฌผ์ด๋ณด์ธ์.", scale=4, show_label=False) | |
send_btn = gr.Button("์ ์ก", scale=1, min_width=80) | |
gr.Examples( | |
examples=["๋ณต์ง ๋ญ ์์ด?", "ํด๊ฐ ์ ๋ ์ค๋ช ํด์ค", "์ฐ๋ด ์๋ ค์ค", "๋ํธํ ํ์ฌ?", "์์ฌ์ ๊ณต?", "์ฃผ๋ ฅ์ ํ", "์กฐ์ง๋ฌธํ"], | |
inputs=msg_input | |
) | |
# ์ฐ๊ฒฐ | |
outputs_list = [chatbot, msg_input, video_player] | |
msg_input.submit(chat_interface, [msg_input, chatbot], outputs_list) | |
send_btn.click(chat_interface, [msg_input, chatbot], outputs_list) | |
if __name__ == "__main__": | |
demo.launch() | |