Spaces:
Build error
Build error
import subprocess | |
commands = 'pip uninstall gradio -y; echo "pwd is: $(pwd)"; pip install ./gradio-12.34.57.tar.gz' | |
subprocess.run(commands, shell=True) | |
import json | |
import os | |
import shutil | |
import threading | |
import gradio as gr | |
from dialogues import DialogueTemplate | |
from huggingface_hub import Repository | |
from text_generation import Client | |
from utils import get_full_text, wrap_html_code | |
STYLE = """ | |
// "done" class is injected when user has made | |
// decision between two candidate generated answers | |
.message.bot.done { | |
animation: colorTransition 2s ease-in-out; | |
} | |
// fade out animation effect when user selects a choice | |
@keyframes colorTransition { | |
0% { | |
background-color: var(--checkbox-background-color-selected); | |
} | |
100% { | |
background-color: var(--background-fill-secondary); | |
} | |
} | |
""" | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
REPO_ID = "sheonhan/rm-test-data" | |
API_URL = "https://api-inference.huggingface.co/models/HuggingFaceH4/starcoderbase-finetuned-oasst1" | |
LABELER_ID = "labeler_123" | |
SESSION_ID = "session_123" | |
client = Client( | |
API_URL, | |
headers={"Authorization": f"Bearer {HF_TOKEN}"}, | |
) | |
repo = None | |
if HF_TOKEN: | |
try: | |
shutil.rmtree("./data/") | |
except: | |
pass | |
print("Pulling repo...") | |
repo = Repository( | |
local_dir="./data/", | |
clone_from=REPO_ID, | |
use_auth_token=HF_TOKEN, | |
repo_type="dataset", | |
) | |
repo.git_pull() | |
system_message = "" | |
def generate(user_message, history): | |
past_messages = [] | |
for data in history: | |
user_data, model_data = data | |
past_messages.extend( | |
[ | |
{"role": "user", "content": user_data}, | |
{"role": "assistant", "content": model_data.rstrip()}, | |
] | |
) | |
if len(past_messages) < 1: | |
dialogue_template = DialogueTemplate( | |
system=system_message, | |
messages=[{"role": "user", "content": user_message}], | |
end_token="<|endoftext|>", | |
) | |
prompt = dialogue_template.get_inference_prompt() | |
else: | |
dialogue_template = DialogueTemplate( | |
system=system_message, | |
messages=past_messages + [{"role": "user", "content": user_message}], | |
end_token="<|endoftext|>", | |
) | |
prompt = dialogue_template.get_inference_prompt() | |
response_1 = client.generate_stream( | |
prompt, temperature=0.1, stop_sequences=["<|end|>"] | |
) | |
response_2 = client.generate_stream( | |
prompt, temperature=0.9, stop_sequences=["<|end|>"] | |
) | |
response_1_text = get_full_text(response_1) | |
response_2_text = get_full_text(response_2) | |
option_a = wrap_html_code(response_1_text.strip()) | |
option_b = wrap_html_code(response_2_text.strip()) | |
option_a = f"A: {option_a}" | |
option_b = f"B: {option_b}" | |
history.append((user_message, option_a, option_b)) | |
return "", history | |
def save_labeling_data(last_dialogue, score): | |
( | |
prompt, | |
response_1, | |
response_2, | |
) = last_dialogue | |
response_1 = response_1[3:] # Remove label "A: " | |
response_2 = response_2[3:] # Remove label "B: " | |
file_name = "data.jsonl" | |
if repo is not None: | |
repo.git_pull(rebase=True) | |
with open(os.path.join("data", file_name), "a", encoding="utf-8") as f: | |
data = { | |
"labeler_id": LABELER_ID, | |
"session_id": SESSION_ID, | |
"prompt": prompt, | |
"response_1": response_1, | |
"response_2": response_2, | |
"score": score, | |
} | |
json.dump(data, f, ensure_ascii=False) | |
f.write("\n") | |
repo.push_to_hub() | |
def on_select(event: gr.SelectData, history): | |
score = event.value | |
index_to_delete = event.index | |
threading.Thread(target=save_labeling_data, args=(history[-1], score)).start() | |
del history[-1][index_to_delete] | |
return history | |
with gr.Blocks(css=STYLE) as demo: | |
chatbot = gr.Chatbot() | |
user_message = gr.Textbox() | |
clear = gr.Button("Clear") | |
user_message.submit( | |
generate, | |
[user_message, chatbot], | |
[user_message, chatbot], | |
queue=False, | |
).then( | |
None, | |
None, | |
None, | |
_js="""()=>{ | |
let last_elem = document.querySelector("div.message.bot.done"); | |
last_elem.classList.remove("done"); | |
} | |
""", | |
) | |
chatbot.select(on_select, chatbot, chatbot).then( | |
None, | |
None, | |
None, | |
_js="""()=>{ | |
let last_elem = document.querySelector("div.message.bot.latest"); | |
last_elem.classList.remove("latest"); | |
last_elem.classList.add("done"); | |
} | |
""", | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
demo.launch() | |