File size: 3,392 Bytes
2425d9c
 
 
 
c70195a
2425d9c
f04f8de
c70195a
 
f04f8de
5e96ac1
f04f8de
2425d9c
c70195a
 
 
 
 
 
 
 
2425d9c
 
c70195a
f04f8de
 
 
c70195a
 
f04f8de
c70195a
 
 
f04f8de
c70195a
 
f04f8de
 
c70195a
 
 
f04f8de
c70195a
 
 
f04f8de
c70195a
f04f8de
2425d9c
f04f8de
2425d9c
c70195a
f04f8de
 
 
 
 
2425d9c
f04f8de
2425d9c
c70195a
 
 
 
 
 
 
f04f8de
2425d9c
c70195a
2425d9c
c70195a
2425d9c
f04f8de
 
 
 
 
 
 
 
 
 
 
 
 
2425d9c
 
 
 
 
 
f04f8de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import chromadb
from sentence_transformers import SentenceTransformer
from google import genai
import gradio as gr

# === ํ™˜๊ฒฝ ์„ค์ • ===
DB_DIR = os.getenv("CHROMA_DB_DIR", os.path.join(os.getcwd(), "chromadb_KH_media"))
os.environ["CHROMA_DB_DIR"] = DB_DIR
API_KEY = os.getenv("GOOGLE_API_KEY", "YOUR_API_KEY_HERE")

# === Simple RAG ์‹œ์Šคํ…œ ===
class SimpleRAGSystem:
    def __init__(self, db_path=None, collection_name="KH_media_docs"):
        path = db_path or DB_DIR
        self.encoder = SentenceTransformer("snunlp/KR-SBERT-V40K-klueNLI-augSTS")
        self.client = chromadb.PersistentClient(path=path)
        self.collection = self.client.get_collection(name=collection_name)
        self.available = self.collection.count() > 0

    def search(self, query, top_k=8):
        if not self.available:
            return []
        emb = self.encoder.encode(query).tolist()
        result = self.collection.query(
            query_embeddings=[emb],
            n_results=top_k,
            include=["documents"]
        )
        return result.get("documents", [[]])[0]

rag = SimpleRAGSystem()

# === Google GenAI ํด๋ผ์ด์–ธํŠธ ===
client = genai.Client(api_key=API_KEY)

# === ์‹œ์Šคํ…œ ๋ฉ”์‹œ์ง€ ===
SYSTEM_MSG = """
๋‹น์‹ ์€ ๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ ์ „๋ฌธ ์ƒ๋‹ด AI์ž…๋‹ˆ๋‹ค.
"""

# === ์‘๋‹ต ํ•จ์ˆ˜ ===
def respond(message, history, system_message, max_tokens, temperature, top_p, model_name):
    docs = rag.search(message) if rag.available else []
    ctx = "\n".join(f"์ฐธ๊ณ ๋ฌธ์„œ{i+1}: {d}" for i, d in enumerate(docs))
    sys_msg = system_message + ("\n# ์ฐธ๊ณ ๋ฌธ์„œ:\n" + ctx if ctx else "")
    convo = "".join(f"์‚ฌ์šฉ์ž: {u}\nAI: {a}\n" for u, a in history)
    prompt = f"{sys_msg}\n{convo}์‚ฌ์šฉ์ž: {message}\nAI:"
    try:
        response = client.models.generate_content(
            model=model_name,
            contents=prompt,
            config={
                "max_output_tokens": max_tokens,
                "temperature": temperature,
                "top_p": top_p
            }
        )
        return response.text or "์‘๋‹ต์ด ์—†์Šต๋‹ˆ๋‹ค."
    except Exception as e:
        err = str(e).lower()
        if "quota" in err:
            return "API ํ• ๋‹น๋Ÿ‰์„ ์ดˆ๊ณผํ–ˆ์Šต๋‹ˆ๋‹ค. ๋‚˜์ค‘์— ์‹œ๋„ํ•ด์ฃผ์„ธ์š”."
        if "authentication" in err:
            return "์ธ์ฆ ์˜ค๋ฅ˜: API ํ‚ค๋ฅผ ํ™•์ธํ•˜์„ธ์š”."
        return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}"

# === Gradio ์ธํ„ฐํŽ˜์ด์Šค ===
demo = gr.ChatInterface(
    fn=respond,
    title="๐ŸŽฌ ๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ AI ์ƒ๋‹ด์‚ฌ",
    description="๊ฒฝํฌ๋Œ€ํ•™๊ต ๋ฏธ๋””์–ดํ•™๊ณผ์— ๋Œ€ํ•ด ๋ฌผ์–ด๋ณด์„ธ์š”!",
    additional_inputs=[
        gr.Textbox(value=SYSTEM_MSG, label="์‹œ์Šคํ…œ ๋ฉ”์‹œ์ง€", lines=2),
        gr.Slider(128, 2048, value=1024, step=64, label="์ตœ๋Œ€ ํ† ํฐ"),
        gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"),
        gr.Dropdown(
            choices=[
                "gemini-2.0-flash", "gemini-2.0-flash-lite",
                "gemini-1.5-flash", "gemini-1.5-pro",
                "gemma-3-27b-it", "gemma-3-12b-it", "gemma-3-4b-it"
            ],
            value="gemini-2.0-flash",
            label="๋ชจ๋ธ ์„ ํƒ"
        )
    ],
    theme="soft",
    analytics_enabled=False,
)

if __name__ == "__main__":
    demo.launch(share=False)