WillHeld commited on
Commit
d951e6a
·
verified ·
1 Parent(s): f66cdbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -70
app.py CHANGED
@@ -150,68 +150,17 @@ def push_feedback_to_hub(hf_token=None):
150
  print(f"Error pushing feedback data to Hub: {e}")
151
  return False
152
 
153
- # Modified predict function to update conversation state
154
- @spaces.GPU(duration=120)
155
- def predict(message, history, state, temperature, top_p):
156
- # Create a deep copy of history to ensure we don't modify the original
157
- current_history = history.copy()
158
-
159
- # Update history with user message
160
- current_history.append({"role": "user", "content": message})
161
-
162
- # Update the conversation state with user message
163
- if not state:
164
- state = []
165
- state = current_history.copy()
166
-
167
- input_text = tokenizer.apply_chat_template(current_history, tokenize=False, add_generation_prompt=True)
168
- inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
169
-
170
- # Create a streamer
171
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
172
-
173
- # Set up generation parameters
174
- generation_kwargs = {
175
- "input_ids": inputs,
176
- "max_new_tokens": 1024,
177
- "temperature": float(temperature),
178
- "top_p": float(top_p),
179
- "do_sample": True,
180
- "streamer": streamer,
181
- "eos_token_id": 128009,
182
- }
183
-
184
- # Run generation in a separate thread
185
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
186
- thread.start()
187
-
188
- # Yield from the streamer as tokens are generated
189
- partial_text = ""
190
- for new_text in streamer:
191
- partial_text += new_text
192
-
193
- # Create a temporary state with partial response
194
- temp_history = current_history.copy()
195
- temp_history.append({"role": "assistant", "content": partial_text})
196
- temp_state = temp_history.copy()
197
-
198
- yield partial_text, temp_state
199
-
200
- # After full generation, update state with assistant's final response
201
- current_history.append({"role": "assistant", "content": partial_text})
202
- state = current_history.copy()
203
-
204
- # Print debug info
205
- print(f"Updated state with {len(state)} messages")
206
- print(f"Last message: {state[-1]['role']}: {state[-1]['content'][:30]}...")
207
-
208
- return partial_text, state
209
-
210
  # Function to handle the research feedback submission
211
- def submit_research_feedback(conversation_state, satisfaction, feedback_text):
212
  """Save user feedback both locally and to HuggingFace Hub"""
 
 
 
 
 
 
213
  # Save locally first
214
- feedback_id = save_feedback_locally(conversation_state, satisfaction, feedback_text)
215
 
216
  # Get token from environment variable
217
  env_token = os.environ.get("HF_TOKEN")
@@ -226,25 +175,111 @@ def submit_research_feedback(conversation_state, satisfaction, feedback_text):
226
 
227
  return status_msg
228
 
 
 
 
 
 
229
  # Create the Gradio blocks interface
230
  with gr.Blocks() as demo:
231
- # State to track conversation history
232
- conversation_state = gr.State([])
233
 
234
  with gr.Row():
235
  with gr.Column(scale=3):
236
- # Custom chat function wrapper to update state
237
- def chat_with_state(message, history, state, temperature, top_p):
238
- for partial_response, updated_state in predict(message, history, state, temperature, top_p):
239
- # Update our state with each yield
240
- conversation_state.value = updated_state
241
- yield partial_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
  # Create ChatInterface
244
  chatbot = gr.ChatInterface(
245
  chat_with_state,
246
  additional_inputs=[
247
- conversation_state,
248
  gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
249
  gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
250
  ],
@@ -282,10 +317,10 @@ with gr.Blocks() as demo:
282
  feedback_modal
283
  )
284
 
285
- # Connect the submit button to the submit_research_feedback function with the current conversation state
286
  submit_button.click(
287
  submit_research_feedback,
288
- inputs=[conversation_state, satisfaction, feedback_text],
289
  outputs=response_text
290
  )
291
 
 
150
  print(f"Error pushing feedback data to Hub: {e}")
151
  return False
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  # Function to handle the research feedback submission
154
+ def submit_research_feedback(conv_history, satisfaction, feedback_text):
155
  """Save user feedback both locally and to HuggingFace Hub"""
156
+ # Print debug information
157
+ print(f"Saving feedback with conversation history containing {len(conv_history)} messages")
158
+ if conv_history and len(conv_history) > 0:
159
+ print(f"First message: {conv_history[0]['role']}: {conv_history[0]['content'][:30]}...")
160
+ print(f"Last message: {conv_history[-1]['role']}: {conv_history[-1]['content'][:30]}...")
161
+
162
  # Save locally first
