Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,12 +4,12 @@ from sentence_transformers import SentenceTransformer
|
|
4 |
from google import genai
|
5 |
import gradio as gr
|
6 |
|
7 |
-
# ํ๊ฒฝ ์ค์
|
8 |
DB_DIR = os.getenv("CHROMA_DB_DIR", os.path.join(os.getcwd(), "chromadb_KH_media"))
|
9 |
os.environ["CHROMA_DB_DIR"] = DB_DIR
|
10 |
-
API_KEY = os.getenv("GOOGLE_API_KEY", "
|
11 |
|
12 |
-
# RAG ์์คํ
|
13 |
class SimpleRAGSystem:
|
14 |
def __init__(self, db_path=None, collection_name="KH_media_docs"):
|
15 |
path = db_path or DB_DIR
|
@@ -22,38 +22,41 @@ class SimpleRAGSystem:
|
|
22 |
if not self.available:
|
23 |
return []
|
24 |
emb = self.encoder.encode(query).tolist()
|
25 |
-
|
26 |
-
query_embeddings=[emb],
|
|
|
27 |
include=["documents"]
|
28 |
)
|
29 |
-
return
|
30 |
|
31 |
rag = SimpleRAGSystem()
|
32 |
|
33 |
-
# Google GenAI ํด๋ผ์ด์ธํธ
|
34 |
client = genai.Client(api_key=API_KEY)
|
35 |
|
36 |
-
#
|
37 |
-
SYSTEM_MSG =
|
38 |
๋น์ ์ ๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ ์ ๋ฌธ ์๋ด AI์
๋๋ค.
|
39 |
"""
|
40 |
|
|
|
41 |
def respond(message, history, system_message, max_tokens, temperature, top_p, model_name):
|
42 |
-
# RAG ์ปจํ
์คํธ
|
43 |
docs = rag.search(message) if rag.available else []
|
44 |
ctx = "\n".join(f"์ฐธ๊ณ ๋ฌธ์{i+1}: {d}" for i, d in enumerate(docs))
|
45 |
-
|
46 |
-
# ๋ํ ์ปจํ
์คํธ
|
47 |
convo = "".join(f"์ฌ์ฉ์: {u}\nAI: {a}\n" for u, a in history)
|
48 |
-
prompt = f"{
|
49 |
-
# API ํธ์ถ
|
50 |
try:
|
51 |
-
|
52 |
model=model_name,
|
53 |
contents=prompt,
|
54 |
-
config={
|
|
|
|
|
|
|
|
|
55 |
)
|
56 |
-
return
|
57 |
except Exception as e:
|
58 |
err = str(e).lower()
|
59 |
if "quota" in err:
|
@@ -62,27 +65,29 @@ def respond(message, history, system_message, max_tokens, temperature, top_p, mo
|
|
62 |
return "์ธ์ฆ ์ค๋ฅ: API ํค๋ฅผ ํ์ธํ์ธ์."
|
63 |
return f"์ค๋ฅ ๋ฐ์: {e}"
|
64 |
|
65 |
-
# Gradio ์ธํฐํ์ด์ค
|
66 |
demo = gr.ChatInterface(
|
67 |
fn=respond,
|
68 |
title="๐ฌ ๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ AI ์๋ด์ฌ",
|
69 |
description="๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ์ ๋ํด ๋ฌผ์ด๋ณด์ธ์!",
|
70 |
additional_inputs=[
|
71 |
-
gr.
|
72 |
-
gr.Slider(
|
73 |
-
gr.Slider(0.1, 1.0, 0.
|
74 |
-
gr.
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
79 |
],
|
80 |
theme="soft",
|
81 |
analytics_enabled=False,
|
82 |
)
|
83 |
|
84 |
-
def main():
|
85 |
-
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)), share=False)
|
86 |
-
|
87 |
if __name__ == "__main__":
|
88 |
-
|
|
|
4 |
from google import genai
|
5 |
import gradio as gr
|
6 |
|
7 |
+
# === ํ๊ฒฝ ์ค์ ===
|
8 |
DB_DIR = os.getenv("CHROMA_DB_DIR", os.path.join(os.getcwd(), "chromadb_KH_media"))
|
9 |
os.environ["CHROMA_DB_DIR"] = DB_DIR
|
10 |
+
API_KEY = os.getenv("GOOGLE_API_KEY", "YOUR_API_KEY_HERE")
|
11 |
|
12 |
+
# === Simple RAG ์์คํ
===
|
13 |
class SimpleRAGSystem:
|
14 |
def __init__(self, db_path=None, collection_name="KH_media_docs"):
|
15 |
path = db_path or DB_DIR
|
|
|
22 |
if not self.available:
|
23 |
return []
|
24 |
emb = self.encoder.encode(query).tolist()
|
25 |
+
result = self.collection.query(
|
26 |
+
query_embeddings=[emb],
|
27 |
+
n_results=top_k,
|
28 |
include=["documents"]
|
29 |
)
|
30 |
+
return result.get("documents", [[]])[0]
|
31 |
|
32 |
rag = SimpleRAGSystem()
|
33 |
|
34 |
+
# === Google GenAI ํด๋ผ์ด์ธํธ ===
|
35 |
client = genai.Client(api_key=API_KEY)
|
36 |
|
37 |
+
# === ์์คํ
๋ฉ์์ง ===
|
38 |
+
SYSTEM_MSG = """
|
39 |
๋น์ ์ ๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ ์ ๋ฌธ ์๋ด AI์
๋๋ค.
|
40 |
"""
|
41 |
|
42 |
+
# === ์๋ต ํจ์ ===
|
43 |
def respond(message, history, system_message, max_tokens, temperature, top_p, model_name):
|
|
|
44 |
docs = rag.search(message) if rag.available else []
|
45 |
ctx = "\n".join(f"์ฐธ๊ณ ๋ฌธ์{i+1}: {d}" for i, d in enumerate(docs))
|
46 |
+
sys_msg = system_message + ("\n# ์ฐธ๊ณ ๋ฌธ์:\n" + ctx if ctx else "")
|
|
|
47 |
convo = "".join(f"์ฌ์ฉ์: {u}\nAI: {a}\n" for u, a in history)
|
48 |
+
prompt = f"{sys_msg}\n{convo}์ฌ์ฉ์: {message}\nAI:"
|
|
|
49 |
try:
|
50 |
+
response = client.models.generate_content(
|
51 |
model=model_name,
|
52 |
contents=prompt,
|
53 |
+
config={
|
54 |
+
"max_output_tokens": max_tokens,
|
55 |
+
"temperature": temperature,
|
56 |
+
"top_p": top_p
|
57 |
+
}
|
58 |
)
|
59 |
+
return response.text or "์๋ต์ด ์์ต๋๋ค."
|
60 |
except Exception as e:
|
61 |
err = str(e).lower()
|
62 |
if "quota" in err:
|
|
|
65 |
return "์ธ์ฆ ์ค๋ฅ: API ํค๋ฅผ ํ์ธํ์ธ์."
|
66 |
return f"์ค๋ฅ ๋ฐ์: {e}"
|
67 |
|
68 |
+
# === Gradio ์ธํฐํ์ด์ค ===
|
69 |
demo = gr.ChatInterface(
|
70 |
fn=respond,
|
71 |
title="๐ฌ ๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ AI ์๋ด์ฌ",
|
72 |
description="๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ์ ๋ํด ๋ฌผ์ด๋ณด์ธ์!",
|
73 |
additional_inputs=[
|
74 |
+
gr.Textbox(value=SYSTEM_MSG, label="์์คํ
๋ฉ์์ง", lines=2),
|
75 |
+
gr.Slider(128, 2048, value=1024, step=64, label="์ต๋ ํ ํฐ"),
|
76 |
+
gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature"),
|
77 |
+
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"),
|
78 |
+
gr.Dropdown(
|
79 |
+
choices=[
|
80 |
+
"gemini-2.0-flash", "gemini-2.0-flash-lite",
|
81 |
+
"gemini-1.5-flash", "gemini-1.5-pro",
|
82 |
+
"gemma-3-27b-it", "gemma-3-12b-it", "gemma-3-4b-it"
|
83 |
+
],
|
84 |
+
value="gemini-2.0-flash",
|
85 |
+
label="๋ชจ๋ธ ์ ํ"
|
86 |
+
)
|
87 |
],
|
88 |
theme="soft",
|
89 |
analytics_enabled=False,
|
90 |
)
|
91 |
|
|
|
|
|
|
|
92 |
if __name__ == "__main__":
|
93 |
+
demo.launch(share=False)
|