null and void
Update app.py
b1c60d2 verified
raw
history blame
11.9 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
import bitsandbytes as bnb
print(f"bitsandbytes version: {bnb.__version__}")
print(f"CUDA is available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
print(f"Current CUDA device: {torch.cuda.current_device()}")
print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
class ConversationManager:
def __init__(self):
self.models = {}
self.conversation = []
self.delay = 3
self.is_paused = False
self.current_model = None
self.initial_prompt = ""
self.task_complete = False
def load_model(self, model_name):
if not model_name:
print("Error: Empty model name provided")
return None
if model_name in self.models:
return self.models[model_name]
try:
print(f"Attempting to load model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Try to load the model with 8-bit quantization
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
except RuntimeError as e:
print(f"8-bit quantization not available, falling back to full precision: {e}")
if torch.cuda.is_available():
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
else:
model = AutoModelForCausalLM.from_pretrained(model_name)
except Exception as e:
print(f"Failed to load model {model_name}: {e}")
print(f"Error type: {type(e).__name__}")
print(f"Error details: {str(e)}")
return None
self.models[model_name] = (model, tokenizer)
print(f"Successfully loaded model: {model_name}")
return self.models[model_name]
except Exception as e:
print(f"Failed to load model {model_name}: {e}")
print(f"Error type: {type(e).__name__}")
print(f"Error details: {str(e)}")
return None
def generate_response(self, model_name, prompt):
model, tokenizer = self.load_model(model_name)
formatted_prompt = f"Human: {prompt.strip()}\n\nAssistant:"
inputs = tokenizer(formatted_prompt, return_tensors="pt", max_length=1024, truncation=True)
with torch.no_grad():
outputs = model.generate(**inputs, max_length=200, num_return_sequences=1, do_sample=True)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def add_to_conversation(self, model_name, response):
self.conversation.append((model_name, response))
if "task complete?" in response.lower():
self.task_complete = True
def get_conversation_history(self):
return "\n".join([f"{model}: {msg}" for model, msg in self.conversation])
def clear_conversation(self):
self.conversation = []
self.initial_prompt = ""
self.models = {}
self.current_model = None
self.task_complete = False
def rewind_conversation(self, steps):
self.conversation = self.conversation[:-steps]
self.task_complete = False
def rewind_and_insert(self, steps, inserted_response):
if steps > 0:
self.conversation = self.conversation[:-steps]
if inserted_response.strip():
last_model = self.conversation[-1][0] if self.conversation else "User"
next_model = "Model 1" if last_model == "Model 2" or last_model == "User" else "Model 2"
self.conversation.append((next_model, inserted_response))
self.current_model = last_model
self.task_complete = False
manager = ConversationManager()
def get_model(dropdown, custom):
return custom if custom and custom.strip() else dropdown
def chat(model1, model2, user_input, history, inserted_response=""):
try:
print(f"Starting chat with models: {model1}, {model2}")
print(f"User input: {user_input}")
model1 = get_model(model1, model1_custom.value)
model2 = get_model(model2, model2_custom.value)
print(f"Selected models: {model1}, {model2}")
if not manager.load_model(model1) or not manager.load_model(model2):
return "Error: Failed to load one or both models. Please check the model names and try again.", ""
if not manager.conversation:
manager.initial_prompt = user_input
manager.clear_conversation()
manager.add_to_conversation("User", user_input)
models = [model1, model2]
current_model_index = 0 if manager.current_model in ["User", "Model 2"] else 1
while not manager.task_complete:
if manager.is_paused:
yield history, "Conversation paused."
return
model = models[current_model_index]
manager.current_model = model
if inserted_response and current_model_index == 0:
response = inserted_response
inserted_response = ""
else:
conversation_history = manager.get_conversation_history()
prompt = f"{conversation_history}\n\nPlease continue the conversation. If you believe the task is complete, end your response with 'Task complete?'"
response = manager.generate_response(model, prompt)
manager.add_to_conversation(model, response)
history = manager.get_conversation_history()
for i in range(manager.delay, 0, -1):
yield history, f"{model} is writing... {i}"
time.sleep(1)
yield history, ""
if manager.task_complete:
yield history, "Models believe the task is complete. Are you satisfied with the result? (Yes/No)"
return
current_model_index = (current_model_index + 1) % 2
return history, "Conversation completed."
except Exception as e:
print(f"Error in chat function: {str(e)}")
print(f"Error type: {type(e).__name__}")
print(f"Error details: {str(e)}")
return f"An error occurred: {str(e)}", ""
def user_satisfaction(satisfied, history):
if satisfied.lower() == 'yes':
return history, "Task completed successfully."
else:
manager.task_complete = False
return history, "Continuing the conversation..."
def pause_conversation():
manager.is_paused = True
return "Conversation paused. Press Resume to continue."
def resume_conversation():
manager.is_paused = False
return "Conversation resumed."
def edit_response(edited_text):
if manager.conversation:
manager.conversation[-1] = (manager.current_model, edited_text)
manager.task_complete = False
return manager.get_conversation_history()
def restart_conversation(model1, model2, user_input):
manager.clear_conversation()
return chat(model1, model2, user_input, "")
def rewind_and_insert(steps, inserted_response, history):
manager.rewind_and_insert(int(steps), inserted_response)
return manager.get_conversation_history(), ""
open_source_models = [
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"bigcode/starcoder2-15b",
"bigcode/starcoder2-3b",
"tiiuae/falcon-7b",
"EleutherAI/gpt-neox-20b",
"google/flan-ul2",
"stabilityai/stablelm-zephyr-3b",
"HuggingFaceH4/zephyr-7b-beta",
"microsoft/phi-2",
"google/gemma-7b-it",
"OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
"mosaicml/mpt-7b-chat",
"databricks/dolly-v2-12b",
"thebloke/Wizard-Vicuna-13B-Uncensored-HF",
"bigscience/bloom-560m"
]
with gr.Blocks() as demo:
gr.Markdown("# ConversAI Playground")
with gr.Row():
with gr.Column(scale=1):
model1_dropdown = gr.Dropdown(choices=open_source_models, label="Model 1")
model1_custom = gr.Textbox(label="Custom Model 1")
with gr.Column(scale=1):
model2_dropdown = gr.Dropdown(choices=open_source_models, label="Model 2")
model2_custom = gr.Textbox(label="Custom Model 2")
user_input = gr.Textbox(label="Initial prompt", lines=2)
chat_history = gr.Textbox(label="Conversation", lines=20)
current_response = gr.Textbox(label="Current model response", lines=3)
with gr.Row():
pause_btn = gr.Button("Pause")
edit_btn = gr.Button("Edit")
rewind_btn = gr.Button("Rewind")
resume_btn = gr.Button("Resume")
restart_btn = gr.Button("Restart")
clear_btn = gr.Button("Clear")
with gr.Row():
rewind_steps = gr.Slider(0, 10, 1, label="Steps to rewind")
inserted_response = gr.Textbox(label="Insert response after rewind", lines=2)
delay_slider = gr.Slider(0, 10, 3, label="Response Delay (seconds)")
user_satisfaction_input = gr.Textbox(label="Are you satisfied with the result? (Yes/No)", visible=False)
gr.Markdown("""
## Button Descriptions
- **Pause**: Temporarily stops the conversation. The current model will finish its response.
- **Edit**: Allows you to modify the last response in the conversation.
- **Rewind**: Removes the specified number of last responses from the conversation.
- **Resume**: Continues the conversation from where it was paused.
- **Restart**: Begins a new conversation with the same or different models, keeping the initial prompt.
- **Clear**: Resets everything, including loaded models, conversation history, and initial prompt.
""")
def on_chat_update(history, response):
if response and "Models believe the task is complete" in response:
return gr.update(visible=True), gr.update(visible=False)
return gr.update(visible=False), gr.update(visible=True)
start_btn = gr.Button("Start Conversation")
chat_output = start_btn.click(
chat,
inputs=[
model1_dropdown,
model2_dropdown,
user_input,
chat_history
],
outputs=[chat_history, current_response]
)
chat_output.then(
on_chat_update,
inputs=[chat_history, current_response],
outputs=[user_satisfaction_input, start_btn]
)
user_satisfaction_input.submit(
user_satisfaction,
inputs=[user_satisfaction_input, chat_history],
outputs=[chat_history, current_response]
).then(
chat,
inputs=[
model1_dropdown,
model2_dropdown,
user_input,
chat_history
],
outputs=[chat_history, current_response]
)
pause_btn.click(pause_conversation, outputs=[current_response])
resume_btn.click(
chat,
inputs=[
model1_dropdown,
model2_dropdown,
user_input,
chat_history,
inserted_response
],
outputs=[chat_history, current_response]
)
edit_btn.click(edit_response, inputs=[current_response], outputs=[chat_history])
rewind_btn.click(rewind_and_insert, inputs=[rewind_steps, inserted_response, chat_history], outputs=[chat_history, current_response])
restart_btn.click(
restart_conversation,
inputs=[
model1_dropdown,
model2_dropdown,
user_input
],
outputs=[chat_history, current_response]
)
clear_btn.click(manager.clear_conversation, outputs=[chat_history, current_response, user_input])
delay_slider.change(lambda x: setattr(manager, 'delay', x), inputs=[delay_slider])
if __name__ == "__main__":
demo.launch()