163
+ feedback_id = save_feedback_locally(conv_history, satisfaction, feedback_text)
164
 
165
  # Get token from environment variable
166
  env_token = os.environ.get("HF_TOKEN")
 
175
 
176
  return status_msg
177
 
178
+ # Initial state - set up at app start
179
+ def initialize_state():
180
+ """Initialize the conversation state - this could load previous sessions or start fresh"""
181
+ return [] # Start with empty conversation history
182
+
183
  # Create the Gradio blocks interface
184
  with gr.Blocks() as demo:
185
+ # Create state to store full conversation history with proper initialization
186
+ conv_state = gr.State(initialize_state)
187
 
188
  with gr.Row():
189
  with gr.Column(scale=3):
190
+ # Create a custom predict function that updates our state
191
+ def enhanced_predict(message, history, temperature, top_p, state):
192
+ # Initialize state if needed
193
+ if state is None:
194
+ state = []
195
+ print("Initializing empty state")
196
+
197
+ # Copy history to state if state is empty but history exists
198
+ if len(state) == 0 and len(history) > 0:
199
+ state = history.copy()
200
+ print(f"Copied {len(history)} messages from history to state")
201
+
202
+ # Add user message to state
203
+ state.append({"role": "user", "content": message})
204
+
205
+ # Process with the model (this doesn't modify the original history)
206
+ input_text = tokenizer.apply_chat_template(state, tokenize=False, add_generation_prompt=True)
207
+ inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
208
+
209
+ # Create a streamer
210
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
211
+
212
+ # Set up generation parameters
213
+ generation_kwargs = {
214
+ "input_ids": inputs,
215
+ "max_new_tokens": 1024,
216
+ "temperature": float(temperature),
217
+ "top_p": float(top_p),
218
+ "do_sample": True,
219
+ "streamer": streamer,
220
+ "eos_token_id": 128009,
221
+ }
222
+
223
+ # Run generation in a separate thread
224
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
225
+ thread.start()
226
+
227
+ # Yield from the streamer as tokens are generated
228
+ response = ""
229
+ for new_text in streamer:
230
+ response += new_text
231
+ # For each partial response, yield the text only
232
+ # We'll update the state after generation is complete
233
+ yield response
234
+
235
+ # After generation completes, update our state with the final response
236
+ state.append({"role": "assistant", "content": response})
237
+
238
+ # Return the updated state
239
+ return state
240
+
241
+ # Create a wrapper that connects to ChatInterface but also updates our state
242
+ def chat_with_state(message, history, temperature, top_p):
243
+ # This function is what interfaces with the ChatInterface
244
+ nonlocal conv_state
245
+
246
+ # Access the current state
247
+ current_state = conv_state.value if conv_state.value else []
248
+
249
+ # Call the main function that generates responses and updates state
250
+ # This is a generator function, so we need to iterate through its outputs
251
+ response_gen = enhanced_predict(message, history, temperature, top_p, current_state)
252
+
253
+ # For each response, yield it and also update our state at the end
254
+ last_response = None
255
+ for response in response_gen:
256
+ last_response = response
257
+ yield response
258
+
259
+ # After generation is complete, update our state
260
+ if last_response is not None:
261
+ # Create a full copy of the history plus the new exchange
262
+ updated_state = []
263
+ # Add all previous history
264
+ for msg in history:
265
+ updated_state.append(msg.copy())
266
+ # Add new exchange
267
+ updated_state.append({"role": "user", "content": message})
268
+ updated_state.append({"role": "assistant", "content": last_response})
269
+
270
+ # Store in our state
271
+ conv_state.value = updated_state
272
+
273
+ # Debug
274
+ print(f"Updated conversation state with {len(updated_state)} messages")
275
+ if updated_state:
276
+ last_msg = updated_state[-1]
277
+ print(f"Last message: {last_msg['role']}: {last_msg['content'][:30]}...")
278
 
279
  # Create ChatInterface
280
  chatbot = gr.ChatInterface(
281
  chat_with_state,
282
  additional_inputs=[
 
283
  gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
284
  gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
285
  ],
 
317
  feedback_modal
318
  )
319
 
320
+ # Connect the submit button to the submit_research_feedback function
321
  submit_button.click(
322
  submit_research_feedback,
323
+ inputs=[conv_state, satisfaction, feedback_text],
324
  outputs=response_text
325
  )
326