Spaces:
Sleeping
Sleeping
import os | |
import asyncio | |
from typing import List, Dict | |
# Protobuf C-extension ๋์ pure-Python ๊ตฌํ ์ฌ์ฉ | |
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" | |
# .env ํ์ผ ๋ฐ Space Secrets ๋ก๋ | |
from dotenv import load_dotenv | |
load_dotenv() | |
# Gradio client ๋ฒ๊ทธ ์ฐํ (OpenAPI ํ์ฑ) | |
import gradio_client.utils as client_utils | |
orig = client_utils.json_schema_to_python_type | |
def safe_json_type(schema, defs=None): | |
try: | |
return orig(schema, defs) | |
except Exception: | |
return "Any" | |
client_utils.json_schema_to_python_type = safe_json_type | |
# Google API Key ๊ฒ์ฆ | |
api_key = os.getenv("GOOGLE_API_KEY") | |
if not api_key: | |
raise EnvironmentError("GOOGLE_API_KEY๋ฅผ SettingsโSecrets์ ์ถ๊ฐํด์ฃผ์ธ์.") | |
os.environ["GOOGLE_API_KEY"] = api_key | |
# ChromaDB ๊ฒฝ๋ก ์ค์ | |
db_dir = os.path.join(os.getcwd(), "chromadb_KH_media") | |
os.environ["CHROMA_DB_DIR"] = db_dir | |
# === ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ํฌํธ === | |
import chromadb | |
import gradio as gr | |
from sentence_transformers import SentenceTransformer | |
from google.adk.agents import Agent | |
from google.adk.sessions import InMemorySessionService | |
from google.adk.runners import Runner | |
from google.genai import types | |
# === ํ์ฌ ๊ต์์ง ๋ชฉ๋ก === | |
PROFESSORS = [ | |
"์ด์ธํฌ", "๊นํ์ฉ", "๋ฐ์ข ๋ฏผ", "ํ์ง์", "์ด์ ๊ต", | |
"์ด๊ธฐํ", "์ด์ ์", "์กฐ์์", "์ด์ข ํ", "์ด๋ํฉ", | |
"์ด์์", "์ดํ", "์ต์์ง", "์ต๋ฏผ์", "๊น๊ดํธ" | |
] | |
# === Simple RAG ์์คํ === | |
class SimpleRAGSystem: | |
def __init__(self, db_path: str = None, collection: str = "KH_media_docs"): | |
db_path = db_path or os.getenv("CHROMA_DB_DIR") | |
self.embedding_model = SentenceTransformer("snunlp/KR-SBERT-V40K-klueNLI-augSTS") | |
self.client = chromadb.PersistentClient(path=db_path) | |
self.collection = self.client.get_collection(name=collection) | |
count = self.collection.count() | |
if count == 0: | |
raise RuntimeError("ChromaDB๊ฐ ๋น์ด์์ต๋๋ค.") | |
def search_similar_docs(self, query: str, top_k: int = 20) -> List[Dict]: | |
emb = self.embedding_model.encode(query).tolist() | |
res = self.collection.query( | |
query_embeddings=[emb], n_results=top_k, | |
include=["documents", "metadatas"] | |
) | |
docs = [] | |
for doc, meta in zip(res["documents"][0], res["metadatas"][0]): | |
docs.append({"role": "system", "content": doc}) | |
return docs | |
rag_system = SimpleRAGSystem() | |
# === Google ADK ์ค์ === | |
session_svc = InMemorySessionService() | |
agent = Agent(model="gemini-2.0-flash-lite", #"gemini-2.0-flash" | |
name="khu_media_advisor", | |
instruction="""๋น์ ์ ๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ ์ ๋ฌธ ์๋ด AI์ ๋๋ค. | |
# ์ฃผ์ ์ญํ : | |
- ์ ๊ณต๋ ๋ฌธ์ ์ ๋ณด๋ฅผ ๋ฐํ์ผ๋ก ๋ต๋ณ ์ ๊ณต | |
- ๋ฏธ๋์ดํ๊ณผ ๊ด๋ จ ์ง๋ฌธ์ ์น์ ํ๊ณ ๊ตฌ์ฒด์ ์ผ๋ก ์๋ต | |
- ๋ฌธ์์ ์๋ ๋ด์ฉ์ ์ผ๋ฐ ์ง์์ผ๋ก ๋ณด์ (๋จ, ๋ช ์) | |
# ๋ต๋ณ ์คํ์ผ: | |
- ์์ธํ๊ณ ํ๋ถํ ์ค๋ช ์ ํฌํจํ์ฌ ์์ธํ๊ณ ๊ธธ๊ฒ ๋ต๋ณ ์ ๊ณต | |
- ์น๊ทผํ๊ณ ๋์์ด ๋๋ ์๋ด์ฌ ํค | |
- ํต์ฌ ์ ๋ณด๋ฅผ ๋ช ํํ๊ฒ ์ ๋ฌ | |
- ์ถ๊ฐ ๊ถ๊ธํ ์ ์ด ์์ผ๋ฉด ์ธ์ ๋ ๋ฌผ์ด๋ณด๋ผ๊ณ ์๋ด | |
# ์ฐธ๊ณ ๋ฌธ์ ํ์ฉ: | |
- ๋ฌธ์ ๋ด์ฉ์ด ์์ผ๋ฉด ๊ตฌ์ฒด์ ์ผ๋ก ์ธ์ฉ | |
- ์ฌ๋ฌ ๋ฌธ์์ ์ ๋ณด๋ฅผ ์ข ํฉํ์ฌ ๋ต๋ณ ์์ฑ | |
- ์ ํํ์ง ์์ ์ ๋ณด๋ ์ถ์ธกํ์ง ๋ง๊ณ ์์งํ๊ฒ ๋ชจ๋ฅธ๋ค๊ณ ๋ต๋ณ | |
# ํ์ฌ ๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ ๊ต์์ง์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค: | |
์ด์ธํฌ, ๊นํ์ฉ, ๋ฐ์ข ๋ฏผ, ํ์ง์, ์ด์ ๊ต, ์ด๊ธฐํ, ์ด์ ์, ์กฐ์์, ์ด์ข ํ, ์ด๋ํฉ, ์ด์์, ์ดํ, ์ต์์ง, ์ต๋ฏผ์, ๊น๊ดํธ""" | |
) | |
runner = Runner(agent=agent, app_name="khu_media_chatbot", session_service=session_svc) | |
session_id = None | |
async def get_response(prompt: str) -> str: | |
global session_id | |
if session_id is None: | |
sess = await session_svc.create_session(app_name="khu_media_chatbot", user_id="user") | |
session_id = sess.id | |
content = types.Content(role="user", parts=[types.Part(text=prompt)]) | |
response = "" | |
for ev in runner.run(user_id="user", session_id=session_id, new_message=content): | |
if ev.is_final_response(): | |
response = ev.content.parts[0].text | |
return response | |
# === Gradio UI === | |
with gr.Blocks(title="๊ฒฝํฌ๋ ๋ฏธ๋์ดํ๊ณผ AI ์๋ด์ฌ", theme="soft") as app: | |
gr.Markdown("# ๐ฌ ๊ฒฝํฌ๋ ๋ฏธ๋์ดํ๊ณผ AI ์๋ด์ฌ") | |
chatbot = gr.Chatbot(type="messages", height=400) | |
msg = gr.Textbox(show_label=False, placeholder="์ด ๊ณณ์ ์ง๋ฌธ์ ์ ๋ ฅํ์ธ์...") | |
send = gr.Button("์ ์ก") | |
def chat_fn(user_input, history): | |
history = history or [] | |
# ์ ๋ ฅ ์ ์ฒ๋ฆฌ: ์ค์ฉ ๋ฐฉ์ง | |
user_input = user_input.replace("์ ๊ณต", "๋ถ์ผ").replace("๊ต์", "๊ต์์ง") | |
# RAG ์ปจํ ์คํธ | |
docs = rag_system.search_similar_docs(user_input) | |
# Combine existing history (dicts) with new user message | |
new_history = history + [{"role": "user", "content": user_input}] | |
# Insert docs as system messages | |
new_history += docs | |
# Build prompt text from history | |
prompt = "\n".join([f"{m['role']}: {m['content']}" for m in new_history]) | |
resp = asyncio.run(get_response(prompt)) | |
new_history.append({"role": "assistant", "content": resp}) | |
return new_history, "" | |
send.click(chat_fn, inputs=[msg, chatbot], outputs=[chatbot, msg]) | |
msg.submit(chat_fn, inputs=[msg, chatbot], outputs=[chatbot, msg]) | |
gr.Markdown(f""" | |
--- | |
### โ๏ธ ์์คํ ์ ๋ณด\n | |
**ChromaDB ๋ฌธ์ ์**: {rag_system.collection.count()}๊ฐ\n | |
**์๋ฒ ๋ฉ ๋ชจ๋ธ**: snunlp/KR-SBERT-V40K-klueNLI-augSTS (ํ๊ตญ์ด ํนํ)\n | |
**์ธ์ด ๋ชจ๋ธ**: Google Gemini 2.0 Flash (๋ฌด๋ฃ) | |
""") | |
if __name__ == "__main__": | |
app.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT",7860)), share=False, show_api=False) | |