Spaces:
Sleeping
Sleeping
File size: 5,249 Bytes
2425d9c c70195a 2425d9c f04f8de c70195a 150a22f 5e96ac1 f04f8de 2425d9c c70195a d19feb1 2425d9c c70195a f04f8de c70195a f04f8de c70195a f04f8de c70195a f04f8de c70195a 0cf18b0 c70195a f04f8de c70195a f04f8de c70195a f04f8de 2425d9c f04f8de 2425d9c c70195a f04f8de 2425d9c f04f8de 2425d9c c70195a f04f8de 2425d9c c70195a 2425d9c c70195a 2425d9c f04f8de 2425d9c 0cf18b0 0392b83 2425d9c 8d55cbb 2425d9c 0cf18b0 0392b83 8d55cbb 2425d9c f04f8de 8d55cbb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import chromadb
from sentence_transformers import SentenceTransformer
from google import genai
import gradio as gr
# === ํ๊ฒฝ ์ค์ ===
DB_DIR = os.getenv("CHROMA_DB_DIR", os.path.join(os.getcwd(), "chromadb_KH_media"))
os.environ["CHROMA_DB_DIR"] = DB_DIR
API_KEY = os.getenv("GOOGLE_API_KEY", "AIzaSyCoglAa_T_27Qu-nVULgvsV9oPlJxNGS2k")
# === Simple RAG ์์คํ
===
class SimpleRAGSystem:
def __init__(self, db_path=None, collection_name="KH_media_docs"):
path = db_path or DB_DIR
self.encoder = SentenceTransformer("snunlp/KR-SBERT-V40K-klueNLI-augSTS")
self.client = chromadb.PersistentClient(path=path)
self.collection = self.client.get_collection(name=collection_name)
self.available = self.collection.count() > 0
def search(self, query, top_k=10):
if not self.available:
return []
emb = self.encoder.encode(query).tolist()
result = self.collection.query(
query_embeddings=[emb],
n_results=top_k,
include=["documents"]
)
return result.get("documents", [[]])[0]
rag = SimpleRAGSystem()
# === Google GenAI ํด๋ผ์ด์ธํธ ===
client = genai.Client(api_key=API_KEY)
# === ์์คํ
๋ฉ์์ง ===
SYSTEM_MSG = """
๋น์ ์ ๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ ์ ๋ฌธ ์๋ด AI์
๋๋ค.
# ์ฃผ์ ์ญํ :
- ์ ๊ณต๋ ๋ฌธ์ ์ ๋ณด๋ฅผ ๋ฐํ์ผ๋ก ๋ต๋ณ ์ ๊ณต
- ๋ฏธ๋์ดํ๊ณผ ๊ด๋ จ ์ง๋ฌธ์ ์น์ ํ๊ณ ๊ตฌ์ฒด์ ์ผ๋ก ์๋ต
- ๋ฌธ์์ ์๋ ๋ด์ฉ์ ์ผ๋ฐ ์ง์์ผ๋ก ๋ณด์ (๋จ, ๋ช
์)
# ๋ต๋ณ ์คํ์ผ:
- ์์ธํ๊ณ ํ๋ถํ ์ค๋ช
์ ํฌํจํ์ฌ ์์ธํ๊ณ ๊ธธ๊ฒ ๋ต๋ณ ์ ๊ณต
- ์น๊ทผํ๊ณ ๋์์ด ๋๋ ์๋ด์ฌ ํค
- ํต์ฌ ์ ๋ณด๋ฅผ ๋ช
ํํ๊ฒ ์ ๋ฌ
- ์ถ๊ฐ ๊ถ๊ธํ ์ ์ด ์์ผ๋ฉด ์ธ์ ๋ ๋ฌผ์ด๋ณด๋ผ๊ณ ์๋ด
# ์ฐธ๊ณ ๋ฌธ์ ํ์ฉ:
- ๋ฌธ์ ๋ด์ฉ์ด ์์ผ๋ฉด ๊ตฌ์ฒด์ ์ผ๋ก ์ธ์ฉ
- ์ฌ๋ฌ ๋ฌธ์์ ์ ๋ณด๋ฅผ ์ข
ํฉํ์ฌ ๋ต๋ณ ์์ฑ
- ์ ํํ์ง ์์ ์ ๋ณด๋ ์ถ์ธกํ์ง ๋ง๊ณ ์์งํ๊ฒ ๋ชจ๋ฅธ๋ค๊ณ ๋ต๋ณ
# ํ์ฌ ๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ ๊ต์์ง:
์ด์ธํฌ, ๊นํ์ฉ, ๋ฐ์ข
๋ฏผ, ํ์ง์, ์ด์ ๊ต, ์ด๊ธฐํ, ์ด์ ์, ์กฐ์์, ์ด์ข
ํ, ์ด๋ํฉ, ์ด์์, ์ดํ, ์ต์์ง, ์ต๋ฏผ์, ๊น๊ดํธ
"""
# === ์๋ต ํจ์ ===
def respond(message, history, system_message, max_tokens, temperature, top_p, model_name):
docs = rag.search(message) if rag.available else []
ctx = "\n".join(f"์ฐธ๊ณ ๋ฌธ์{i+1}: {d}" for i, d in enumerate(docs))
sys_msg = system_message + ("\n# ์ฐธ๊ณ ๋ฌธ์:\n" + ctx if ctx else "")
convo = "".join(f"์ฌ์ฉ์: {u}\nAI: {a}\n" for u, a in history)
prompt = f"{sys_msg}\n{convo}์ฌ์ฉ์: {message}\nAI:"
try:
response = client.models.generate_content(
model=model_name,
contents=prompt,
config={
"max_output_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p
}
)
return response.text or "์๋ต์ด ์์ต๋๋ค."
except Exception as e:
err = str(e).lower()
if "quota" in err:
return "API ํ ๋น๋์ ์ด๊ณผํ์ต๋๋ค. ๋์ค์ ์๋ํด์ฃผ์ธ์."
if "authentication" in err:
return "์ธ์ฆ ์ค๋ฅ: API ํค๋ฅผ ํ์ธํ์ธ์."
return f"์ค๋ฅ ๋ฐ์: {e}"
# === Gradio ์ธํฐํ์ด์ค ===
demo = gr.ChatInterface(
fn=respond,
title="๐ฌ ๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ AI ์๋ด์ฌ",
description="๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ์ ๋ํด ๋ฌผ์ด๋ณด์ธ์!",
additional_inputs=[
gr.Slider(128, 2048, value=1024, step=64, label="์ต๋ ํ ํฐ"),
gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"),
gr.Dropdown(
choices=[
"gemini-2.0-flash", "gemini-2.0-flash-lite",
"gemini-1.5-flash", "gemini-1.5-pro",
"gemma-3-27b-it", "gemma-3-12b-it", "gemma-3-4b-it"
],
value="gemini-2.0-flash",
label="๋ชจ๋ธ ์ ํ"
)
],
additional_inputs_accordion="๐ง ๊ณ ๊ธ ์ค์ ",
examples=[
["๋ฏธ๋์ดํ๊ณผ์์ ๋ฐฐ์ฐ๋ ์ฃผ์ ๊ณผ๋ชฉ๋ค์ ๋ฌด์์ธ๊ฐ์?"],
["๋ฏธ๋์ดํ๊ณผ ๊ต์์ง์ ์๊ฐํด์ฃผ์ธ์."],
["๋ฏธ๋์ดํ๊ณผ ์กธ์
ํ ์ง๋ก๋ ์ด๋ป๊ฒ ๋๋์?"],
["๋ฏธ๋์ดํ๊ณผ ์
ํ ์ ํ์ ๋ํด ์๋ ค์ฃผ์ธ์."],
["๋ฏธ๋์ดํ๊ณผ ๋์๋ฆฌ๋ ํ์ ํ๋์ ์ด๋ค ๊ฒ๋ค์ด ์๋์?"]
],
type="messages",
theme="soft",
analytics_enabled=False
)
# === ์์คํ
์ ๋ณด ํ์ ===
with demo:
with gr.Row():
with gr.Column():
gr.Markdown(f"""
---
### โ๏ธ ์์คํ
์ ๋ณด
**์ธ์ด ๋ชจ๋ธ**: Google Gemini 2.0 Flash, Gemma 3 (4B/12B/27B) ์ ํ ๊ฐ๋ฅ
**์๋ฒ ๋ฉ ๋ชจ๋ธ**: snunlp/KR-SBERT-V40K-klueNLI-augSTS (ํ๊ตญ์ด ํนํ)
**RAG ์ํ**: {"โ
ํ์ฑํ" if rag.available else "โ ๋นํ์ฑํ"}
**๋ฌธ์ ์**: {rag.collection.count() if rag.available else "0"}๊ฐ
""")
if __name__ == "__main__":
demo.launch(share=False)
|