Copain22 commited on
Commit
c485f88
·
verified ·
1 Parent(s): 4019a71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -126
app.py CHANGED
@@ -1,96 +1,58 @@
1
- # 0. Install custom transformers and imports
2
  import os
3
- os.system("pip install git+https://github.com/shumingma/transformers.git")
4
- os.system("pip install python-docx")
5
-
6
- import threading
7
  import torch
8
- import torch._dynamo
9
- torch._dynamo.config.suppress_errors = True
10
-
11
- from transformers import (
12
- AutoModelForCausalLM,
13
- AutoTokenizer,
14
- TextIteratorStreamer,
15
- )
16
  import gradio as gr
17
- import spaces
18
- from docx import Document
19
 
20
- # 1. System prompt
21
- SYSTEM_PROMPT = """
22
- You are a friendly café assistant for Café Eleven. Your job is to:
23
- 1. Greet the customer warmly.
24
- 2. Help them order food and drinks from our menu.
25
- 3. Ask the customer for their desired pickup time.
26
- 4. Confirm the pickup time before ending the conversation.
27
- 5. Answer questions about ingredients, preparation, etc.
28
- 6. Handle special requests (allergies, modifications) politely.
29
- 7. Provide calorie information if asked.
30
- Always be polite, helpful, and ensure the customer feels welcomed and cared for!
31
- """
32
 
33
- MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
 
 
 
 
34
 
35
- # 2. Load tokenizer and model
36
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
37
- model = AutoModelForCausalLM.from_pretrained(
38
- MODEL_ID,
39
- torch_dtype=torch.bfloat16,
40
- device_map="auto"
41
- )
42
 
43
- print(f"Model loaded on device: {model.device}")
 
44
 
45
- # 3. Load Menu Text from Word document
46
- def load_menu_text(docx_path):
47
  doc = Document(docx_path)
48
- full_text = []
49
- for para in doc.paragraphs:
50
- if para.text.strip():
51
- full_text.append(para.text.strip())
52
- return "\n".join(full_text)
53
 
54
- MENU_TEXT = load_menu_text("menu.docx")
55
- print(f"Loaded menu text from Word document.")
 
 
56
 
57
- # 4. Simple retrieval function (search inside MENU_TEXT)
58
- def retrieve_context(question, top_k=3):
59
- question = question.lower()
60
- sentences = MENU_TEXT.split("\n")
61
- matches = [s for s in sentences if any(word in s.lower() for word in question.split())]
62
- if not matches:
63
- return "Sorry, I couldn't find relevant menu information."
64
- return "\n\n".join(matches[:top_k])
65
 
66
- # 5. Chat respond function
67
- @spaces.GPU
68
- def respond(
69
- message: str,
70
- history: list[tuple[str, str]],
71
- system_message: str,
72
- max_tokens: int,
73
- temperature: float,
74
- top_p: float,
75
- ):
76
- context = retrieve_context(message)
77
-
78
  messages = [{"role": "system", "content": system_message}]
79
- for user_msg, bot_msg in history:
80
- if user_msg:
81
- messages.append({"role": "user", "content": user_msg})
82
- if bot_msg:
83
- messages.append({"role": "assistant", "content": bot_msg})
84
- messages.append({"role": "user", "content": f"{message}\n\nRelevant menu info:\n{context}"})
85
 
86
- prompt = tokenizer.apply_chat_template(
87
- messages, tokenize=False, add_generation_prompt=True
88
- )
89
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
90
 
