jonghhhh commited on
Commit
c70195a
ยท
verified ยท
1 Parent(s): 5e96ac1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -199
app.py CHANGED
@@ -1,222 +1,88 @@
1
  import os
2
- import asyncio
3
- from typing import List, Dict
4
-
5
- # ChromaDB ๊ฒฝ๋กœ ์„ค์ •
6
- db_dir = os.path.join(os.getcwd(), "chromadb_KH_media")
7
- os.environ["CHROMA_DB_DIR"] = db_dir
8
-
9
- # === ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ž„ํฌํŠธ ===
10
  import chromadb
11
- import gradio as gr
12
  from sentence_transformers import SentenceTransformer
13
  from google import genai
 
14
 
15
- # === ํ˜„์žฌ ๊ต์ˆ˜์ง„ ๋ชฉ๋ก ===
16
- PROFESSORS = [
17
- "์ด์ธํฌ", "๊น€ํƒœ์šฉ", "๋ฐ•์ข…๋ฏผ", "ํ™์ง€์•„", "์ด์ •๊ต",
18
- "์ด๊ธฐํ˜•", "์ด์„ ์˜", "์กฐ์ˆ˜์˜", "์ด์ข…ํ˜", "์ด๋‘ํ™ฉ",
19
- "์ด์ƒ์›", "์ดํ›ˆ", "์ตœ์ˆ˜์ง„", "์ตœ๋ฏผ์•„", "๊น€๊ด€ํ˜ธ"
20
- ]
21
 
22
- # === Simple RAG ์‹œ์Šคํ…œ ===
23
  class SimpleRAGSystem:
24
- def __init__(self, db_path: str = None, collection: str = "KH_media_docs"):
25
- db_path = db_path or os.getenv("CHROMA_DB_DIR")
26
- try:
27
- self.embedding_model = SentenceTransformer("snunlp/KR-SBERT-V40K-klueNLI-augSTS")
28
- self.client = chromadb.PersistentClient(path=db_path)
29
- self.collection = self.client.get_collection(name=collection)
30
- count = self.collection.count()
31
- self.available = count > 0
32
- except Exception:
33
- self.available = False
34
-
35
- def search_similar_docs(self, query: str, top_k: int = 10) -> List[str]:
36
  if not self.available:
37
  return []
