Spaces:
Sleeping
Sleeping
File size: 6,772 Bytes
2e6852b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import os
import yaml
import gradio as gr
from sentence_transformers import SentenceTransformer, util
import torch
import shutil
import tempfile
# ํ์ผ ๊ฒฝ๋ก
KNOWLEDGE_FILE = "company_knowledge.md"
PERSONA_FILE = "persona.yaml"
CHITCHAT_FILE = "chitchat.yaml"
KEYWORD_MAP_FILE = "keyword_map.yaml"
CEO_VIDEO_FILE = "ceo_video.mp4"
CEO_IMG_FILE = "ceo.jpg" # ํ์์ ์ฌ์ฉ
def load_yaml(file_path, default_data=None):
try:
with open(file_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
except Exception:
return default_data if default_data is not None else []
def parse_knowledge_base(file_path):
import re
faqs = []
if not os.path.exists(file_path):
return []
with open(file_path, encoding="utf-8") as f:
content = f.read()
# Q:\s*(...) \nA:\s*(...)\n{2,} ๋๋ ๋
blocks = re.findall(r"Q:\s*(.*?)\nA:\s*(.*?)(?=(\n{2,}Q:|\Z))", content, re.DOTALL)
for q, a, _ in blocks:
faqs.append({"question": q.strip(), "answer": a.strip()})
return faqs
# ๋ฐ์ดํฐ ๋ก๋
persona = load_yaml(PERSONA_FILE, {})
chitchat_map = load_yaml(CHITCHAT_FILE, [])
keyword_map = load_yaml(KEYWORD_MAP_FILE, [])
knowledge_base = parse_knowledge_base(KNOWLEDGE_FILE)
kb_questions = [item['question'] for item in knowledge_base]
kb_answers = [item['answer'] for item in knowledge_base]
# ๋ฌด๋ฃ ์๋ฒ ๋ฉ ๋ชจ๋ธ
model = SentenceTransformer('distilbert-base-multilingual-cased')
if kb_questions:
kb_embeddings = model.encode(kb_questions, convert_to_tensor=True)
else:
kb_embeddings = None
# ์ญ์ ์ (์ทจ์์ ) ์ ์ฉ ํจ์
def apply_strike(text, del_section="6000~6500๋ง์, ์ฑ๊ณผ๊ธ 1800~2400๋ง์"):
# ๊ธ์ฌ ์ ๋ณด๊ฐ ํฌํจ๋ ๋ต๋ณ์ผ ๋๋ง strike-through
if del_section in text:
return text.replace(
del_section,
f"<s>{del_section}</s>"
)
return text
# Chitchat(์ธ์ฌ ๋ฑ) ๋งค์นญ
def find_chitchat(user_question):
uq = user_question.lower()
for chat in chitchat_map:
if any(kw in uq for kw in chat.get('keywords', [])):
return chat['answer']
return None
# ํค์๋ ๊ธฐ๋ฐ Q ๋งคํ (๋ณต์ง: ํด๊ฐ ์ ๋, ๊ต์ก, ๋ณต๋ฆฌํ์ ๋ฑ ๊ฐํ)
def map_user_question_to_knowledge(user_question):
uq = user_question.lower()
for item in keyword_map:
for kw in item.get('keywords', []):
if kw in uq:
return item['question']
return None
def find_answer_by_question(q):
for item in knowledge_base:
if item['question'] == q:
return item['answer']
return None
def find_answer_by_keywords(user_question):
uq = user_question.lower()
for item in knowledge_base:
for kw in item.get('keywords', []):
if kw in uq:
return item['answer']
return None
def best_faq_answer(user_question):
uq = user_question.strip()
if not uq:
return "๋ฌด์์ด ๊ถ๊ธํ์ ์ง ๋ง์ํด ์ฃผ์ธ์!"
chit = find_chitchat(uq)
if chit:
return chit
# (1) ํค์๋๋งต ์ฐ์ ๋งคํ (๋ณต์ง/๊ธ์ฌ ๊ฐ๊ฐ ๋ถ๋ฆฌ)
mapped_q = map_user_question_to_knowledge(uq)
if mapped_q:
answer = find_answer_by_question(mapped_q)
if answer:
# ๋ณต์ง ๋ถ์ผ: '์ฐ๋ด ์์ค' ๋ต๋ณ ์๋ ๊ฒฝ์ฐ์๋ ์ญ์ ์ ์์
if "์ฐ๋ด" in mapped_q:
return apply_strike(answer)
return answer
# (2) knowledge_base ์ง์ ํค์๋ ๋งค์นญ (๋ณต์ง ๊ด๋ จ ํค์๋ ๊ฐํ๋์ด์ผ ํจ!)
answer = find_answer_by_keywords(uq)
if answer:
return answer
# (3) ์๋ฒ ๋ฉ ์ ์ฌ๋ ๊ธฐ๋ฐ soft-matching
if kb_embeddings is not None and len(kb_answers) > 0:
q_emb = model.encode([uq], convert_to_tensor=True)
scores = util.cos_sim(q_emb, kb_embeddings)[0]
best_idx = int(torch.argmax(scores))
best_question = kb_questions[best_idx]
# ๋ณต์ง์ง๋ฌธ์ธ๋ฐ ์ฐ๋ดํค์๋ ๋งค์นญ๋๋ ๊ฒฝ์ฐ, ๋ณต์ง ์ฐ์ ๋ต๋ณ์ ์ ํํ๋๋ก
# ์๋ if์์ ์ค์ ๋ณต์ง ๋ต๋ณ ์ฐ์ ์ฝ๋
๋ณต์ง๊ฐ๋ฅ = ["๋ณต์ง", "ํด๊ฐ", "๊ต์ก", "ํ์ฌ", "๋ํธํ", "๋ณต๋ฆฌํ์", "์ ๋"]
์ฐ๋ด๊ฐ๋ฅ = ["์ฐ๋ด", "๊ธ์ฌ", "์๊ธ", "์๊ธ", "๋ณด์", "๋ด๊ธ", "์ฒ์ฐ"]
if any(w in uq for w in ๋ณต์ง๊ฐ๋ฅ) and not any(w in best_question for w in ์ฐ๋ด๊ฐ๋ฅ):
return kb_answers[best_idx]
# ์ญ์ ์ ์ ์ฐ๋ด ๋ต๋ณ์๋ง
if "์ฐ๋ด" in best_question or "๊ธ์ฌ" in best_question:
return apply_strike(kb_answers[best_idx])
return kb_answers[best_idx]
# (4) fallback
return persona.get('style', {}).get('unknown_answer', "์์ง ์ค๋น๋์ง ์์ ์ง๋ฌธ์
๋๋ค. ๋ค๋ฅธ ์ง๋ฌธ๋ ํด์ฃผ์ธ์!")
# ์ง๋ฌธ ๋ฐ์ ๋๋ง๋ค CEO ์์ ๋ณต์ฌ๋ณธ ์์ํ์ผ๋ก ์์ฑ โ autoplay ํ์ค
def get_temp_video_copy():
temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
temp_filepath = temp_file.name
temp_file.close()
shutil.copyfile(CEO_VIDEO_FILE, temp_filepath)
return temp_filepath
def chat_interface(message, history):
bot_response = best_faq_answer(message)
history.append((message, bot_response))
temp_video_path = get_temp_video_copy()
# ํ
์คํธ์ html๊ฐ๋ฅํ๋ฉด answer์ html์ญ์ ์ ์ ์ง
return history, "", gr.update(value=temp_video_path, autoplay=True, interactive=False, elem_id="ceo-video-panel")
with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
with gr.Row(elem_id="main-row"):
with gr.Column(scale=1, min_width=350):
video_player = gr.Video(
value=CEO_VIDEO_FILE,
autoplay=False, loop=False, interactive=False,
height=350, elem_id="ceo-video-panel"
)
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="",
height=350,
elem_id="chatbot-box",
show_copy_button=True
)
with gr.Row():
msg_input = gr.Textbox(placeholder="๋ฌด์์ด๋ ๋ฌผ์ด๋ณด์ธ์.", scale=4, show_label=False)
send_btn = gr.Button("์ ์ก", scale=1, min_width=80)
gr.Examples(
examples=["๋ณต์ง ๋ญ ์์ด?", "ํด๊ฐ ์ ๋ ์ค๋ช
ํด์ค", "์ฐ๋ด ์๋ ค์ค", "๋ํธํ ํ์ฌ?", "์์ฌ์ ๊ณต?", "์ฃผ๋ ฅ์ ํ", "์กฐ์ง๋ฌธํ"],
inputs=msg_input
)
# ์ฐ๊ฒฐ
outputs_list = [chatbot, msg_input, video_player]
msg_input.submit(chat_interface, [msg_input, chatbot], outputs_list)
send_btn.click(chat_interface, [msg_input, chatbot], outputs_list)
if __name__ == "__main__":
demo.launch()
|