jonghhhh commited on
Commit
f04f8de
ยท
verified ยท
1 Parent(s): c70195a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -30
app.py CHANGED
@@ -4,12 +4,12 @@ from sentence_transformers import SentenceTransformer
4
  from google import genai
5
  import gradio as gr
6
 
7
- # ํ™˜๊ฒฝ ์„ค์ •
8
  DB_DIR = os.getenv("CHROMA_DB_DIR", os.path.join(os.getcwd(), "chromadb_KH_media"))
9
  os.environ["CHROMA_DB_DIR"] = DB_DIR
10
- API_KEY = os.getenv("GOOGLE_API_KEY", "AIzaSyBuBsC5k9yw2JwUfVFn1Zu1qM_ifwGx6cM")
11
 
12
- # RAG ์‹œ์Šคํ…œ
13
  class SimpleRAGSystem:
14
  def __init__(self, db_path=None, collection_name="KH_media_docs"):
15
  path = db_path or DB_DIR
@@ -22,38 +22,41 @@ class SimpleRAGSystem:
22
  if not self.available:
23
  return []
24
  emb = self.encoder.encode(query).tolist()
25
- docs = self.collection.query(
26
- query_embeddings=[emb], n_results=top_k,
 
27
  include=["documents"]
28
  )
29
- return docs["documents"][0] if docs.get("documents") else []
30
 
31
  rag = SimpleRAGSystem()
32
 
33
- # Google GenAI ํด๋ผ์ด์–ธํŠธ
34
  client = genai.Client(api_key=API_KEY)
35
 
36
- # Gradio ์‘๋‹ต ํ•จ์ˆ˜
37
- SYSTEM_MSG = f"""
38
  ๋‹น์‹ ์€ ๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ ์ „๋ฌธ ์ƒ๋‹ด AI์ž…๋‹ˆ๋‹ค.
39
  """
40
 
 
41
  def respond(message, history, system_message, max_tokens, temperature, top_p, model_name):
42
- # RAG ์ปจํ…์ŠคํŠธ
43
  docs = rag.search(message) if rag.available else []
44
  ctx = "\n".join(f"์ฐธ๊ณ ๋ฌธ์„œ{i+1}: {d}" for i, d in enumerate(docs))
45
- sys = system_message + ("\n# ์ฐธ๊ณ ๋ฌธ์„œ:\n" + ctx if ctx else "")
46
- # ๋Œ€ํ™” ์ปจํ…์ŠคํŠธ
47
  convo = "".join(f"์‚ฌ์šฉ์ž: {u}\nAI: {a}\n" for u, a in history)
48
- prompt = f"{sys}\n{convo}์‚ฌ์šฉ์ž: {message}\nAI:"
49
- # API ํ˜ธ์ถœ
50
  try:
51
- res = client.models.generate_content(
52
  model=model_name,
53
  contents=prompt,
54
- config={"max_output_tokens": max_tokens, "temperature": temperature, "top_p": top_p}
 
 
 
 
55
  )
56
- return res.text or "์‘๋‹ต์ด ์—†์Šต๋‹ˆ๋‹ค."
57
  except Exception as e:
58
  err = str(e).lower()
59
  if "quota" in err:
@@ -62,27 +65,29 @@ def respond(message, history, system_message, max_tokens, temperature, top_p, mo
62
  return "์ธ์ฆ ์˜ค๋ฅ˜: API ํ‚ค๋ฅผ ํ™•์ธํ•˜์„ธ์š”."
63
  return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}"
64
 
65
- # Gradio ์ธํ„ฐํŽ˜์ด์Šค
66
  demo = gr.ChatInterface(
67
  fn=respond,
68
  title="๐ŸŽฌ ๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ AI ์ƒ๋‹ด์‚ฌ",
69
  description="๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ์— ๋Œ€ํ•ด ๋ฌผ์–ด๋ณด์„ธ์š”!",
70
  additional_inputs=[
71
- gr.Slider(128, 2048, 1024, step=64, label="์ตœ๋Œ€ ํ† ํฐ"),
72
- gr.Slider(0.1, 1.0, 0.7, step=0.1, label="Temperature"),
73
- gr.Slider(0.1, 1.0, 0.9, step=0.05, label="Top-p"),
74
- gr.Dropdown([
75
- "gemini-2.0-flash", "gemini-2.0-flash-lite",
76
- "gemini-1.5-flash", "gemini-1.5-pro",
77
- "gemma-3-27b-it", "gemma-3-12b-it", "gemma-3-4b-it"
78
- ], value="gemini-2.0-flash", label="๋ชจ๋ธ ์„ ํƒ")
 
 
 
 
 
79
  ],
80
  theme="soft",
81
  analytics_enabled=False,
82
  )
