Copain22 commited on
Commit
4c5c529
·
verified ·
1 Parent(s): 51f7fc7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -59
app.py CHANGED
@@ -1,63 +1,95 @@
1
  import os
2
- import faiss
3
- import torch
 
4
  import threading
5
- import gradio as gr
 
 
6
 
 
 
 
 
 
 
 
7
  from docx import Document
8
- from sentence_transformers import SentenceTransformer
9
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
- from spaces import GPU
11
 
12
- # === Configuration ===
13
- MODEL_ID = "microsoft/phi-2"
14
- EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
15
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
- SYSTEM_PROMPT = """You are a friendly café assistant. Help customers place orders, check ingredients, and provide warm service."""
17
 
18
- # === Load LLM ===
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
20
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID).to(DEVICE)
 
 
 
 
21
 
22
- # === Load Embedder ===
23
- embedder = SentenceTransformer(EMBED_MODEL)
24
 
25
- # === Load Menu Text ===
26
- def load_menu(docx_path):
27
- doc = Document(docx_path)
28
- return [p.text.strip() for p in doc.paragraphs if p.text.strip()]
29
 
30
- menu_chunks = load_menu("menu.docx")
31
- chunk_embeddings = embedder.encode(menu_chunks, convert_to_tensor=True).cpu().numpy()
 
 
 
 
 
32
 
33
- # === Build FAISS Index ===
34
- dimension = chunk_embeddings.shape[1]
35
- index = faiss.IndexFlatL2(dimension)
36
- index.add(chunk_embeddings)
37
 
38
- # === Retrieval ===
39
- def retrieve_context_faiss(query, top_k=3):
40
- query_vec = embedder.encode([query]).astype("float32")
41
- distances, indices = index.search(query_vec, top_k)
42
- return "\n".join([menu_chunks[i] for i in indices[0]])
43
 
 
 
 
 
 
 
 
44
 
45
- # === Generate LLM Response ===
46
- @spaces.GPU # Only if you're using ZeroGPU
47
- def generate_response(message, history, system_message, max_tokens, temperature, top_p):
48
- context = retrieve_context_faiss(message)
49
 
 
 
 
 
 
 
 
 
 
 
 
50
  messages = [{"role": "system", "content": system_message}]
51
  for user_msg, bot_msg in history:
52
- messages.append({"role": "user", "content": user_msg})
53
- messages.append({"role": "assistant", "content": bot_msg})
54
-
55
- messages.append({"role": "user", "content": f"{message}\n\nRelevant info:\n{context}"})
56
-
57
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
58
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
59
 