38
- try:
39
- emb = self.embedding_model.encode(query).tolist()
40
- res = self.collection.query(
41
- query_embeddings=[emb], n_results=top_k,
42
- include=["documents", "metadatas"]
43
- )
44
- return res["documents"][0] if res["documents"] else []
45
- except Exception:
46
- return []
47
-
48
- # RAG ์‹œ์Šคํ…œ ์ดˆ๊ธฐํ™”
49
- try:
50
- rag_system = SimpleRAGSystem()
51
- except Exception:
52
- rag_system = None
53
-
54
- # === Google GenAI Client ์„ค์ • ===
55
- client = genai.Client(api_key="AIzaSyBuBsC5k9yw2JwUfVFn1Zu1qM_ifwGx6cM")
56
-
57
- # === ์‹œ์Šคํ…œ ๋ฉ”์‹œ์ง€ ===
58
- SYSTEM_MESSAGE = """๋‹น์‹ ์€ ๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ ์ „๋ฌธ ์ƒ๋‹ด AI์ž…๋‹ˆ๋‹ค.
59
-
60
- # ์ฃผ์š” ์—ญํ• :
61
- - ์ œ๊ณต๋œ ๋ฌธ์„œ ์ •๋ณด๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๋‹ต๋ณ€ ์ œ๊ณต
62
- - ๋ฏธ๋””์–ดํ•™๊ณผ ๊ด€๋ จ ์งˆ๋ฌธ์— ์นœ์ ˆํ•˜๊ณ  ๊ตฌ์ฒด์ ์œผ๋กœ ์‘๋‹ต
63
- - ๋ฌธ์„œ์— ์—†๋Š” ๋‚ด์šฉ์€ ์ผ๋ฐ˜ ์ง€์‹์œผ๋กœ ๋ณด์™„ (๋‹จ, ๋ช…์‹œ)
64
-
65
- # ๋‹ต๋ณ€ ์Šคํƒ€์ผ:
66
- - ์ž์„ธํ•˜๊ณ  ํ’๋ถ€ํ•œ ์„ค๋ช…์„ ํฌํ•จํ•˜์—ฌ ์ƒ์„ธํ•˜๊ณ  ๊ธธ๊ฒŒ ๋‹ต๋ณ€ ์ œ๊ณต
67
- - ์นœ๊ทผํ•˜๊ณ  ๋„์›€์ด ๋˜๋Š” ์ƒ๋‹ด์‚ฌ ํ†ค
68
- - ํ•ต์‹ฌ ์ •๋ณด๋ฅผ ๋ช…ํ™•ํ•˜๊ฒŒ ์ „๋‹ฌ
69
- - ์ถ”๊ฐ€ ๊ถ๊ธˆํ•œ ์ ์ด ์žˆ์œผ๋ฉด ์–ธ์ œ๋“  ๋ฌผ์–ด๋ณด๋ผ๊ณ  ์•ˆ๋‚ด
70
-
71
- # ์ฐธ๊ณ  ๋ฌธ์„œ ํ™œ์šฉ:
72
- - ๋ฌธ์„œ ๋‚ด์šฉ์ด ์žˆ์œผ๋ฉด ๊ตฌ์ฒด์ ์œผ๋กœ ์ธ์šฉ
73
- - ์—ฌ๋Ÿฌ ๋ฌธ์„œ์˜ ์ •๋ณด๋ฅผ ์ข…ํ•ฉํ•˜์—ฌ ๋‹ต๋ณ€ ์ž‘์„ฑ
74
- - ์ •ํ™•ํ•˜์ง€ ์•Š์€ ์ •๋ณด๋Š” ์ถ”์ธกํ•˜์ง€ ๋ง๊ณ  ์†”์งํ•˜๊ฒŒ ๋ชจ๋ฅธ๋‹ค๊ณ  ๋‹ต๋ณ€
75
-
76
- # ํ˜„์žฌ ๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ ๊ต์ˆ˜์ง„:
77
- ์ด์ธํฌ, ๊น€ํƒœ์šฉ, ๋ฐ•์ข…๋ฏผ, ํ™์ง€์•„, ์ด์ •๊ต, ์ด๊ธฐํ˜•, ์ด์„ ์˜, ์กฐ์ˆ˜์˜, ์ด์ข…ํ˜, ์ด๋‘ํ™ฉ, ์ด์ƒ์›, ์ดํ›ˆ, ์ตœ์ˆ˜์ง„, ์ตœ๋ฏผ์•„, ๊น€๊ด€ํ˜ธ
78
-
79
- ํ•œ๊ตญ์–ด๋กœ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ๋‹ต๋ณ€ํ•ด์ฃผ์„ธ์š”."""
80
-
81
- def respond(
82
- message,
83
- history: list[tuple[str, str]],
84
- system_message,
85
- max_tokens,
86
- temperature,
87
- top_p,
88
- model_name,
89
- ):
90
- # RAG ์ปจํ…์ŠคํŠธ ์ถ”๊ฐ€
91
- context_docs = []
92
- if rag_system and rag_system.available:
93
- try:
94
- context_docs = rag_system.search_similar_docs(message, top_k=8)
95
- except Exception:
96
- pass
97
-
98
- # ์‹œ์Šคํ…œ ๋ฉ”์‹œ์ง€์— ์ปจํ…์ŠคํŠธ ์ถ”๊ฐ€
99
- enhanced_system_message = system_message
100
- if context_docs:
101
- context_text = "\n".join([f"์ฐธ๊ณ ๋ฌธ์„œ {i+1}: {doc}" for i, doc in enumerate(context_docs)])
102
- enhanced_system_message += f"\n\n# ์ฐธ๊ณ  ๋ฌธ์„œ:\n{context_text}"
103
-
104
- # ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ๋ฅผ ํฌํ•จํ•œ ์ „์ฒด ์ปจํ…์ŠคํŠธ ๊ตฌ์„ฑ
105
- full_context = enhanced_system_message + "\n\n"
106
-
107
- # ์ด์ „ ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ ์ถ”๊ฐ€
108
- for val in history:
109
- if val[0]:
110
- full_context += f"์‚ฌ์šฉ์ž: {val[0]}\n"
111
- if val[1]:
112
- full_context += f"AI ์ƒ๋‹ด์‚ฌ: {val[1]}\n"
113
-
114
- # ํ˜„์žฌ ์‚ฌ์šฉ์ž ๋ฉ”์‹œ์ง€ ์ถ”๊ฐ€
115
- full_context += f"์‚ฌ์šฉ์ž: {message}\nAI ์ƒ๋‹ด์‚ฌ: "
116
-
117
  try:
