jonghhhh commited on
Commit
3708980
ยท
verified ยท
1 Parent(s): a6d830f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -64
app.py CHANGED
@@ -2,27 +2,27 @@ import os
2
  import asyncio
3
  from typing import List, Dict
4
 
5
- # Protobuf C-extension ๋Œ€์‹  pure-Python ๊ตฌํ˜„
6
- os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
7
 
8
- # .env ๋ฐ Space Secrets ๋กœ๋“œ
9
  from dotenv import load_dotenv
10
  load_dotenv()
11
 
12
- # Gradio client ๋ฒ„๊ทธ ์šฐํšŒ์šฉ ํŒจ์น˜
13
  import gradio_client.utils as client_utils
14
- orig_json_to_python = client_utils.json_schema_to_python_type
15
- def safe_json_to_python(schema):
16
  try:
17
- return orig_json_to_python(schema)
18
  except Exception:
19
  return "Any"
20
- client_utils.json_schema_to_python_type = safe_json_to_python
21
 
22
  # Google API Key ๊ฒ€์ฆ
23
  api_key = os.getenv("GOOGLE_API_KEY")
24
  if not api_key:
25
- raise EnvironmentError("GOOGLE_API_KEY๋ฅผ Secrets์— ์ถ”๊ฐ€ํ•ด์ฃผ์„ธ์š”.")
26
  os.environ["GOOGLE_API_KEY"] = api_key
27
 
28
  # ChromaDB ๊ฒฝ๋กœ ์„ค์ •
@@ -42,46 +42,31 @@ from google.genai import types
42
  class SimpleRAGSystem:
43
  def __init__(self, db_path: str = None, collection: str = "KH_media_docs"):
44
  db_path = db_path or os.getenv("CHROMA_DB_DIR")
45
- print(f"๐Ÿ”„ RAG ์ดˆ๊ธฐํ™”: DB ๊ฒฝ๋กœ={db_path}, ์ปฌ๋ ‰์…˜={collection}")
46
- # ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ
47
- print("๐Ÿ“ KR-SBERT ๋กœ๋“œ ์ค‘...")
48
  self.embedding_model = SentenceTransformer("snunlp/KR-SBERT-V40K-klueNLI-augSTS")
49
- print("โœ… ์ž„๋ฒ ๋”ฉ ์ค€๋น„ ์™„๋ฃŒ")
50
- # ChromaDB ์—ฐ๊ฒฐ
51
- print(f"๐Ÿ—„๏ธ ChromaDB ์—ฐ๊ฒฐ: {db_path}")
52
  self.client = chromadb.PersistentClient(path=db_path)
53
  self.collection = self.client.get_collection(name=collection)
54
  count = self.collection.count()
55
- print(f"โœ… DB ์ค€๋น„๋จ: {collection} ({count}๋ฌธ์„œ)")
56
  if count == 0:
57
- raise RuntimeError("ChromaDB๊ฐ€ ๋น„์–ด์žˆ์Šต๋‹ˆ๋‹ค. ๋จผ์ € chromadb_builder.py๋ฅผ ์‹คํ–‰ํ•˜์„ธ์š”.")
58
 
59
  def search_similar_docs(self, query: str, top_k: int = 5) -> List[Dict]:
60
  emb = self.embedding_model.encode(query).tolist()
61
  res = self.collection.query(
62
- query_embeddings=[emb],
63
- n_results=top_k,
64
- include=["documents","metadatas","distances"]
65
  )
66
  docs = []
67
- for doc, meta, dist in zip(res["documents"][0], res["metadatas"][0], res.get("distances", [[]])[0]):
68
- score = (2 - dist)/2 if dist is not None else None
69
- docs.append({"content":doc, "metadata":meta, "similarity":score})
70
  return docs
71
 
72
- def get_status(self) -> str:
73
- return f"๐Ÿ“Š '{self.collection.name}'์— {self.collection.count()}๋ฌธ์„œ ์ค€๋น„"
74
-
75
- # RAG ์‹œ์Šคํ…œ ์ธ์Šคํ„ด์Šค
76
  rag_system = SimpleRAGSystem()
77
 
78
  # === Google ADK ์„ค์ • ===
79
  session_svc = InMemorySessionService()
80
- agent = Agent(
81
- model="gemini-2.0-flash-lite",
82
- name="khu_media_advisor",
83
- instruction="๋‹น์‹ ์€ ๊ฒฝํฌ๋Œ€ ๋ฏธ๋””์–ดํ•™๊ณผ ์ƒ๋‹ด AI์ž…๋‹ˆ๋‹ค. ์ฐธ๊ณ  ๋ฌธ์„œ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๋‹ต๋ณ€ํ•˜์„ธ์š”."
84
- )
85
  runner = Runner(agent=agent, app_name="khu_media_chatbot", session_service=session_svc)
