Spaces:
Sleeping
Sleeping
File size: 3,392 Bytes
2425d9c c70195a 2425d9c f04f8de c70195a f04f8de 5e96ac1 f04f8de 2425d9c c70195a 2425d9c c70195a f04f8de c70195a f04f8de c70195a f04f8de c70195a f04f8de c70195a f04f8de c70195a f04f8de c70195a f04f8de 2425d9c f04f8de 2425d9c c70195a f04f8de 2425d9c f04f8de 2425d9c c70195a f04f8de 2425d9c c70195a 2425d9c c70195a 2425d9c f04f8de 2425d9c f04f8de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import os
import chromadb
from sentence_transformers import SentenceTransformer
from google import genai
import gradio as gr
# === ํ๊ฒฝ ์ค์ ===
DB_DIR = os.getenv("CHROMA_DB_DIR", os.path.join(os.getcwd(), "chromadb_KH_media"))
os.environ["CHROMA_DB_DIR"] = DB_DIR
API_KEY = os.getenv("GOOGLE_API_KEY", "YOUR_API_KEY_HERE")
# === Simple RAG ์์คํ
===
class SimpleRAGSystem:
def __init__(self, db_path=None, collection_name="KH_media_docs"):
path = db_path or DB_DIR
self.encoder = SentenceTransformer("snunlp/KR-SBERT-V40K-klueNLI-augSTS")
self.client = chromadb.PersistentClient(path=path)
self.collection = self.client.get_collection(name=collection_name)
self.available = self.collection.count() > 0
def search(self, query, top_k=8):
if not self.available:
return []
emb = self.encoder.encode(query).tolist()
result = self.collection.query(
query_embeddings=[emb],
n_results=top_k,
include=["documents"]
)
return result.get("documents", [[]])[0]
rag = SimpleRAGSystem()
# === Google GenAI ํด๋ผ์ด์ธํธ ===
client = genai.Client(api_key=API_KEY)
# === ์์คํ
๋ฉ์์ง ===
SYSTEM_MSG = """
๋น์ ์ ๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ ์ ๋ฌธ ์๋ด AI์
๋๋ค.
"""
# === ์๋ต ํจ์ ===
def respond(message, history, system_message, max_tokens, temperature, top_p, model_name):
docs = rag.search(message) if rag.available else []
ctx = "\n".join(f"์ฐธ๊ณ ๋ฌธ์{i+1}: {d}" for i, d in enumerate(docs))
sys_msg = system_message + ("\n# ์ฐธ๊ณ ๋ฌธ์:\n" + ctx if ctx else "")
convo = "".join(f"์ฌ์ฉ์: {u}\nAI: {a}\n" for u, a in history)
prompt = f"{sys_msg}\n{convo}์ฌ์ฉ์: {message}\nAI:"
try:
response = client.models.generate_content(
model=model_name,
contents=prompt,
config={
"max_output_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p
}
)
return response.text or "์๋ต์ด ์์ต๋๋ค."
except Exception as e:
err = str(e).lower()
if "quota" in err:
return "API ํ ๋น๋์ ์ด๊ณผํ์ต๋๋ค. ๋์ค์ ์๋ํด์ฃผ์ธ์."
if "authentication" in err:
return "์ธ์ฆ ์ค๋ฅ: API ํค๋ฅผ ํ์ธํ์ธ์."
return f"์ค๋ฅ ๋ฐ์: {e}"
# === Gradio ์ธํฐํ์ด์ค ===
demo = gr.ChatInterface(
fn=respond,
title="๐ฌ ๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ AI ์๋ด์ฌ",
description="๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ์ ๋ํด ๋ฌผ์ด๋ณด์ธ์!",
additional_inputs=[
gr.Textbox(value=SYSTEM_MSG, label="์์คํ
๋ฉ์์ง", lines=2),
gr.Slider(128, 2048, value=1024, step=64, label="์ต๋ ํ ํฐ"),
gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"),
gr.Dropdown(
choices=[
"gemini-2.0-flash", "gemini-2.0-flash-lite",
"gemini-1.5-flash", "gemini-1.5-pro",
"gemma-3-27b-it", "gemma-3-12b-it", "gemma-3-4b-it"
],
value="gemini-2.0-flash",
label="๋ชจ๋ธ ์ ํ"
)
],
theme="soft",
analytics_enabled=False,
)
if __name__ == "__main__":
demo.launch(share=False)
|