KH_media_chatbot / app1.py
jonghhhh's picture
Rename app.py to app1.py
0c83479 verified
import os
import chromadb
from sentence_transformers import SentenceTransformer
from google import genai
import gradio as gr
# === ํ™˜๊ฒฝ ์„ค์ • ===
DB_DIR = os.getenv("CHROMA_DB_DIR", os.path.join(os.getcwd(), "chromadb_KH_media"))
os.environ["CHROMA_DB_DIR"] = DB_DIR
API_KEY = os.getenv("GOOGLE_API_KEY", "AIzaSyCoglAa_T_27Qu-nVULgvsV9oPlJxNGS2k")
# === Simple RAG ์‹œ์Šคํ…œ ===
class SimpleRAGSystem:
def __init__(self, db_path=None, collection_name="KH_media_docs"):
path = db_path or DB_DIR
self.encoder = SentenceTransformer("snunlp/KR-SBERT-V40K-klueNLI-augSTS")
self.client = chromadb.PersistentClient(path=path)
self.collection = self.client.get_collection(name=collection_name)
self.available = self.collection.count() > 0
def search(self, query, top_k=10):
if not self.available:
return []
emb = self.encoder.encode(query).tolist()
result = self.collection.query(
query_embeddings=[emb],
n_results=top_k,
include=["documents"]
)
return result.get("documents", [[]])[0]
rag = SimpleRAGSystem()
# === Google GenAI ํด๋ผ์ด์–ธํŠธ ===
client = genai.Client(api_key=API_KEY)
# === ์‹œ์Šคํ…œ ๋ฉ”์‹œ์ง€ ===
SYSTEM_MSG = """
๋‹น์‹ ์€ ๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ ์ „๋ฌธ ์ƒ๋‹ด AI์ž…๋‹ˆ๋‹ค.
# ์ฃผ์š” ์—ญํ• :
- ์ œ๊ณต๋œ ๋ฌธ์„œ ์ •๋ณด๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๋‹ต๋ณ€ ์ œ๊ณต
- ๋ฏธ๋””์–ดํ•™๊ณผ ๊ด€๋ จ ์งˆ๋ฌธ์— ์นœ์ ˆํ•˜๊ณ  ๊ตฌ์ฒด์ ์œผ๋กœ ์‘๋‹ต
- ๋ฌธ์„œ์— ์—†๋Š” ๋‚ด์šฉ์€ ์ผ๋ฐ˜ ์ง€์‹์œผ๋กœ ๋ณด์™„ (๋‹จ, ๋ช…์‹œ)
# ๋‹ต๋ณ€ ์Šคํƒ€์ผ:
- ์ž์„ธํ•˜๊ณ  ํ’๋ถ€ํ•œ ์„ค๋ช…์„ ํฌํ•จํ•˜์—ฌ ์ƒ์„ธํ•˜๊ณ  ๊ธธ๊ฒŒ ๋‹ต๋ณ€ ์ œ๊ณต
- ์นœ๊ทผํ•˜๊ณ  ๋„์›€์ด ๋˜๋Š” ์ƒ๋‹ด์‚ฌ ํ†ค
- ํ•ต์‹ฌ ์ •๋ณด๋ฅผ ๋ช…ํ™•ํ•˜๊ฒŒ ์ „๋‹ฌ
- ์ถ”๊ฐ€ ๊ถ๊ธˆํ•œ ์ ์ด ์žˆ์œผ๋ฉด ์–ธ์ œ๋“  ๋ฌผ์–ด๋ณด๋ผ๊ณ  ์•ˆ๋‚ด
# ์ฐธ๊ณ  ๋ฌธ์„œ ํ™œ์šฉ:
- ๋ฌธ์„œ ๋‚ด์šฉ์ด ์žˆ์œผ๋ฉด ๊ตฌ์ฒด์ ์œผ๋กœ ์ธ์šฉ
- ์—ฌ๋Ÿฌ ๋ฌธ์„œ์˜ ์ •๋ณด๋ฅผ ์ข…ํ•ฉํ•˜์—ฌ ๋‹ต๋ณ€ ์ž‘์„ฑ
- ์ •ํ™•ํ•˜์ง€ ์•Š์€ ์ •๋ณด๋Š” ์ถ”์ธกํ•˜์ง€ ๋ง๊ณ  ์†”์งํ•˜๊ฒŒ ๋ชจ๋ฅธ๋‹ค๊ณ  ๋‹ต๋ณ€
# ํ˜„์žฌ ๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ ๊ต์ˆ˜์ง„:
์ด์ธํฌ, ๊น€ํƒœ์šฉ, ๋ฐ•์ข…๋ฏผ, ํ™์ง€์•„, ์ด์ •๊ต, ์ด๊ธฐํ˜•, ์ด์„ ์˜, ์กฐ์ˆ˜์˜, ์ด์ข…ํ˜, ์ด๋‘ํ™ฉ, ์ด์ƒ์›, ์ดํ›ˆ, ์ตœ์ˆ˜์ง„, ์ตœ๋ฏผ์•„, ๊น€๊ด€ํ˜ธ
"""
# === ์‘๋‹ต ํ•จ์ˆ˜ ===
def respond(message, history, system_message, max_tokens, temperature, top_p, model_name):
docs = rag.search(message) if rag.available else []
ctx = "\n".join(f"์ฐธ๊ณ ๋ฌธ์„œ{i+1}: {d}" for i, d in enumerate(docs))
sys_msg = system_message + ("\n# ์ฐธ๊ณ ๋ฌธ์„œ:\n" + ctx if ctx else "")
convo = "".join(f"์‚ฌ์šฉ์ž: {u}\nAI: {a}\n" for u, a in history)
prompt = f"{sys_msg}\n{convo}์‚ฌ์šฉ์ž: {message}\nAI:"
try:
response = client.models.generate_content(
model=model_name,
contents=prompt,
config={
"max_output_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p
}
)
return response.text or "์‘๋‹ต์ด ์—†์Šต๋‹ˆ๋‹ค."
except Exception as e:
err = str(e).lower()
if "quota" in err:
return "API ํ• ๋‹น๋Ÿ‰์„ ์ดˆ๊ณผํ–ˆ์Šต๋‹ˆ๋‹ค. ๋‚˜์ค‘์— ์‹œ๋„ํ•ด์ฃผ์„ธ์š”."
if "authentication" in err:
return "์ธ์ฆ ์˜ค๋ฅ˜: API ํ‚ค๋ฅผ ํ™•์ธํ•˜์„ธ์š”."
return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}"
# === Gradio ์ธํ„ฐํŽ˜์ด์Šค ===
demo = gr.ChatInterface(
fn=respond,
title="๐ŸŽฌ ๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ AI ์ƒ๋‹ด์‚ฌ",
description="๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ์— ๋Œ€ํ•ด ๋ฌผ์–ด๋ณด์„ธ์š”!",
additional_inputs=[
gr.Slider(128, 2048, value=1024, step=64, label="์ตœ๋Œ€ ํ† ํฐ"),
gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"),
gr.Dropdown(
choices=[
"gemini-2.0-flash", "gemini-2.0-flash-lite",
"gemini-1.5-flash", "gemini-1.5-pro",
"gemma-3-27b-it", "gemma-3-12b-it", "gemma-3-4b-it"
],
value="gemini-2.0-flash",
label="๋ชจ๋ธ ์„ ํƒ"
)
],
additional_inputs_accordion="๐Ÿ”ง ๊ณ ๊ธ‰ ์„ค์ •",
examples=[
["๋ฏธ๋””์–ดํ•™๊ณผ์—์„œ ๋ฐฐ์šฐ๋Š” ์ฃผ์š” ๊ณผ๋ชฉ๋“ค์€ ๋ฌด์—‡์ธ๊ฐ€์š”?"],
["๋ฏธ๋””์–ดํ•™๊ณผ ๊ต์ˆ˜์ง„์„ ์†Œ๊ฐœํ•ด์ฃผ์„ธ์š”."],
["๋ฏธ๋””์–ดํ•™๊ณผ ์กธ์—… ํ›„ ์ง„๋กœ๋Š” ์–ด๋–ป๊ฒŒ ๋˜๋‚˜์š”?"],
["๋ฏธ๋””์–ดํ•™๊ณผ ์ž…ํ•™ ์ „ํ˜•์— ๋Œ€ํ•ด ์•Œ๋ ค์ฃผ์„ธ์š”."],
["๋ฏธ๋””์–ดํ•™๊ณผ ๋™์•„๋ฆฌ๋‚˜ ํ•™์ƒ ํ™œ๋™์€ ์–ด๋–ค ๊ฒƒ๋“ค์ด ์žˆ๋‚˜์š”?"]
],
type="messages",
theme="soft",
analytics_enabled=False
)
# === ์‹œ์Šคํ…œ ์ •๋ณด ํ‘œ์‹œ ===
with demo:
with gr.Row():
with gr.Column():
gr.Markdown(f"""
---
### โš™๏ธ ์‹œ์Šคํ…œ ์ •๋ณด
**์–ธ์–ด ๋ชจ๋ธ**: Google Gemini 2.0 Flash, Gemma 3 (4B/12B/27B) ์„ ํƒ ๊ฐ€๋Šฅ
**์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ**: snunlp/KR-SBERT-V40K-klueNLI-augSTS (ํ•œ๊ตญ์–ด ํŠนํ™”)
**RAG ์ƒํƒœ**: {"โœ… ํ™œ์„ฑํ™”" if rag.available else "โŒ ๋น„ํ™œ์„ฑํ™”"}
**๋ฌธ์„œ ์ˆ˜**: {rag.collection.count() if rag.available else "0"}๊ฐœ
""")
if __name__ == "__main__":
demo.launch(share=False)