86
  session_id = None
87
 
@@ -90,52 +75,42 @@ async def get_response(prompt: str) -> str:
90
  if session_id is None:
91
  sess = await session_svc.create_session(app_name="khu_media_chatbot", user_id="user")
92
  session_id = sess.id
93
- response = ""
94
  content = types.Content(role="user", parts=[types.Part(text=prompt)])
 
95
  for ev in runner.run(user_id="user", session_id=session_id, new_message=content):
96
  if ev.is_final_response():
97
  response = ev.content.parts[0].text
98
  return response
99
 
100
- # === Gradio Blocks UI ===
101
  with gr.Blocks(title="๊ฒฝํฌ๋Œ€ ๋ฏธ๋””์–ดํ•™๊ณผ AI ์ƒ๋‹ด์‚ฌ", theme="soft") as app:
102
  gr.Markdown("# ๐ŸŽฌ ๊ฒฝํฌ๋Œ€ ๋ฏธ๋””์–ดํ•™๊ณผ AI ์ƒ๋‹ด์‚ฌ")
103
- with gr.Row():
104
- with gr.Column(scale=3):
105
- chatbot = gr.Chatbot(type="messages", height=400, label="๋Œ€ํ™”")
106
- msg = gr.Textbox(show_label=False, placeholder="์งˆ๋ฌธํ•˜์„ธ์š”...", lines=1)
107
- send = gr.Button("์ „์†ก")
108
- gr.Markdown("### ๐Ÿ’ก FAQ")
109
- for q in ["์ „๊ณต ์ •๋ณด","์ฃผ์š” ๊ณผ๋ชฉ","์ทจ์—… ์ „๋ง","ํ•™์ƒ ํ™œ๋™"]:
110
- btn = gr.Button(q, size="sm")
111
- btn.click(lambda x=q: x, outputs=[msg])
112
- with gr.Column(scale=1):
113
- status = gr.Textbox(label="์‹œ์Šคํ…œ ์ƒํƒœ", value=rag_system.get_status(), interactive=False)
114
- gr.Button("๐Ÿ”„ ์ƒˆ๋กœ๊ณ ์นจ").click(lambda: rag_system.get_status(), outputs=[status])
115
- gr.Button("๐Ÿ—‘๏ธ ์ดˆ๊ธฐํ™”").click(lambda: [], outputs=[chatbot])
116
  def chat_fn(user_input, history):
 
 
117
  docs = rag_system.search_similar_docs(user_input)
118
- context = "\n".join(f"โ€ข {d['content']}" for d in docs)
119
- prompt = f"์งˆ๋ฌธ: {user_input}\n\n===์ฐธ๊ณ ===\n{context}\n\n๋‹ต๋ณ€:"
 
 
 
 
120
  resp = asyncio.run(get_response(prompt))
121
- history = history or []
122
- history.append([user_input, resp])
123
- return history, []
124
  send.click(chat_fn, inputs=[msg, chatbot], outputs=[chatbot, msg])
125
  msg.submit(chat_fn, inputs=[msg, chatbot], outputs=[chatbot, msg])
 
126
  gr.Markdown(f"""
127
- ---
128
- **์ž„๋ฒ ๋”ฉ**: KR-SBERT
129
- **DB**: ChromaDB
130
- **LLM**: Gemini Flash-Lite
131
- **์ƒํƒœ**: {rag_system.get_status()}
132
- """
133
  )
134
 
135
  if __name__ == "__main__":
136
- app.launch(
137
- server_name="0.0.0.0",
138
- server_port=int(os.environ.get("PORT", 7860)),
139
- share=False,
140
- show_api=False # API ํƒญ ์ˆจ๊ธฐ๊ธฐ
141
- )
 
2
  import asyncio
3
  from typing import List, Dict
4
 
5
+ # Protobuf C-extension ๋Œ€์‹  pure-Python ๊ตฌํ˜„ ์‚ฌ์šฉ
6
+ ios.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
7
 
8
+ # .env ํŒŒ์ผ ๋ฐ Space Secrets ๋กœ๋“œ
9
  from dotenv import load_dotenv
10
  load_dotenv()
11
 
12
+ # Gradio client ๋ฒ„๊ทธ ์šฐํšŒ (OpenAPI ํŒŒ์‹ฑ)
13
  import gradio_client.utils as client_utils
14
+ orig = client_utils.json_schema_to_python_type
15
+ def safe_json_type(schema, defs=None):
16
  try:
17
+ return orig(schema, defs)
18
  except Exception:
19
  return "Any"
20
+ client_utils.json_schema_to_python_type = safe_json_type
21
 
22
  # Google API Key ๊ฒ€์ฆ
23
  api_key = os.getenv("GOOGLE_API_KEY")