60
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
61
  generate_kwargs = dict(
62
  **inputs,
63
  streamer=streamer,
@@ -66,32 +98,64 @@ def generate_response(message, history, system_message, max_tokens, temperature,
66
  top_p=top_p,
67
  do_sample=True,
68
  )
69
-
70
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
71
  thread.start()
72
 
73
- output = ""
74
- for token in streamer:
75
- output += token
76
- yield output
77
-
78
- print("Inputs received:", message, history, system_message, max_tokens, temperature, top_p)
79
- # === UI ===
80
  demo = gr.ChatInterface(
81
- fn=generate_response,
82
- title="Café Eleven RAG Assistant",
83
- description="LLM + FAISS powered café chatbot with real-time Word document lookup.",
84
  examples=[
85
- ["Do you have vegetarian options?", SYSTEM_PROMPT, 512, 0.7, 0.9],
86
- ["What's in the turkey sandwich?", SYSTEM_PROMPT, 512, 0.7, 0.9],
 
 
 
 
 
 
 
 
 
 
 
 
87
  ],
88
  additional_inputs=[
89
- gr.Textbox(value=SYSTEM_PROMPT, label="System Prompt"),
90
- gr.Slider(1, 1024, 512, label="Max Tokens"),
91
- gr.Slider(0.1, 2.0, 0.7, label="Temperature"),
92
- gr.Slider(0.1, 1.0, 0.9, label="Top-p"),
93
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  )
95
 
 
96
  if __name__ == "__main__":
97
  demo.launch(share=True)
 
1
  import os
2
+ os.system("pip install git+https://github.com/shumingma/transformers.git")
3
+ os.system("pip install python-docx")
4
+
5
  import threading
6
+ import torch
7
+ import torch._dynamo
8
+ torch._dynamo.config.suppress_errors = True
9
 
10
+ from transformers import (
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ TextIteratorStreamer,
14
+ )
15
+ import gradio as gr
16
+ import spaces
17
  from docx import Document
 
 
 
18
 
 
 
 
 
 
19
 
20
+ SYSTEM_PROMPT = """
21
+ You are a friendly café assistant for Café Eleven. Your job is to:
22
+ 1. Greet the customer warmly.
23
+ 2. Help them order food and drinks from our menu.
24
+ 3. Ask the customer for their desired pickup time.
25
+ 4. Confirm the pickup time before ending the conversation.
26
+ 5. Answer questions about ingredients, preparation, etc.
27
+ 6. Handle special requests (allergies, modifications) politely.
28
+ 7. Provide calorie information if asked.
29
+ Always be polite, helpful, and ensure the customer feels welcomed and cared for!
30
+ """
31
+
32
+ MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
33
+
34
+
35
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ MODEL_ID,
38
+ torch_dtype=torch.bfloat16,
39
+ device_map="auto"
40
+ )
41
 
42
+ print(f"Model loaded on device: {model.device}")
 
43
 
 
 
 
 
44
 
45
+ def load_menu_text(docx_path):
46
+ doc = Document(docx_path)
47
+ full_text = []
48
+ for para in doc.paragraphs:
49
+ if para.text.strip():
50
+ full_text.append(para.text.strip())
51
+ return "\n".join(full_text)
52
 
53
+ MENU_TEXT = load_menu_text("menu.docx")
54
+ print(f"Loaded menu text from Word document.")
 
 
55
 
 
 
 
 
 
56
 
57
+ def retrieve_context(question, top_k=3):
58
+ question = question.lower()
59
+ sentences = MENU_TEXT.split("\n")
60
+ matches = [s for s in sentences if any(word in s.lower() for word in question.split())]
61
+ if not matches:
62
+ return "Sorry, I couldn't find relevant menu information."
63
+ return "\n\n".join(matches[:top_k])
64
 
 
 
 
 
65
 
66
+ @spaces.GPU
67
+ def respond(
68
+ message: str,
69
+ history: list[tuple[str, str]],
70
+ system_message: str,
71
+ max_tokens: int,
72
+ temperature: float,
73
+ top_p: float,
74
+ ):
75
+ context = retrieve_context(message)
76
+
77
  messages = [{"role": "system", "content": system_message}]
78
  for user_msg, bot_msg in history:
79
+ if user_msg:
80
+ messages.append({"role": "user", "content": user_msg})
81
+ if bot_msg:
82
+ messages.append({"role": "assistant", "content": bot_msg})
83
+ messages.append({"role": "user", "content": f"{message}\n\nRelevant menu info:\n{context}"})
84
+
85
+ prompt = tokenizer.apply_chat_template(
86
+ messages, tokenize=False, add_generation_prompt=True
87
+ )
88
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
89
 
90
+ streamer = TextIteratorStreamer(
91
+ tokenizer, skip_prompt=True, skip_special_tokens=True
92
+ )
93
  generate_kwargs = dict(
94
  **inputs,
95
  streamer=streamer,
 
98
  top_p=top_p,
99
  do_sample=True,
100
  )
 
101
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
102
  thread.start()
103
 
104
+ response = ""
105
+ for new_text in streamer:
106
+ response += new_text
107
+ yield response
108
+
109
+
 
110
  demo = gr.ChatInterface(
111
+ fn=respond,
112
+ title="Café Eleven Assistant",
113
+ description="Friendly café assistant based on real menu loaded from Word document!",
114
  examples=[
115
+ [
116
+ "What kinds of burgers do you have?",
117
+ SYSTEM_PROMPT.strip(),
118
+ 512,
119
+ 0.7,
120
+ 0.95,
121
+ ],
122
+ [
123
+ "Do you have gluten-free pastries?",
124
+ SYSTEM_PROMPT.strip(),
125
+ 512,
126
+ 0.7,
127
+ 0.95,
128
+ ],
129
  ],
130
  additional_inputs=[
131
+ gr.Textbox(
132
+ value=SYSTEM_PROMPT.strip(),
133
+ label="System message"
134
+ ),
135
+ gr.Slider(
136
+ minimum=1,
137
+ maximum=2048,
138
+ value=512,
139
+ step=1,
140
+ label="Max new tokens"
141
+ ),
142
+ gr.Slider(
143
+ minimum=0.1,
144
+ maximum=4.0,
145
+ value=0.7,
146
+ step=0.1,
147
+ label="Temperature"
148
+ ),
149
+ gr.Slider(
150
+ minimum=0.1,
151
+ maximum=1.0,
152
+ value=0.95,
153
+ step=0.05,
154
+ label="Top-p (nucleus sampling)"
155
+ ),
156
+ ],
157
  )
158
 
159
+
160
  if __name__ == "__main__":
161
  demo.launch(share=True)