null and void commited on
Commit
0af8075
·
verified ·
1 Parent(s): b963ac6

Upload conversai_playground.py

Browse files
Files changed (1) hide show
  1. conversai_playground.py +328 -0
conversai_playground.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import time
5
+ import os
6
+ from huggingface_hub import whoami
7
+
8
+ # Hugging Face login
9
+ def hello(profile: gr.OAuthProfile | None) -> str:
10
+ if profile is None:
11
+ return "I don't know you."
12
+ return f"Hello {profile.name}"
13
+
14
+ def list_organizations(oauth_token: gr.OAuthToken | None) -> str:
15
+ if oauth_token is None:
16
+ return "Please log in to list organizations."
17
+ org_names = [org["name"] for org in whoami(oauth_token.token)["orgs"]]
18
+ return f"You belong to {', '.join(org_names)}."
19
+
20
+
21
+ class ConversationManager:
22
+ def __init__(self):
23
+ self.models = {}
24
+ self.conversation = []
25
+ self.delay = 3
26
+ self.is_paused = False
27
+ self.current_model = None
28
+ self.initial_prompt = ""
29
+ self.task_complete = False # New attribute for task completion
30
+
31
+ def load_model(self, model_name):
32
+ if model_name in self.models:
33
+ return self.models[model_name]
34
+
35
+ try:
36
+ print(f"Attempting to load model: {model_name}")
37
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
38
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
39
+ self.models[model_name] = (model, tokenizer)
40
+ print(f"Successfully loaded model: {model_name}")
41
+ return self.models[model_name]
42
+ except Exception as e:
43
+ print(f"Failed to load model {model_name}: {e}")
44
+ return None
45
+
46
+ def generate_response(self, model_name, prompt):
47
+ model, tokenizer = self.load_model(model_name)
48
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
49
+ with torch.no_grad():
50
+ outputs = model.generate(**inputs, max_length=200, num_return_sequences=1, do_sample=True)
51
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
52
+
53
+ def add_to_conversation(self, model_name, response):
54
+ self.conversation.append((model_name, response))
55
+ if "task complete?" in response.lower(): # Check for task completion marker
56
+ self.task_complete = True
57
+
58
+ def get_conversation_history(self):
59
+ return "\n".join([f"{model}: {msg}" for model, msg in self.conversation])
60
+
61
+ def clear_conversation(self):
62
+ self.conversation = []
63
+ self.initial_prompt = ""
64
+ self.models = {}
65
+ self.current_model = None
66
+ self.task_complete = False # Reset task completion status
67
+
68
+ def rewind_conversation(self, steps):
69
+ self.conversation = self.conversation[:-steps]
70
+ self.task_complete = False # Reset task completion status after rewinding
71
+
72
+ def rewind_and_insert(self, steps, inserted_response):
73
+ if steps > 0:
74
+ self.conversation = self.conversation[:-steps]
75
+ if inserted_response.strip():
76
+ last_model = self.conversation[-1][0] if self.conversation else "User"
77
+ next_model = "Model 1" if last_model == "Model 2" or last_model == "User" else "Model 2"
78
+ self.conversation.append((next_model, inserted_response))
79
+ self.current_model = last_model
80
+ self.task_complete = False # Reset task completion status after rewinding and inserting
81
+
82
+ manager = ConversationManager()
83
+
84
+ def get_model(dropdown, custom):
85
+ return custom if custom.strip() else dropdown
86
+
87
+ def chat(model1, model2, user_input, history, inserted_response=""):
88
+ try:
89
+ # Attempt to load models
90
+ manager.load_model(model1)
91
+ manager.load_model(model2)
92
+ except Exception as e:
93
+ return f"Error loading models: {str(e)}", ""
94
+
95
+ if not manager.conversation:
96
+ manager.initial_prompt = user_input
97
+ manager.clear_conversation()
98
+ manager.add_to_conversation("User", user_input)
99
+
100
+ models = [model1, model2]
101
+ current_model_index = 0 if manager.current_model in ["User", "Model 2"] else 1
102
+
103
+ while not manager.task_complete: # Continue until task is complete
104
+ if manager.is_paused:
105
+ yield history, "Conversation paused."
106
+ return
107
+
108
+ model = models[current_model_index]
109
+ manager.current_model = model
110
+
111
+ if inserted_response and current_model_index == 0:
112
+ response = inserted_response
113
+ inserted_response = ""
114
+ else:
115
+ prompt = manager.get_conversation_history() + "\n\nPlease continue the conversation. If you believe the task is complete, end your response with 'Task complete?'"
116
+ response = manager.generate_response(model, prompt)
117
+
118
+ manager.add_to_conversation(model, response)
119
+ history = manager.get_conversation_history()
120
+
121
+ for i in range(manager.delay, 0, -1):
122
+ yield history, f"{model} is writing... {i}"
123
+ time.sleep(1)
124
+
125
+ yield history, ""
126
+
127
+ if manager.task_complete:
128
+ yield history, "Models believe the task is complete. Are you satisfied with the result? (Yes/No)"
129
+ return
130
+
131
+ current_model_index = (current_model_index + 1) % 2
132
+
133
+ return history, "Conversation completed."
134
+
135
+ models = [model1, model2]
136
+ current_model_index = 0 if manager.current_model in ["User", "Model 2"] else 1
137
+
138
+ while not manager.task_complete: # Continue until task is complete
139
+ if manager.is_paused:
140
+ yield history, "Conversation paused."
141
+ return
142
+
143
+ model = models[current_model_index]
144
+ manager.current_model = model
145
+
146
+ if inserted_response and current_model_index == 0:
147
+ response = inserted_response
148
+ inserted_response = ""
149
+ else:
150
+ prompt = manager.get_conversation_history() + "\n\nPlease continue the conversation. If you believe the task is complete, end your response with 'Task complete?'"
151
+ response = manager.generate_response(model, prompt)
152
+
153
+ manager.add_to_conversation(model, response)
154
+ history = manager.get_conversation_history()
155
+
156
+ for i in range(manager.delay, 0, -1):
157
+ yield history, f"{model} is writing... {i}"
158
+ time.sleep(1)
159
+
160
+ yield history, ""
161
+
162
+ if manager.task_complete:
163
+ yield history, "Models believe the task is complete. Are you satisfied with the result? (Yes/No)"
164
+ return
165
+
166
+ current_model_index = (current_model_index + 1) % 2
167
+
168
+ return history, "Conversation completed."
169
+
170
+ def user_satisfaction(satisfied, history):
171
+ if satisfied.lower() == 'yes':
172
+ return history, "Task completed successfully."
173
+ else:
174
+ manager.task_complete = False
175
+ return history, "Continuing the conversation..."
176
+
177
+ def pause_conversation():
178
+ manager.is_paused = True
179
+ return "Conversation paused. Press Resume to continue."
180
+
181
+ def resume_conversation():
182
+ manager.is_paused = False
183
+ return "Conversation resumed."
184
+
185
+ def edit_response(edited_text):
186
+ if manager.conversation:
187
+ manager.conversation[-1] = (manager.current_model, edited_text)
188
+ manager.task_complete = False # Reset task completion status after editing
189
+ return manager.get_conversation_history()
190
+
191
+ def restart_conversation(model1, model2, user_input):
192
+ manager.clear_conversation()
193
+ return chat(model1, model2, user_input, "")
194
+
195
+ def rewind_and_insert(steps, inserted_response, history):
196
+ manager.rewind_and_insert(int(steps), inserted_response)
197
+ return manager.get_conversation_history(), ""
198
+
199
+ # This list should be populated with the exact model names when available
200
+ open_source_models = [
201
+ "meta-llama/Llama-2-7b-chat-hf",
202
+ "meta-llama/Llama-2-13b-chat-hf",
203
+ "meta-llama/Llama-2-70b-chat-hf",
204
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
205
+ "bigcode/starcoder2-15b",
206
+ "bigcode/starcoder2-3b",
207
+ "tiiuae/falcon-7b",
208
+ "tiiuae/falcon-40b",
209
+ "EleutherAI/gpt-neox-20b",
210
+ "google/flan-ul2",
211
+ "stabilityai/stablelm-zephyr-3b",
212
+ "HuggingFaceH4/zephyr-7b-beta",
213
+ "microsoft/phi-2",
214
+ "google/gemma-7b-it"
215
+ ]
216
+
217
+ with gr.Blocks() as demo:
218
+ gr.LoginButton()
219
+ m1 = gr.Markdown()
220
+ m2 = gr.Markdown()
221
+ demo.load(hello, inputs=None, outputs=m1)
222
+ demo.load(list_organizations, inputs=None, outputs=m2)
223
+ gr.Markdown("# ConversAI Playground")
224
+
225
+ with gr.Row():
226
+ with gr.Column(scale=1):
227
+ model1_dropdown = gr.Dropdown(open_source_models, label="Model 1")
228
+ model1_custom = gr.Textbox(label="Custom Model 1")
229
+ with gr.Column(scale=1):
230
+ model2_dropdown = gr.Dropdown(open_source_models, label="Model 2")
231
+ model2_custom = gr.Textbox(label="Custom Model 2")
232
+
233
+ user_input = gr.Textbox(label="Initial prompt", lines=2)
234
+ chat_history = gr.Textbox(label="Conversation", lines=20)
235
+ current_response = gr.Textbox(label="Current model response", lines=3)
236
+
237
+ with gr.Row():
238
+ pause_btn = gr.Button("Pause")
239
+ edit_btn = gr.Button("Edit")
240
+ rewind_btn = gr.Button("Rewind")
241
+ resume_btn = gr.Button("Resume")
242
+ restart_btn = gr.Button("Restart")
243
+ clear_btn = gr.Button("Clear")
244
+
245
+ with gr.Row():
246
+ rewind_steps = gr.Slider(0, 10, 1, label="Steps to rewind")
247
+ inserted_response = gr.Textbox(label="Insert response after rewind", lines=2)
248
+
249
+ delay_slider = gr.Slider(0, 10, 3, label="Response Delay (seconds)")
250
+
251
+ user_satisfaction_input = gr.Textbox(label="Are you satisfied with the result? (Yes/No)", visible=False)
252
+
253
+ gr.Markdown("""
254
+ ## Button Descriptions
255
+ - **Pause**: Temporarily stops the conversation. The current model will finish its response.
256
+ - **Edit**: Allows you to modify the last response in the conversation.
257
+ - **Rewind**: Removes the specified number of last responses from the conversation.
258
+ - **Resume**: Continues the conversation from where it was paused.
259
+ - **Restart**: Begins a new conversation with the same or different models, keeping the initial prompt.
260
+ - **Clear**: Resets everything, including loaded models, conversation history, and initial prompt.
261
+ """)
262
+
263
+ def on_chat_update(history, response):
264
+ if "Models believe the task is complete" in response:
265
+ return gr.update(visible=True), gr.update(visible=False)
266
+ return gr.update(visible=False), gr.update(visible=True)
267
+
268
+ start_btn = gr.Button("Start Conversation")
269
+ chat_output = start_btn.click(
270
+ chat,
271
+ inputs=[
272
+ gr.Dropdown(choices=lambda: [get_model(model1_dropdown.value, model1_custom.value)], value=lambda: get_model(model1_dropdown.value, model1_custom.value)),
273
+ gr.Dropdown(choices=lambda: [get_model(model2_dropdown.value, model2_custom.value)], value=lambda: get_model(model2_dropdown.value, model2_custom.value)),
274
+ user_input,
275
+ chat_history
276
+ ],
277
+ outputs=[chat_history, current_response]
278
+ )
279
+
280
+ chat_output.then(
281
+ on_chat_update,
282
+ inputs=[chat_history, current_response],
283
+ outputs=[user_satisfaction_input, start_btn]
284
+ )
285
+
286
+ user_satisfaction_input.submit(
287
+ user_satisfaction,
288
+ inputs=[user_satisfaction_input, chat_history],
289
+ outputs=[chat_history, current_response]
290
+ ).then(
291
+ chat,
292
+ inputs=[
293
+ gr.Dropdown(choices=lambda: [get_model(model1_dropdown.value, model1_custom.value)], value=lambda: get_model(model1_dropdown.value, model1_custom.value)),
294
+ gr.Dropdown(choices=lambda: [get_model(model2_dropdown.value, model2_custom.value)], value=lambda: get_model(model2_dropdown.value, model2_custom.value)),
295
+ user_input,
296
+ chat_history
297
+ ],
298
+ outputs=[chat_history, current_response]
299
+ )
300
+
301
+ pause_btn.click(pause_conversation, outputs=[current_response])
302
+ resume_btn.click(
303
+ chat,
304
+ inputs=[
305
+ gr.Dropdown(choices=lambda: [get_model(model1_dropdown.value, model1_custom.value)], value=lambda: get_model(model1_dropdown.value, model1_custom.value)),
306
+ gr.Dropdown(choices=lambda: [get_model(model2_dropdown.value, model2_custom.value)], value=lambda: get_model(model2_dropdown.value, model2_custom.value)),
307
+ user_input,
308
+ chat_history,
309
+ inserted_response
310
+ ],
311
+ outputs=[chat_history, current_response]
312
+ )
313
+ edit_btn.click(edit_response, inputs=[current_response], outputs=[chat_history])
314
+ rewind_btn.click(rewind_and_insert, inputs=[rewind_steps, inserted_response, chat_history], outputs=[chat_history, current_response])
315
+ restart_btn.click(
316
+ restart_conversation,
317
+ inputs=[
318
+ gr.Dropdown(choices=lambda: [get_model(model1_dropdown.value, model1_custom.value)], value=lambda: get_model(model1_dropdown.value, model1_custom.value)),
319
+ gr.Dropdown(choices=lambda: [get_model(model2_dropdown.value, model2_custom.value)], value=lambda: get_model(model2_dropdown.value, model2_custom.value)),
320
+ user_input
321
+ ],
322
+ outputs=[chat_history, current_response]
323
+ )
324
+ clear_btn.click(manager.clear_conversation, outputs=[chat_history, current_response, user_input])
325
+ delay_slider.change(lambda x: setattr(manager, 'delay', x), inputs=[delay_slider])
326
+
327
+ if __name__ == "__main__":
328
+ demo.launch()