118
- # Google GenAI API ํ˜ธ์ถœ
119
- api_response = client.models.generate_content(
120
  model=model_name,
121
- contents=full_context,
122
- config={
123
- "temperature": temperature,
124
- "top_p": top_p,
125
- "max_output_tokens": max_tokens,
126
- }
127
  )
128
-
129
- if hasattr(api_response, 'text') and api_response.text:
130
- yield api_response.text
131
- else:
132
- yield "์‘๋‹ต์ด ๋น„์–ด์žˆ์Šต๋‹ˆ๋‹ค. ๋‹ค๋ฅธ ๋ชจ๋ธ์„ ์‹œ๋„ํ•ด๋ณด์„ธ์š”."
133
-
134
  except Exception as e:
135
- if "no api" in str(e).lower():
136
- error_msg = "API ์ ‘๊ทผ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค. Google AI Studio์—์„œ API ํ‚ค๋ฅผ ๋‹ค์‹œ ํ™•์ธํ•ด์ฃผ์„ธ์š”."
137
- elif "quota" in str(e).lower():
138
- error_msg = "API ํ• ๋‹น๋Ÿ‰์ด ์ดˆ๊ณผ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์ž ์‹œ ํ›„ ๋‹ค์‹œ ์‹œ๋„ํ•ด์ฃผ์„ธ์š”."
139
- elif "authentication" in str(e).lower():
140
- error_msg = "์ธ์ฆ ์˜ค๋ฅ˜์ž…๋‹ˆ๋‹ค. API ํ‚ค๋ฅผ ํ™•์ธํ•ด์ฃผ์„ธ์š”."
141
- else:
142
- error_msg = f"์‘๋‹ต ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
143
-
144
- yield error_msg
145
-
146
- # === Gradio ์ธํ„ฐํŽ˜์ด์Šค ===
147
  demo = gr.ChatInterface(
148
- respond,
149
  title="๐ŸŽฌ ๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ AI ์ƒ๋‹ด์‚ฌ",
150
- description="""
151
- ๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ์— ๋Œ€ํ•œ ๋ชจ๋“  ๊ถ๊ธˆํ•œ ์ ์„ ๋ฌผ์–ด๋ณด์„ธ์š”!
152
- """,
153
  additional_inputs=[
154
- gr.Slider(
155
- minimum=128,
156
- maximum=2048,
157
- value=1024,
158
- step=64,
159
- label="์ตœ๋Œ€ ํ† ํฐ ์ˆ˜"
160
- ),
161
- gr.Slider(
162
- minimum=0.1,
163
- maximum=1.0,
164
- value=0.7,
165
- step=0.1,
166
- label="Temperature (์ฐฝ์˜์„ฑ)"
167
- ),
168
- gr.Slider(
169
- minimum=0.1,
170
- maximum=1.0,
171
- value=0.9,
172
- step=0.05,
173
- label="Top-p (๋‹ค์–‘์„ฑ)"
174
- ),
175
- gr.Dropdown(
176
- choices=[
177
- "gemini-2.0-flash",
178
- "gemini-2.0-flash-lite",
179
- "gemini-1.5-flash",
180
- "gemini-1.5-pro",
181
- "gemma-3-27b-it",
182
- "gemma-3-12b-it",
183
- "gemma-3-4b-it"
184
- ],
185
- value="gemini-2.0-flash",
186
- label="๋ชจ๋ธ ์„ ํƒ"
187
- ),
188
- ],
189
- additional_inputs_accordion="๐Ÿ”ง ๊ณ ๊ธ‰ ์„ค์ •",
190
- examples=[
191
- ["๋ฏธ๋””์–ดํ•™๊ณผ์—์„œ ๋ฐฐ์šฐ๋Š” ์ฃผ์š” ๊ณผ๋ชฉ๋“ค์€ ๋ฌด์—‡์ธ๊ฐ€์š”?"],
192
- ["๋ฏธ๋””์–ดํ•™๊ณผ ๊ต์ˆ˜์ง„์„ ์†Œ๊ฐœํ•ด์ฃผ์„ธ์š”."],
193
- ["๋ฏธ๋””์–ดํ•™๊ณผ ์กธ์—… ํ›„ ์ง„๋กœ๋Š” ์–ด๋–ป๊ฒŒ ๋˜๋‚˜์š”?"],
194
- ["๋ฏธ๋””์–ดํ•™๊ณผ ์ž…ํ•™ ์ „ํ˜•์— ๋Œ€ํ•ด ์•Œ๋ ค์ฃผ์„ธ์š”."],
195
- ["๋ฏธ๋””์–ดํ•™๊ณผ ๋™์•„๋ฆฌ๋‚˜ ํ•™์ƒ ํ™œ๋™์€ ์–ด๋–ค ๊ฒƒ๋“ค์ด ์žˆ๋‚˜์š”?"]
196
  ],
197
  theme="soft",
198
  analytics_enabled=False,
199
  )
