jonghhhh's picture
Update app.py
f04f8de verified
raw
history blame
3.39 kB
import os
import chromadb
from sentence_transformers import SentenceTransformer
from google import genai
import gradio as gr
# === ν™˜κ²½ μ„€μ • ===
DB_DIR = os.getenv("CHROMA_DB_DIR", os.path.join(os.getcwd(), "chromadb_KH_media"))
os.environ["CHROMA_DB_DIR"] = DB_DIR
API_KEY = os.getenv("GOOGLE_API_KEY", "YOUR_API_KEY_HERE")
# === Simple RAG μ‹œμŠ€ν…œ ===
class SimpleRAGSystem:
def __init__(self, db_path=None, collection_name="KH_media_docs"):
path = db_path or DB_DIR
self.encoder = SentenceTransformer("snunlp/KR-SBERT-V40K-klueNLI-augSTS")
self.client = chromadb.PersistentClient(path=path)
self.collection = self.client.get_collection(name=collection_name)
self.available = self.collection.count() > 0
def search(self, query, top_k=8):
if not self.available:
return []
emb = self.encoder.encode(query).tolist()
result = self.collection.query(
query_embeddings=[emb],
n_results=top_k,
include=["documents"]
)
return result.get("documents", [[]])[0]
rag = SimpleRAGSystem()
# === Google GenAI ν΄λΌμ΄μ–ΈνŠΈ ===
client = genai.Client(api_key=API_KEY)
# === μ‹œμŠ€ν…œ λ©”μ‹œμ§€ ===
SYSTEM_MSG = """
당신은 κ²½ν¬λŒ€ν•™κ΅ λ―Έλ””μ–΄ν•™κ³Ό μ „λ¬Έ 상담 AIμž…λ‹ˆλ‹€.
"""
# === 응닡 ν•¨μˆ˜ ===
def respond(message, history, system_message, max_tokens, temperature, top_p, model_name):
docs = rag.search(message) if rag.available else []
ctx = "\n".join(f"μ°Έκ³ λ¬Έμ„œ{i+1}: {d}" for i, d in enumerate(docs))
sys_msg = system_message + ("\n# μ°Έκ³ λ¬Έμ„œ:\n" + ctx if ctx else "")
convo = "".join(f"μ‚¬μš©μž: {u}\nAI: {a}\n" for u, a in history)
prompt = f"{sys_msg}\n{convo}μ‚¬μš©μž: {message}\nAI:"
try:
response = client.models.generate_content(
model=model_name,
contents=prompt,
config={
"max_output_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p
}
)
return response.text or "응닡이 μ—†μŠ΅λ‹ˆλ‹€."
except Exception as e:
err = str(e).lower()
if "quota" in err:
return "API ν• λ‹ΉλŸ‰μ„ μ΄ˆκ³Όν–ˆμŠ΅λ‹ˆλ‹€. λ‚˜μ€‘μ— μ‹œλ„ν•΄μ£Όμ„Έμš”."
if "authentication" in err:
return "인증 였λ₯˜: API ν‚€λ₯Ό ν™•μΈν•˜μ„Έμš”."
return f"였λ₯˜ λ°œμƒ: {e}"
# === Gradio μΈν„°νŽ˜μ΄μŠ€ ===
demo = gr.ChatInterface(
fn=respond,
title="🎬 κ²½ν¬λŒ€ν•™κ΅ λ―Έλ””μ–΄ν•™κ³Ό AI 상담사",
description="κ²½ν¬λŒ€ν•™κ΅ 미디어학과에 λŒ€ν•΄ λ¬Όμ–΄λ³΄μ„Έμš”!",
additional_inputs=[
gr.Textbox(value=SYSTEM_MSG, label="μ‹œμŠ€ν…œ λ©”μ‹œμ§€", lines=2),
gr.Slider(128, 2048, value=1024, step=64, label="μ΅œλŒ€ 토큰"),
gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"),
gr.Dropdown(
choices=[
"gemini-2.0-flash", "gemini-2.0-flash-lite",
"gemini-1.5-flash", "gemini-1.5-pro",
"gemma-3-27b-it", "gemma-3-12b-it", "gemma-3-4b-it"
],
value="gemini-2.0-flash",
label="λͺ¨λΈ 선택"
)
],
theme="soft",
analytics_enabled=False,
)
if __name__ == "__main__":
demo.launch(share=False)