sheonhan's picture
Update Gradio
fc40822
import subprocess
commands = 'pip uninstall gradio -y; echo "pwd is: $(pwd)"; pip install ./gradio-12.34.57.tar.gz'
subprocess.run(commands, shell=True)
import json
import os
import shutil
import threading
import gradio as gr
from dialogues import DialogueTemplate
from huggingface_hub import Repository
from text_generation import Client
from utils import get_full_text, wrap_html_code
STYLE = """
// "done" class is injected when user has made
// decision between two candidate generated answers
.message.bot.done {
animation: colorTransition 2s ease-in-out;
}
// fade out animation effect when user selects a choice
@keyframes colorTransition {
0% {
background-color: var(--checkbox-background-color-selected);
}
100% {
background-color: var(--background-fill-secondary);
}
}
"""
HF_TOKEN = os.environ.get("HF_TOKEN", None)
REPO_ID = "sheonhan/rm-test-data"
API_URL = "https://api-inference.huggingface.co/models/HuggingFaceH4/starcoderbase-finetuned-oasst1"
LABELER_ID = "labeler_123"
SESSION_ID = "session_123"
client = Client(
API_URL,
headers={"Authorization": f"Bearer {HF_TOKEN}"},
)
repo = None
if HF_TOKEN:
try:
shutil.rmtree("./data/")
except:
pass
print("Pulling repo...")
repo = Repository(
local_dir="./data/",
clone_from=REPO_ID,
use_auth_token=HF_TOKEN,
repo_type="dataset",
)
repo.git_pull()
system_message = ""
def generate(user_message, history):
past_messages = []
for data in history:
user_data, model_data = data
past_messages.extend(
[
{"role": "user", "content": user_data},
{"role": "assistant", "content": model_data.rstrip()},
]
)
if len(past_messages) < 1:
dialogue_template = DialogueTemplate(
system=system_message,
messages=[{"role": "user", "content": user_message}],
end_token="<|endoftext|>",
)
prompt = dialogue_template.get_inference_prompt()
else:
dialogue_template = DialogueTemplate(
system=system_message,
messages=past_messages + [{"role": "user", "content": user_message}],
end_token="<|endoftext|>",
)
prompt = dialogue_template.get_inference_prompt()
response_1 = client.generate_stream(
prompt, temperature=0.1, stop_sequences=["<|end|>"]
)
response_2 = client.generate_stream(
prompt, temperature=0.9, stop_sequences=["<|end|>"]
)
response_1_text = get_full_text(response_1)
response_2_text = get_full_text(response_2)
option_a = wrap_html_code(response_1_text.strip())
option_b = wrap_html_code(response_2_text.strip())
option_a = f"A: {option_a}"
option_b = f"B: {option_b}"
history.append((user_message, option_a, option_b))
return "", history
def save_labeling_data(last_dialogue, score):
(
prompt,
response_1,
response_2,
) = last_dialogue
response_1 = response_1[3:] # Remove label "A: "
response_2 = response_2[3:] # Remove label "B: "
file_name = "data.jsonl"
if repo is not None:
repo.git_pull(rebase=True)
with open(os.path.join("data", file_name), "a", encoding="utf-8") as f:
data = {
"labeler_id": LABELER_ID,
"session_id": SESSION_ID,
"prompt": prompt,
"response_1": response_1,
"response_2": response_2,
"score": score,
}
json.dump(data, f, ensure_ascii=False)
f.write("\n")
repo.push_to_hub()
def on_select(event: gr.SelectData, history):
score = event.value
index_to_delete = event.index
threading.Thread(target=save_labeling_data, args=(history[-1], score)).start()
del history[-1][index_to_delete]
return history
with gr.Blocks(css=STYLE) as demo:
chatbot = gr.Chatbot()
user_message = gr.Textbox()
clear = gr.Button("Clear")
user_message.submit(
generate,
[user_message, chatbot],
[user_message, chatbot],
queue=False,
).then(
None,
None,
None,
_js="""()=>{
let last_elem = document.querySelector("div.message.bot.done");
last_elem.classList.remove("done");
}
""",
)
chatbot.select(on_select, chatbot, chatbot).then(
None,
None,
None,
_js="""()=>{
let last_elem = document.querySelector("div.message.bot.latest");
last_elem.classList.remove("latest");
last_elem.classList.add("done");
}
""",
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.launch()