200
 
201
- # === ์‹œ์Šคํ…œ ์ •๋ณด ํ‘œ์‹œ ===
202
- with demo:
203
- with gr.Row():
204
- with gr.Column():
205
- gr.Markdown(f"""
206
- ---
207
- ### โš™๏ธ ์‹œ์Šคํ…œ ์ •๋ณด
208
- **์–ธ์–ด ๋ชจ๋ธ**: Google Gemini 2.0 Flash + Gemma 3 (4B/12B/27B) ์„ ํƒ ๊ฐ€๋Šฅ
209
- **์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ**: snunlp/KR-SBERT-V40K-klueNLI-augSTS (ํ•œ๊ตญ์–ด ํŠนํ™”)
210
- **RAG ์ƒ๏ฟฝ๏ฟฝ๏ฟฝ**: {"โœ… ํ™œ์„ฑํ™”" if rag_system and rag_system.available else "โŒ ๋น„ํ™œ์„ฑํ™”"}
211
- **๋ฌธ์„œ ์ˆ˜**: {rag_system.collection.count() if rag_system and rag_system.available else "0"}๊ฐœ
212
-
213
- ๐Ÿ’ก **์‚ฌ์šฉ ํŒ**: ๊ตฌ์ฒด์ ์ธ ์งˆ๋ฌธ์ผ์ˆ˜๋ก ๋” ์ •ํ™•ํ•œ ๋‹ต๋ณ€์„ ๋ฐ›์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค!
214
- """)
215
 
216
  if __name__ == "__main__":
217
- demo.launch(
218
- server_name="0.0.0.0",
219
- server_port=int(os.environ.get("PORT", 7860)),
220
- share=False,
221
- show_api=False
222
- )
 
1
  import os
 
 
 
 
 
 
 
 
2
  import chromadb
 
3
  from sentence_transformers import SentenceTransformer
4
  from google import genai
5
+ import gradio as gr
6
 
7
+ # ํ™˜๊ฒฝ ์„ค์ •
8
+ DB_DIR = os.getenv("CHROMA_DB_DIR", os.path.join(os.getcwd(), "chromadb_KH_media"))
9
+ os.environ["CHROMA_DB_DIR"] = DB_DIR
10
+ API_KEY = os.getenv("GOOGLE_API_KEY", "AIzaSyBuBsC5k9yw2JwUfVFn1Zu1qM_ifwGx6cM")
 
 
11
 
12
+ # RAG ์‹œ์Šคํ…œ
13
  class SimpleRAGSystem:
14
+ def __init__(self, db_path=None, collection_name="KH_media_docs"):
15
+ path = db_path or DB_DIR
16
+ self.encoder = SentenceTransformer("snunlp/KR-SBERT-V40K-klueNLI-augSTS")
17
+ self.client = chromadb.PersistentClient(path=path)
18
+ self.collection = self.client.get_collection(name=collection_name)
19
+ self.available = self.collection.count() > 0
20
+
21
+ def search(self, query, top_k=8):
 
 
 
 
22
  if not self.available:
