Spaces:
Sleeping
Sleeping
import os | |
import chromadb | |
from sentence_transformers import SentenceTransformer | |
from google import genai | |
import gradio as gr | |
# === νκ²½ μ€μ === | |
DB_DIR = os.getenv("CHROMA_DB_DIR", os.path.join(os.getcwd(), "chromadb_KH_media")) | |
os.environ["CHROMA_DB_DIR"] = DB_DIR | |
API_KEY = os.getenv("GOOGLE_API_KEY", "YOUR_API_KEY_HERE") | |
# === Simple RAG μμ€ν === | |
class SimpleRAGSystem: | |
def __init__(self, db_path=None, collection_name="KH_media_docs"): | |
path = db_path or DB_DIR | |
self.encoder = SentenceTransformer("snunlp/KR-SBERT-V40K-klueNLI-augSTS") | |
self.client = chromadb.PersistentClient(path=path) | |
self.collection = self.client.get_collection(name=collection_name) | |
self.available = self.collection.count() > 0 | |
def search(self, query, top_k=8): | |
if not self.available: | |
return [] | |
emb = self.encoder.encode(query).tolist() | |
result = self.collection.query( | |
query_embeddings=[emb], | |
n_results=top_k, | |
include=["documents"] | |
) | |
return result.get("documents", [[]])[0] | |
rag = SimpleRAGSystem() | |
# === Google GenAI ν΄λΌμ΄μΈνΈ === | |
client = genai.Client(api_key=API_KEY) | |
# === μμ€ν λ©μμ§ === | |
SYSTEM_MSG = """ | |
λΉμ μ κ²½ν¬λνκ΅ λ―Έλμ΄νκ³Ό μ λ¬Έ μλ΄ AIμ λλ€. | |
""" | |
# === μλ΅ ν¨μ === | |
def respond(message, history, system_message, max_tokens, temperature, top_p, model_name): | |
docs = rag.search(message) if rag.available else [] | |
ctx = "\n".join(f"μ°Έκ³ λ¬Έμ{i+1}: {d}" for i, d in enumerate(docs)) | |
sys_msg = system_message + ("\n# μ°Έκ³ λ¬Έμ:\n" + ctx if ctx else "") | |
convo = "".join(f"μ¬μ©μ: {u}\nAI: {a}\n" for u, a in history) | |
prompt = f"{sys_msg}\n{convo}μ¬μ©μ: {message}\nAI:" | |
try: | |
response = client.models.generate_content( | |
model=model_name, | |
contents=prompt, | |
config={ | |
"max_output_tokens": max_tokens, | |
"temperature": temperature, | |
"top_p": top_p | |
} | |
) | |
return response.text or "μλ΅μ΄ μμ΅λλ€." | |
except Exception as e: | |
err = str(e).lower() | |
if "quota" in err: | |
return "API ν λΉλμ μ΄κ³Όνμ΅λλ€. λμ€μ μλν΄μ£ΌμΈμ." | |
if "authentication" in err: | |
return "μΈμ¦ μ€λ₯: API ν€λ₯Ό νμΈνμΈμ." | |
return f"μ€λ₯ λ°μ: {e}" | |
# === Gradio μΈν°νμ΄μ€ === | |
demo = gr.ChatInterface( | |
fn=respond, | |
title="π¬ κ²½ν¬λνκ΅ λ―Έλμ΄νκ³Ό AI μλ΄μ¬", | |
description="κ²½ν¬λνκ΅ λ―Έλμ΄νκ³Όμ λν΄ λ¬Όμ΄λ³΄μΈμ!", | |
additional_inputs=[ | |
gr.Textbox(value=SYSTEM_MSG, label="μμ€ν λ©μμ§", lines=2), | |
gr.Slider(128, 2048, value=1024, step=64, label="μ΅λ ν ν°"), | |
gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"), | |
gr.Dropdown( | |
choices=[ | |
"gemini-2.0-flash", "gemini-2.0-flash-lite", | |
"gemini-1.5-flash", "gemini-1.5-pro", | |
"gemma-3-27b-it", "gemma-3-12b-it", "gemma-3-4b-it" | |
], | |
value="gemini-2.0-flash", | |
label="λͺ¨λΈ μ ν" | |
) | |
], | |
theme="soft", | |
analytics_enabled=False, | |
) | |
if __name__ == "__main__": | |
demo.launch(share=False) | |