83
 
84
- def main():
85
- demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)), share=False)
86
-
87
  if __name__ == "__main__":
88
- main()
 
4
  from google import genai
5
  import gradio as gr
6
 
7
+ # === ํ™˜๊ฒฝ ์„ค์ • ===
8
  DB_DIR = os.getenv("CHROMA_DB_DIR", os.path.join(os.getcwd(), "chromadb_KH_media"))
9
  os.environ["CHROMA_DB_DIR"] = DB_DIR
10
+ API_KEY = os.getenv("GOOGLE_API_KEY", "YOUR_API_KEY_HERE")
11
 
12
+ # === Simple RAG ์‹œ์Šคํ…œ ===
13
  class SimpleRAGSystem:
14
  def __init__(self, db_path=None, collection_name="KH_media_docs"):
15
  path = db_path or DB_DIR
 
22
  if not self.available:
23
  return []
24
  emb = self.encoder.encode(query).tolist()
25
+ result = self.collection.query(
26
+ query_embeddings=[emb],
27
+ n_results=top_k,
28
  include=["documents"]
29
  )
30
+ return result.get("documents", [[]])[0]
31
 
32
  rag = SimpleRAGSystem()
33
 
34
+ # === Google GenAI ํด๋ผ์ด์–ธํŠธ ===
35
  client = genai.Client(api_key=API_KEY)
36
 
37
+ # === ์‹œ์Šคํ…œ ๋ฉ”์‹œ์ง€ ===
38
+ SYSTEM_MSG = """
39
  ๋‹น์‹ ์€ ๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ ์ „๋ฌธ ์ƒ๋‹ด AI์ž…๋‹ˆ๋‹ค.
40
  """
41
 
42
+ # === ์‘๋‹ต ํ•จ์ˆ˜ ===
43
  def respond(message, history, system_message, max_tokens, temperature, top_p, model_name):
 
44
  docs = rag.search(message) if rag.available else []
45
  ctx = "\n".join(f"์ฐธ๊ณ ๋ฌธ์„œ{i+1}: {d}" for i, d in enumerate(docs))
46
+ sys_msg = system_message + ("\n# ์ฐธ๊ณ ๋ฌธ์„œ:\n" + ctx if ctx else "")
 
47
  convo = "".join(f"์‚ฌ์šฉ์ž: {u}\nAI: {a}\n" for u, a in history)
48
+ prompt = f"{sys_msg}\n{convo}์‚ฌ์šฉ์ž: {message}\nAI:"
 
49
  try:
50
+ response = client.models.generate_content(
51
  model=model_name,
52
  contents=prompt,
53
+ config={
54
+ "max_output_tokens": max_tokens,
55
+ "temperature": temperature,
56
+ "top_p": top_p
57
+ }
58
  )
59
+ return response.text or "์‘๋‹ต์ด ์—†์Šต๋‹ˆ๋‹ค."
60
  except Exception as e:
61
  err = str(e).lower()
62
  if "quota" in err:
 
65
  return "์ธ์ฆ ์˜ค๋ฅ˜: API ํ‚ค๋ฅผ ํ™•์ธํ•˜์„ธ์š”."
66
  return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}"
67
 
68
+ # === Gradio ์ธํ„ฐํŽ˜์ด์Šค ===
69
  demo = gr.ChatInterface(
70
  fn=respond,
71
  title="๐ŸŽฌ ๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ AI ์ƒ๋‹ด์‚ฌ",
72
  description="๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ์— ๋Œ€ํ•ด ๋ฌผ์–ด๋ณด์„ธ์š”!",
73
  additional_inputs=[
74
+ gr.Textbox(value=SYSTEM_MSG, label="์‹œ์Šคํ…œ ๋ฉ”์‹œ์ง€", lines=2),
75
+ gr.Slider(128, 2048, value=1024, step=64, label="์ตœ๋Œ€ ํ† ํฐ"),
76
+ gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature"),
77
+ gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"),
78
+ gr.Dropdown(
79
+ choices=[
80
+ "gemini-2.0-flash", "gemini-2.0-flash-lite",
81
+ "gemini-1.5-flash", "gemini-1.5-pro",
82
+ "gemma-3-27b-it", "gemma-3-12b-it", "gemma-3-4b-it"
83
+ ],
84
+ value="gemini-2.0-flash",
85
+ label="๋ชจ๋ธ ์„ ํƒ"
86
+ )
87
  ],
88
  theme="soft",
89
  analytics_enabled=False,
90
  )
91
 
 
 
 
92
  if __name__ == "__main__":
93
+ demo.launch(share=False)