24
  if not api_key:
25
+ raise EnvironmentError("GOOGLE_API_KEY๋ฅผ Settingsโ†’Secrets์— ์ถ”๊ฐ€ํ•ด์ฃผ์„ธ์š”.")
26
  os.environ["GOOGLE_API_KEY"] = api_key
27
 
28
  # ChromaDB ๊ฒฝ๋กœ ์„ค์ •
 
42
  class SimpleRAGSystem:
43
  def __init__(self, db_path: str = None, collection: str = "KH_media_docs"):
44
  db_path = db_path or os.getenv("CHROMA_DB_DIR")
 
 
 
45
  self.embedding_model = SentenceTransformer("snunlp/KR-SBERT-V40K-klueNLI-augSTS")
 
 
 
46
  self.client = chromadb.PersistentClient(path=db_path)
47
  self.collection = self.client.get_collection(name=collection)
48
  count = self.collection.count()
 
49
  if count == 0:
50
+ raise RuntimeError("ChromaDB๊ฐ€ ๋น„์–ด์žˆ์Šต๋‹ˆ๋‹ค.")
51
 
52
  def search_similar_docs(self, query: str, top_k: int = 5) -> List[Dict]:
53
  emb = self.embedding_model.encode(query).tolist()
54
  res = self.collection.query(
55
+ query_embeddings=[emb], n_results=top_k,
56
+ include=["documents", "metadatas"]
 
57
  )
58
  docs = []
59
+ for doc, meta in zip(res["documents"][0], res["metadatas"][0]):
60
+ docs.append({"role": "system", "content": doc})
 
61
  return docs
62
 
 
 
 
 
63
  rag_system = SimpleRAGSystem()
64
 
65
  # === Google ADK ์„ค์ • ===
66
  session_svc = InMemorySessionService()
67
+ agent = Agent(model="gemini-2.0-flash-lite",
68
+ name="khu_media_advisor",
69
+ instruction="๊ฒฝํฌ๋Œ€ ๋ฏธ๋””์–ดํ•™๊ณผ ์ƒ๋‹ด์‚ฌ๋กœ ๋‹ต๋ณ€ํ•˜์„ธ์š”.")
 
 
70
  runner = Runner(agent=agent, app_name="khu_media_chatbot", session_service=session_svc)
71
  session_id = None
72
 
 
75
  if session_id is None:
76
  sess = await session_svc.create_session(app_name="khu_media_chatbot", user_id="user")
77
  session_id = sess.id
 
78
  content = types.Content(role="user", parts=[types.Part(text=prompt)])
79
+ response = ""
80
  for ev in runner.run(user_id="user", session_id=session_id, new_message=content):
81
  if ev.is_final_response():
82
  response = ev.content.parts[0].text
83
  return response
84
 
85
+ # === Gradio UI ===
86
  with gr.Blocks(title="๊ฒฝํฌ๋Œ€ ๋ฏธ๋””์–ดํ•™๊ณผ AI ์ƒ๋‹ด์‚ฌ", theme="soft") as app:
87
  gr.Markdown("# ๐ŸŽฌ ๊ฒฝํฌ๋Œ€ ๋ฏธ๋””์–ดํ•™๊ณผ AI ์ƒ๋‹ด์‚ฌ")
88
+ chatbot = gr.Chatbot(type="messages", height=400)
89
+ msg = gr.Textbox(show_label=False, placeholder="์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”...")
90
+ send = gr.Button("์ „์†ก")
91
+
 
 
 
 
 
 
 
 
 
92
  def chat_fn(user_input, history):
93
+ history = history or []
94
+ # RAG ์ปจํ…์ŠคํŠธ
95
  docs = rag_system.search_similar_docs(user_input)
96
+ # Combine existing history (dicts) with new user message
97
+ new_history = history + [{"role": "user", "content": user_input}]
98
+ # Insert docs as system messages
99
+ new_history += docs
100
+ # Build prompt text from history
101
+ prompt = "\n".join([f"{m['role']}: {m['content']}" for m in new_history])
102
  resp = asyncio.run(get_response(prompt))
103
+ new_history.append({"role": "assistant", "content": resp})
104
+ return new_history, ""
105
+
106
  send.click(chat_fn, inputs=[msg, chatbot], outputs=[chatbot, msg])
107
  msg.submit(chat_fn, inputs=[msg, chatbot], outputs=[chatbot, msg])
108
+
109
  gr.Markdown(f"""
110
+ ---
111
+ **์ƒํƒœ**: ChromaDB์— ์ค€๋น„๋œ ๋ฌธ์„œ ์ˆ˜ = {rag_system.collection.count()}
112
+ """
 
 
 
113
  )
114
 
115
  if __name__ == "__main__":
116
+ app.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT",7860)), share=False, show_api=False)