23
  return []
24
+ emb = self.encoder.encode(query).tolist()
25
+ docs = self.collection.query(
26
+ query_embeddings=[emb], n_results=top_k,
27
+ include=["documents"]
28
+ )
29
+ return docs["documents"][0] if docs.get("documents") else []
30
+
31
+ rag = SimpleRAGSystem()
32
+
33
+ # Google GenAI ํด๋ผ์ด์–ธํŠธ
34
+ client = genai.Client(api_key=API_KEY)
35
+
36
+ # Gradio ์‘๋‹ต ํ•จ์ˆ˜
37
+ SYSTEM_MSG = f"""
38
+ ๋‹น์‹ ์€ ๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ ์ „๋ฌธ ์ƒ๋‹ด AI์ž…๋‹ˆ๋‹ค.
39
+ """
40
+
41
+ def respond(message, history, system_message, max_tokens, temperature, top_p, model_name):
42
+ # RAG ์ปจํ…์ŠคํŠธ
43
+ docs = rag.search(message) if rag.available else []
44
+ ctx = "\n".join(f"์ฐธ๊ณ ๋ฌธ์„œ{i+1}: {d}" for i, d in enumerate(docs))
45
+ sys = system_message + ("\n# ์ฐธ๊ณ ๋ฌธ์„œ:\n" + ctx if ctx else "")
46
+ # ๋Œ€ํ™” ์ปจํ…์ŠคํŠธ
47
+ convo = "".join(f"์‚ฌ์šฉ์ž: {u}\nAI: {a}\n" for u, a in history)
48
+ prompt = f"{sys}\n{convo}์‚ฌ์šฉ์ž: {message}\nAI:"
49
+ # API ํ˜ธ์ถœ
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  try:
51
+ res = client.models.generate_content(
 
52
  model=model_name,
53
+ contents=prompt,
54
+ config={"max_output_tokens": max_tokens, "temperature": temperature, "top_p": top_p}
 
 
 
 
55
  )
56
+ return res.text or "์‘๋‹ต์ด ์—†์Šต๋‹ˆ๋‹ค."
 
 
 
 
 
57
  except Exception as e:
58
+ err = str(e).lower()
59
+ if "quota" in err:
60
+ return "API ํ• ๋‹น๋Ÿ‰์„ ์ดˆ๊ณผํ–ˆ์Šต๋‹ˆ๋‹ค. ๋‚˜์ค‘์— ์‹œ๋„ํ•ด์ฃผ์„ธ์š”."
61
+ if "authentication" in err:
62
+ return "์ธ์ฆ ์˜ค๋ฅ˜: API ํ‚ค๋ฅผ ํ™•์ธํ•˜์„ธ์š”."
63
+ return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}"
64
+
65
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค
 
 
 
 
66
  demo = gr.ChatInterface(
67
+ fn=respond,
68
  title="๐ŸŽฌ ๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ AI ์ƒ๋‹ด์‚ฌ",
69
+ description="๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ์— ๋Œ€ํ•ด ๋ฌผ์–ด๋ณด์„ธ์š”!",
 
 
70
  additional_inputs=[
71
+ gr.Slider(128, 2048, 1024, step=64, label="์ตœ๋Œ€ ํ† ํฐ"),
72
+ gr.Slider(0.1, 1.0, 0.7, step=0.1, label="Temperature"),
73
+ gr.Slider(0.1, 1.0, 0.9, step=0.05, label="Top-p"),
74
+ gr.Dropdown([
75
+ "gemini-2.0-flash", "gemini-2.0-flash-lite",
76
+ "gemini-1.5-flash", "gemini-1.5-pro",
77
+ "gemma-3-27b-it", "gemma-3-12b-it", "gemma-3-4b-it"
78
+ ], value="gemini-2.0-flash", label="๋ชจ๋ธ ์„ ํƒ")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  ],
80
  theme="soft",
81
  analytics_enabled=False,
82
  )
83
 
84
+ def main():
85
+ demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)), share=False)
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  if __name__ == "__main__":
88
+ main()