91
- streamer = TextIteratorStreamer(
92
- tokenizer, skip_prompt=True, skip_special_tokens=True
93
- )
94
  generate_kwargs = dict(
95
  **inputs,
96
  streamer=streamer,
@@ -99,64 +61,31 @@ def respond(
99
  top_p=top_p,
100
  do_sample=True,
101
  )
 
102
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
103
  thread.start()
104
 
105
- response = ""
106
- for new_text in streamer:
107
- response += new_text
108
- yield response
109
 
110
- # 6. Gradio ChatInterface
111
  demo = gr.ChatInterface(
112
- fn=respond,
113
- title="Café Eleven Assistant",
114
- description="Friendly café assistant based on real menu loaded from Word document!",
115
  examples=[
116
- [
117
- "What kinds of burgers do you have?",
118
- SYSTEM_PROMPT.strip(),
119
- 512,
120
- 0.7,
121
- 0.95,
122
- ],
123
- [
124
- "Do you have gluten-free pastries?",
125
- SYSTEM_PROMPT.strip(),
126
- 512,
127
- 0.7,
128
- 0.95,
129
- ],
130
  ],
131
  additional_inputs=[
132
- gr.Textbox(
133
- value=SYSTEM_PROMPT.strip(),
134
- label="System message"
135
- ),
136
- gr.Slider(
137
- minimum=1,
138
- maximum=2048,
139
- value=512,
140
- step=1,
141
- label="Max new tokens"
142
- ),
143
- gr.Slider(
144
- minimum=0.1,
145
- maximum=4.0,
146
- value=0.7,
147
- step=0.1,
148
- label="Temperature"
149
- ),
150
- gr.Slider(
151
- minimum=0.1,
152
- maximum=1.0,
153
- value=0.95,
154
- step=0.05,
155
- label="Top-p (nucleus sampling)"
156
- ),
157
- ],
158
  )
159
 
160
- # 7. Launch
161
  if __name__ == "__main__":
162
  demo.launch(share=True)
 
 
1
  import os
2
+ import faiss
 
 
 
3
  import torch
4
+ import threading
 
 
 
 
 
 
 
5
  import gradio as gr
 
 
6
 
7
+ from docx import Document
8
+ from sentence_transformers import SentenceTransformer
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
 
 
 
 
 
 
 
10
 
11
+ # === Configuration ===
12
+ MODEL_ID = "microsoft/phi-2"
13
+ EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+ SYSTEM_PROMPT = """You are a friendly café assistant. Help customers place orders, check ingredients, and provide warm service."""
16
 
17
+ # === Load LLM ===
18
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
19
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID).to(DEVICE)
 
 
 
 
20
 
21
+ # === Load Embedder ===
22
+ embedder = SentenceTransformer(EMBED_MODEL)
23
 
24
+ # === Load Menu Text ===
25
+ def load_menu(docx_path):
26
  doc = Document(docx_path)
27
+ return [p.text.strip() for p in doc.paragraphs if p.text.strip()]
28
+
29
+ menu_chunks = load_menu("menu.docx")
30
+ chunk_embeddings = embedder.encode(menu_chunks, convert_to_tensor=True).cpu().numpy()
 
31
 
32
+ # === Build FAISS Index ===
33
+ dimension = chunk_embeddings.shape[1]
34
+ index = faiss.IndexFlatL2(dimension)
35
+ index.add(chunk_embeddings)
36
 
37
+ # === Retrieval ===
38
+ def retrieve_context_faiss(query, top_k=3):
39
+ query_vec = embedder.encode([query]).astype("float32")
40
+ distances, indices = index.search(query_vec, top_k)
41
+ return "\n".join([menu_chunks[i] for i in indices[0]])
 
 
 
42
 
43
+ # === Generate LLM Response ===
44
+ def generate_response(message, history, system_message, max_tokens, temperature, top_p):
45
+ context = retrieve_context_faiss(message)
 
 
 
 
 
 
 
 
 
46
  messages = [{"role": "system", "content": system_message}]
47
+ for user, bot in history:
48
+ messages.append({"role": "user", "content": user})
49
+ messages.append({"role": "assistant", "content": bot})
50
+ messages.append({"role": "user", "content": f"{message}\n\nRelevant info:\n{context}"})
 
 
51
 
52
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
53
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
 
 
54
 
55
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
56
  generate_kwargs = dict(
57
  **inputs,
58
  streamer=streamer,
 
61
  top_p=top_p,
62
  do_sample=True,
63
  )
64
+
65
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
66
  thread.start()
67
 
68
+ output = ""
69
+ for token in streamer:
70
+ output += token
71
+ yield output
72
 
73
+ # === UI ===
74
  demo = gr.ChatInterface(
75
+ fn=generate_response,
76
+ title="Café Eleven RAG Assistant",
77
+ description="LLM + FAISS powered café chatbot with real-time Word document lookup.",
78
  examples=[
79
+ ["Do you have vegetarian options?", SYSTEM_PROMPT, 512, 0.7, 0.9],
80
+ ["What's in the turkey sandwich?", SYSTEM_PROMPT, 512, 0.7, 0.9],
 
 
 
 
 
 
 
 
 
 
 
 
81
  ],
82
  additional_inputs=[
83
+ gr.Textbox(value=SYSTEM_PROMPT, label="System Prompt"),
84
+ gr.Slider(1, 1024, 512, label="Max Tokens"),
85
+ gr.Slider(0.1, 2.0, 0.7, label="Temperature"),
86
+ gr.Slider(0.1, 1.0, 0.9, label="Top-p"),
87
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  )
89
 
 
90
  if __name__ == "__main__":
91
  demo.launch(share=True)