Spaces:
Build error
Build error
null and void
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -27,15 +27,35 @@ class ConversationManager:
|
|
27 |
return self.models[model_name]
|
28 |
except Exception as e:
|
29 |
print(f"Failed to load model {model_name}: {e}")
|
|
|
|
|
30 |
return None
|
31 |
|
32 |
def generate_response(self, model_name, prompt):
|
33 |
model, tokenizer = self.load_model(model_name)
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
with torch.no_grad():
|
36 |
outputs = model.generate(**inputs, max_length=200, num_return_sequences=1, do_sample=True)
|
37 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
def add_to_conversation(self, model_name, response):
|
40 |
self.conversation.append((model_name, response))
|
41 |
if "task complete?" in response.lower(): # Check for task completion marker
|
@@ -72,48 +92,56 @@ def get_model(dropdown, custom):
|
|
72 |
return (model, model) # Return a tuple (label, value)
|
73 |
|
74 |
def chat(model1, model2, user_input, history, inserted_response=""):
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
if not manager.conversation:
|
79 |
-
manager.initial_prompt = user_input
|
80 |
-
manager.clear_conversation()
|
81 |
-
manager.add_to_conversation("User", user_input)
|
82 |
-
|
83 |
-
models = [model1, model2]
|
84 |
-
current_model_index = 0 if manager.current_model in ["User", "Model 2"] else 1
|
85 |
-
|
86 |
-
while not manager.task_complete: # Continue until task is complete
|
87 |
-
if manager.is_paused:
|
88 |
-
yield history, "Conversation paused."
|
89 |
-
return
|
90 |
-
|
91 |
-
model = models[current_model_index]
|
92 |
-
manager.current_model = model
|
93 |
-
|
94 |
-
if inserted_response and current_model_index == 0:
|
95 |
-
response = inserted_response
|
96 |
-
inserted_response = ""
|
97 |
-
else:
|
98 |
-
prompt = manager.get_conversation_history() + "\n\nPlease continue the conversation. If you believe the task is complete, end your response with 'Task complete?'"
|
99 |
-
response = manager.generate_response(model, prompt)
|
100 |
|
101 |
-
manager.
|
102 |
-
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
107 |
|
108 |
-
|
|
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
117 |
|
118 |
def user_satisfaction(satisfied, history):
|
119 |
if satisfied.lower() == 'yes':
|
|
|
27 |
return self.models[model_name]
|
28 |
except Exception as e:
|
29 |
print(f"Failed to load model {model_name}: {e}")
|
30 |
+
print(f"Error type: {type(e).__name__}")
|
31 |
+
print(f"Error details: {str(e)}")
|
32 |
return None
|
33 |
|
34 |
def generate_response(self, model_name, prompt):
|
35 |
model, tokenizer = self.load_model(model_name)
|
36 |
+
|
37 |
+
# Format the prompt based on the model
|
38 |
+
if "llama" in model_name.lower():
|
39 |
+
formatted_prompt = self.format_llama2_prompt(prompt)
|
40 |
+
else:
|
41 |
+
formatted_prompt = self.format_general_prompt(prompt)
|
42 |
+
|
43 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt", max_length=1024, truncation=True)
|
44 |
with torch.no_grad():
|
45 |
outputs = model.generate(**inputs, max_length=200, num_return_sequences=1, do_sample=True)
|
46 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
47 |
|
48 |
+
def format_llama2_prompt(self, prompt):
|
49 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
50 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
51 |
+
system_prompt = "You are a helpful AI assistant. Please provide a concise and relevant response."
|
52 |
+
|
53 |
+
formatted_prompt = f"{B_INST} {B_SYS}{system_prompt}{E_SYS}{prompt.strip()} {E_INST}"
|
54 |
+
return formatted_prompt
|
55 |
+
|
56 |
+
def format_general_prompt(self, prompt):
|
57 |
+
# A general format that might work for other models
|
58 |
+
return f"Human: {prompt.strip()}\n\nAssistant:"
|
59 |
def add_to_conversation(self, model_name, response):
|
60 |
self.conversation.append((model_name, response))
|
61 |
if "task complete?" in response.lower(): # Check for task completion marker
|
|
|
92 |
return (model, model) # Return a tuple (label, value)
|
93 |
|
94 |
def chat(model1, model2, user_input, history, inserted_response=""):
|
95 |
+
try:
|
96 |
+
model1 = get_model(model1, model1_custom.value)[0]
|
97 |
+
model2 = get_model(model2, model2_custom.value)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
+
if not manager.load_model(model1) or not manager.load_model(model2):
|
100 |
+
return "Error: Failed to load one or both models. Please check the model names and try again.", ""
|
101 |
|
102 |
+
if not manager.conversation:
|
103 |
+
manager.initial_prompt = user_input
|
104 |
+
manager.clear_conversation()
|
105 |
+
manager.add_to_conversation("User", user_input)
|
106 |
|
107 |
+
models = [model1, model2]
|
108 |
+
current_model_index = 0 if manager.current_model in ["User", "Model 2"] else 1
|
109 |
|
110 |
+
while not manager.task_complete:
|
111 |
+
if manager.is_paused:
|
112 |
+
yield history, "Conversation paused."
|
113 |
+
return
|
114 |
+
|
115 |
+
model = models[current_model_index]
|
116 |
+
manager.current_model = model
|
117 |
+
|
118 |
+
if inserted_response and current_model_index == 0:
|
119 |
+
response = inserted_response
|
120 |
+
inserted_response = ""
|
121 |
+
else:
|
122 |
+
conversation_history = manager.get_conversation_history()
|
123 |
+
prompt = f"{conversation_history}\n\nPlease continue the conversation. If you believe the task is complete, end your response with 'Task complete?'"
|
124 |
+
response = manager.generate_response(model, prompt)
|
125 |
+
|
126 |
+
manager.add_to_conversation(model, response)
|
127 |
+
history = manager.get_conversation_history()
|
128 |
+
|
129 |
+
for i in range(manager.delay, 0, -1):
|
130 |
+
yield history, f"{model} is writing... {i}"
|
131 |
+
time.sleep(1)
|
132 |
+
|
133 |
+
yield history, ""
|
134 |
+
|
135 |
+
if manager.task_complete:
|
136 |
+
yield history, "Models believe the task is complete. Are you satisfied with the result? (Yes/No)"
|
137 |
+
return
|
138 |
+
|
139 |
+
current_model_index = (current_model_index + 1) % 2
|
140 |
|
141 |
+
return history, "Conversation completed."
|
142 |
+
except Exception as e:
|
143 |
+
print(f"Error in chat function: {str(e)}")
|
144 |
+
return f"An error occurred: {str(e)}", ""
|
145 |
|
146 |
def user_satisfaction(satisfied, history):
|
147 |
if satisfied.lower() == 'yes':
|