null and void commited on
Commit
0f5e907
·
verified ·
1 Parent(s): b50cbe0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -38
app.py CHANGED
@@ -27,15 +27,35 @@ class ConversationManager:
27
  return self.models[model_name]
28
  except Exception as e:
29
  print(f"Failed to load model {model_name}: {e}")
 
 
30
  return None
31
 
32
  def generate_response(self, model_name, prompt):
33
  model, tokenizer = self.load_model(model_name)
34
- inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
 
 
 
 
 
 
 
35
  with torch.no_grad():
36
  outputs = model.generate(**inputs, max_length=200, num_return_sequences=1, do_sample=True)
37
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
38
 
 
 
 
 
 
 
 
 
 
 
 
39
  def add_to_conversation(self, model_name, response):
40
  self.conversation.append((model_name, response))
41
  if "task complete?" in response.lower(): # Check for task completion marker
@@ -72,48 +92,56 @@ def get_model(dropdown, custom):
72
  return (model, model) # Return a tuple (label, value)
73
 
74
  def chat(model1, model2, user_input, history, inserted_response=""):
75
- model1 = get_model(model1, model1_custom.value)[0]
76
- model2 = get_model(model2, model2_custom.value)[0]
77
-
78
- if not manager.conversation:
79
- manager.initial_prompt = user_input
80
- manager.clear_conversation()
81
- manager.add_to_conversation("User", user_input)
82
-
83
- models = [model1, model2]
84
- current_model_index = 0 if manager.current_model in ["User", "Model 2"] else 1
85
-
86
- while not manager.task_complete: # Continue until task is complete
87
- if manager.is_paused:
88
- yield history, "Conversation paused."
89
- return
90
-
91
- model = models[current_model_index]
92
- manager.current_model = model
93
-
94
- if inserted_response and current_model_index == 0:
95
- response = inserted_response
96
- inserted_response = ""
97
- else:
98
- prompt = manager.get_conversation_history() + "\n\nPlease continue the conversation. If you believe the task is complete, end your response with 'Task complete?'"
99
- response = manager.generate_response(model, prompt)
100
 
101
- manager.add_to_conversation(model, response)
102
- history = manager.get_conversation_history()
103
 
104
- for i in range(manager.delay, 0, -1):
105
- yield history, f"{model} is writing... {i}"
106
- time.sleep(1)
 
107
 
108
- yield history, ""
 
109
 
110
- if manager.task_complete:
111
- yield history, "Models believe the task is complete. Are you satisfied with the result? (Yes/No)"
112
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- current_model_index = (current_model_index + 1) % 2
115
-
116
- return history, "Conversation completed."
 
117
 
118
  def user_satisfaction(satisfied, history):
119
  if satisfied.lower() == 'yes':
 
27
  return self.models[model_name]
28
  except Exception as e:
29
  print(f"Failed to load model {model_name}: {e}")
30
+ print(f"Error type: {type(e).__name__}")
31
+ print(f"Error details: {str(e)}")
32
  return None
33
 
34
  def generate_response(self, model_name, prompt):
35
  model, tokenizer = self.load_model(model_name)
36
+
37
+ # Format the prompt based on the model
38
+ if "llama" in model_name.lower():
39
+ formatted_prompt = self.format_llama2_prompt(prompt)
40
+ else:
41
+ formatted_prompt = self.format_general_prompt(prompt)
42
+
43
+ inputs = tokenizer(formatted_prompt, return_tensors="pt", max_length=1024, truncation=True)
44
  with torch.no_grad():
45
  outputs = model.generate(**inputs, max_length=200, num_return_sequences=1, do_sample=True)
46
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
47
 
48
+ def format_llama2_prompt(self, prompt):
49
+ B_INST, E_INST = "[INST]", "[/INST]"
50
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
51
+ system_prompt = "You are a helpful AI assistant. Please provide a concise and relevant response."
52
+
53
+ formatted_prompt = f"{B_INST} {B_SYS}{system_prompt}{E_SYS}{prompt.strip()} {E_INST}"
54
+ return formatted_prompt
55
+
56
+ def format_general_prompt(self, prompt):
57
+ # A general format that might work for other models
58
+ return f"Human: {prompt.strip()}\n\nAssistant:"
59
  def add_to_conversation(self, model_name, response):
60
  self.conversation.append((model_name, response))
61
  if "task complete?" in response.lower(): # Check for task completion marker
 
92
  return (model, model) # Return a tuple (label, value)
93
 
94
  def chat(model1, model2, user_input, history, inserted_response=""):
95
+ try:
96
+ model1 = get_model(model1, model1_custom.value)[0]
97
+ model2 = get_model(model2, model2_custom.value)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ if not manager.load_model(model1) or not manager.load_model(model2):
100
+ return "Error: Failed to load one or both models. Please check the model names and try again.", ""
101
 
102
+ if not manager.conversation:
103
+ manager.initial_prompt = user_input
104
+ manager.clear_conversation()
105
+ manager.add_to_conversation("User", user_input)
106
 
107
+ models = [model1, model2]
108
+ current_model_index = 0 if manager.current_model in ["User", "Model 2"] else 1
109
 
110
+ while not manager.task_complete:
111
+ if manager.is_paused:
112
+ yield history, "Conversation paused."
113
+ return
114
+
115
+ model = models[current_model_index]
116
+ manager.current_model = model
117
+
118
+ if inserted_response and current_model_index == 0:
119
+ response = inserted_response
120
+ inserted_response = ""
121
+ else:
122
+ conversation_history = manager.get_conversation_history()
123
+ prompt = f"{conversation_history}\n\nPlease continue the conversation. If you believe the task is complete, end your response with 'Task complete?'"
124
+ response = manager.generate_response(model, prompt)
125
+
126
+ manager.add_to_conversation(model, response)
127
+ history = manager.get_conversation_history()
128
+
129
+ for i in range(manager.delay, 0, -1):
130
+ yield history, f"{model} is writing... {i}"
131
+ time.sleep(1)
132
+
133
+ yield history, ""
134
+
135
+ if manager.task_complete:
136
+ yield history, "Models believe the task is complete. Are you satisfied with the result? (Yes/No)"
137
+ return
138
+
139
+ current_model_index = (current_model_index + 1) % 2
140
 
141
+ return history, "Conversation completed."
142
+ except Exception as e:
143
+ print(f"Error in chat function: {str(e)}")
144
+ return f"An error occurred: {str(e)}", ""
145
 
146
  def user_satisfaction(satisfied, history):
147
  if satisfied.lower() == 'yes':