WillHeld commited on
Commit
51b5709
·
verified ·
1 Parent(s): bc12570

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -20
app.py CHANGED
@@ -93,16 +93,14 @@ def push_feedback_to_hub(hf_token=None):
93
  print(f"Error pushing feedback data to Hub: {e}")
94
  return False
95
 
96
- # Create a State to store chat history
97
- chat_history_state = []
98
-
99
  @spaces.GPU(duration=120)
100
- def predict(message, history, temperature, top_p):
101
- global chat_history_state
102
-
103
- # Update our chat history state
104
  history.append({"role": "user", "content": message})
105
- chat_history_state = history.copy()
 
 
106
 
107
  input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
108
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
@@ -129,18 +127,18 @@ def predict(message, history, temperature, top_p):
129
  partial_text = ""
130
  for new_text in streamer:
131
  partial_text += new_text
132
- yield partial_text
133
 
134
- # After generation is complete, update chat history state with the assistant response
135
- chat_history_state.append({"role": "assistant", "content": partial_text})
 
 
136
 
137
  # Function to handle the research feedback submission
138
- def submit_research_feedback(satisfaction, feedback_text):
139
  """Save user feedback both locally and to HuggingFace Hub"""
140
- global chat_history_state
141
-
142
  # Save locally first
143
- feedback_id = save_feedback_locally(chat_history_state, satisfaction, feedback_text)
144
 
145
  # Get token from environment variable
146
  env_token = os.environ.get("HF_TOKEN")
@@ -155,16 +153,29 @@ def submit_research_feedback(satisfaction, feedback_text):
155
 
156
  return status_msg
157
 
158
- # Create the Gradio interface
159
  with gr.Blocks() as demo:
 
 
 
160
  with gr.Row():
161
  with gr.Column(scale=3):
 
 
 
 
 
 
 
 
162
  chatbot = gr.ChatInterface(
163
- predict,
164
  additional_inputs=[
 
165
  gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
166
  gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
167
  ],
 
168
  type="messages"
169
  )
170
 
@@ -199,10 +210,10 @@ with gr.Blocks() as demo:
199
  feedback_modal
200
  )
201
 
202
- # Connect the submit button to the submit_research_feedback function with the current chat history
203
  submit_button.click(
204
- lambda satisfaction, feedback_text: submit_research_feedback(satisfaction, feedback_text),
205
- inputs=[satisfaction, feedback_text],
206
  outputs=response_text
207
  )
208
 
 
93
  print(f"Error pushing feedback data to Hub: {e}")
94
  return False
95
 
96
+ # Modified predict function to update conversation state
 
 
97
  @spaces.GPU(duration=120)
98
+ def predict(message, history, state, temperature, top_p):
99
+ # Update history with user message
 
 
100
  history.append({"role": "user", "content": message})
101
+
102
+ # Update the conversation state
103
+ state = history.copy()
104
 
105
  input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
106
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
 
127
  partial_text = ""
128
  for new_text in streamer:
129
  partial_text += new_text
130
+ yield partial_text, state
131
 
132
+ # After full generation, update state with assistant's response
133
+ history.append({"role": "assistant", "content": partial_text})
134
+ state = history.copy()
135
+ return partial_text, state
136
 
137
  # Function to handle the research feedback submission
138
+ def submit_research_feedback(conversation_state, satisfaction, feedback_text):
139
  """Save user feedback both locally and to HuggingFace Hub"""
 
 
140
  # Save locally first
141
+ feedback_id = save_feedback_locally(conversation_state, satisfaction, feedback_text)
142
 
143
  # Get token from environment variable
144
  env_token = os.environ.get("HF_TOKEN")
 
153
 
154
  return status_msg
155
 
156
+ # Create the Gradio blocks interface
157
  with gr.Blocks() as demo:
158
+ # State to track conversation history
159
+ conversation_state = gr.State([])
160
+
161
  with gr.Row():
162
  with gr.Column(scale=3):
163
+ # Custom chat function wrapper to update state
164
+ def chat_with_state(message, history, state, temperature, top_p):
165
+ for partial_response, updated_state in predict(message, history, state, temperature, top_p):
166
+ # Update our state with each yield
167
+ state = updated_state
168
+ yield partial_response, state
169
+
170
+ # Create ChatInterface
171
  chatbot = gr.ChatInterface(
172
+ chat_with_state,
173
  additional_inputs=[
174
+ conversation_state,
175
  gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
176
  gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
177
  ],
178
+ additional_outputs=[conversation_state],
179
  type="messages"
180
  )
181
 
 
210
  feedback_modal
211
  )
212
 
213
+ # Connect the submit button to the submit_research_feedback function with the current conversation state
214
  submit_button.click(
215
+ submit_research_feedback,
216
+ inputs=[conversation_state, satisfaction, feedback_text],
217
  outputs=response_text
218
  )
219