Spaces:
Sleeping
Sleeping
import os | |
import chromadb | |
from sentence_transformers import SentenceTransformer | |
from google import genai | |
import gradio as gr | |
# === ํ๊ฒฝ ์ค์ === | |
DB_DIR = os.getenv("CHROMA_DB_DIR", os.path.join(os.getcwd(), "chromadb_KH_media")) | |
os.environ["CHROMA_DB_DIR"] = DB_DIR | |
API_KEY = os.getenv("GOOGLE_API_KEY", "AIzaSyCoglAa_T_27Qu-nVULgvsV9oPlJxNGS2k") | |
# === Simple RAG ์์คํ === | |
class SimpleRAGSystem: | |
def __init__(self, db_path=None, collection_name="KH_media_docs"): | |
path = db_path or DB_DIR | |
self.encoder = SentenceTransformer("snunlp/KR-SBERT-V40K-klueNLI-augSTS") | |
self.client = chromadb.PersistentClient(path=path) | |
self.collection = self.client.get_collection(name=collection_name) | |
self.available = self.collection.count() > 0 | |
def search(self, query, top_k=10): | |
if not self.available: | |
return [] | |
emb = self.encoder.encode(query).tolist() | |
result = self.collection.query( | |
query_embeddings=[emb], | |
n_results=top_k, | |
include=["documents"] | |
) | |
return result.get("documents", [[]])[0] | |
rag = SimpleRAGSystem() | |
# === Google GenAI ํด๋ผ์ด์ธํธ === | |
client = genai.Client(api_key=API_KEY) | |
# === ์์คํ ๋ฉ์์ง === | |
SYSTEM_MSG = """ | |
๋น์ ์ ๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ ์ ๋ฌธ ์๋ด AI์ ๋๋ค. | |
# ์ฃผ์ ์ญํ : | |
- ์ ๊ณต๋ ๋ฌธ์ ์ ๋ณด๋ฅผ ๋ฐํ์ผ๋ก ๋ต๋ณ ์ ๊ณต | |
- ๋ฏธ๋์ดํ๊ณผ ๊ด๋ จ ์ง๋ฌธ์ ์น์ ํ๊ณ ๊ตฌ์ฒด์ ์ผ๋ก ์๋ต | |
- ๋ฌธ์์ ์๋ ๋ด์ฉ์ ์ผ๋ฐ ์ง์์ผ๋ก ๋ณด์ (๋จ, ๋ช ์) | |
# ๋ต๋ณ ์คํ์ผ: | |
- ์์ธํ๊ณ ํ๋ถํ ์ค๋ช ์ ํฌํจํ์ฌ ์์ธํ๊ณ ๊ธธ๊ฒ ๋ต๋ณ ์ ๊ณต | |
- ์น๊ทผํ๊ณ ๋์์ด ๋๋ ์๋ด์ฌ ํค | |
- ํต์ฌ ์ ๋ณด๋ฅผ ๋ช ํํ๊ฒ ์ ๋ฌ | |
- ์ถ๊ฐ ๊ถ๊ธํ ์ ์ด ์์ผ๋ฉด ์ธ์ ๋ ๋ฌผ์ด๋ณด๋ผ๊ณ ์๋ด | |
# ์ฐธ๊ณ ๋ฌธ์ ํ์ฉ: | |
- ๋ฌธ์ ๋ด์ฉ์ด ์์ผ๋ฉด ๊ตฌ์ฒด์ ์ผ๋ก ์ธ์ฉ | |
- ์ฌ๋ฌ ๋ฌธ์์ ์ ๋ณด๋ฅผ ์ข ํฉํ์ฌ ๋ต๋ณ ์์ฑ | |
- ์ ํํ์ง ์์ ์ ๋ณด๋ ์ถ์ธกํ์ง ๋ง๊ณ ์์งํ๊ฒ ๋ชจ๋ฅธ๋ค๊ณ ๋ต๋ณ | |
# ํ์ฌ ๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ ๊ต์์ง: | |
์ด์ธํฌ, ๊นํ์ฉ, ๋ฐ์ข ๋ฏผ, ํ์ง์, ์ด์ ๊ต, ์ด๊ธฐํ, ์ด์ ์, ์กฐ์์, ์ด์ข ํ, ์ด๋ํฉ, ์ด์์, ์ดํ, ์ต์์ง, ์ต๋ฏผ์, ๊น๊ดํธ | |
""" | |
# === ์๋ต ํจ์ === | |
def respond(message, history, system_message, max_tokens, temperature, top_p, model_name): | |
docs = rag.search(message) if rag.available else [] | |
ctx = "\n".join(f"์ฐธ๊ณ ๋ฌธ์{i+1}: {d}" for i, d in enumerate(docs)) | |
sys_msg = system_message + ("\n# ์ฐธ๊ณ ๋ฌธ์:\n" + ctx if ctx else "") | |
convo = "".join(f"์ฌ์ฉ์: {u}\nAI: {a}\n" for u, a in history) | |
prompt = f"{sys_msg}\n{convo}์ฌ์ฉ์: {message}\nAI:" | |
try: | |
response = client.models.generate_content( | |
model=model_name, | |
contents=prompt, | |
config={ | |
"max_output_tokens": max_tokens, | |
"temperature": temperature, | |
"top_p": top_p | |
} | |
) | |
return response.text or "์๋ต์ด ์์ต๋๋ค." | |
except Exception as e: | |
err = str(e).lower() | |
if "quota" in err: | |
return "API ํ ๋น๋์ ์ด๊ณผํ์ต๋๋ค. ๋์ค์ ์๋ํด์ฃผ์ธ์." | |
if "authentication" in err: | |
return "์ธ์ฆ ์ค๋ฅ: API ํค๋ฅผ ํ์ธํ์ธ์." | |
return f"์ค๋ฅ ๋ฐ์: {e}" | |
# === Gradio ์ธํฐํ์ด์ค === | |
demo = gr.ChatInterface( | |
fn=respond, | |
title="๐ฌ ๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ AI ์๋ด์ฌ", | |
description="๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ์ ๋ํด ๋ฌผ์ด๋ณด์ธ์!", | |
additional_inputs=[ | |
gr.Slider(128, 2048, value=1024, step=64, label="์ต๋ ํ ํฐ"), | |
gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"), | |
gr.Dropdown( | |
choices=[ | |
"gemini-2.0-flash", "gemini-2.0-flash-lite", | |
"gemini-1.5-flash", "gemini-1.5-pro", | |
"gemma-3-27b-it", "gemma-3-12b-it", "gemma-3-4b-it" | |
], | |
value="gemini-2.0-flash", | |
label="๋ชจ๋ธ ์ ํ" | |
) | |
], | |
additional_inputs_accordion="๐ง ๊ณ ๊ธ ์ค์ ", | |
examples=[ | |
["๋ฏธ๋์ดํ๊ณผ์์ ๋ฐฐ์ฐ๋ ์ฃผ์ ๊ณผ๋ชฉ๋ค์ ๋ฌด์์ธ๊ฐ์?"], | |
["๋ฏธ๋์ดํ๊ณผ ๊ต์์ง์ ์๊ฐํด์ฃผ์ธ์."], | |
["๋ฏธ๋์ดํ๊ณผ ์กธ์ ํ ์ง๋ก๋ ์ด๋ป๊ฒ ๋๋์?"], | |
["๋ฏธ๋์ดํ๊ณผ ์ ํ ์ ํ์ ๋ํด ์๋ ค์ฃผ์ธ์."], | |
["๋ฏธ๋์ดํ๊ณผ ๋์๋ฆฌ๋ ํ์ ํ๋์ ์ด๋ค ๊ฒ๋ค์ด ์๋์?"] | |
], | |
type="messages", | |
theme="soft", | |
analytics_enabled=False | |
) | |
# === ์์คํ ์ ๋ณด ํ์ === | |
with demo: | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(f""" | |
--- | |
### โ๏ธ ์์คํ ์ ๋ณด | |
**์ธ์ด ๋ชจ๋ธ**: Google Gemini 2.0 Flash, Gemma 3 (4B/12B/27B) ์ ํ ๊ฐ๋ฅ | |
**์๋ฒ ๋ฉ ๋ชจ๋ธ**: snunlp/KR-SBERT-V40K-klueNLI-augSTS (ํ๊ตญ์ด ํนํ) | |
**RAG ์ํ**: {"โ ํ์ฑํ" if rag.available else "โ ๋นํ์ฑํ"} | |
**๋ฌธ์ ์**: {rag.collection.count() if rag.available else "0"}๊ฐ | |
""") | |
if __name__ == "__main__": | |
demo.launch(share=False) | |