import subprocess commands = 'pip uninstall gradio -y; echo "pwd is: $(pwd)"; pip install ./gradio-12.34.57.tar.gz' subprocess.run(commands, shell=True) import json import os import shutil import threading import gradio as gr from dialogues import DialogueTemplate from huggingface_hub import Repository from text_generation import Client from utils import get_full_text, wrap_html_code STYLE = """ // "done" class is injected when user has made // decision between two candidate generated answers .message.bot.done { animation: colorTransition 2s ease-in-out; } // fade out animation effect when user selects a choice @keyframes colorTransition { 0% { background-color: var(--checkbox-background-color-selected); } 100% { background-color: var(--background-fill-secondary); } } """ HF_TOKEN = os.environ.get("HF_TOKEN", None) REPO_ID = "sheonhan/rm-test-data" API_URL = "https://api-inference.huggingface.co/models/HuggingFaceH4/starcoderbase-finetuned-oasst1" LABELER_ID = "labeler_123" SESSION_ID = "session_123" client = Client( API_URL, headers={"Authorization": f"Bearer {HF_TOKEN}"}, ) repo = None if HF_TOKEN: try: shutil.rmtree("./data/") except: pass print("Pulling repo...") repo = Repository( local_dir="./data/", clone_from=REPO_ID, use_auth_token=HF_TOKEN, repo_type="dataset", ) repo.git_pull() system_message = "" def generate(user_message, history): past_messages = [] for data in history: user_data, model_data = data past_messages.extend( [ {"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}, ] ) if len(past_messages) < 1: dialogue_template = DialogueTemplate( system=system_message, messages=[{"role": "user", "content": user_message}], end_token="<|endoftext|>", ) prompt = dialogue_template.get_inference_prompt() else: dialogue_template = DialogueTemplate( system=system_message, messages=past_messages + [{"role": "user", "content": user_message}], end_token="<|endoftext|>", ) prompt = dialogue_template.get_inference_prompt() response_1 = client.generate_stream( prompt, temperature=0.1, stop_sequences=["<|end|>"] ) response_2 = client.generate_stream( prompt, temperature=0.9, stop_sequences=["<|end|>"] ) response_1_text = get_full_text(response_1) response_2_text = get_full_text(response_2) option_a = wrap_html_code(response_1_text.strip()) option_b = wrap_html_code(response_2_text.strip()) option_a = f"A: {option_a}" option_b = f"B: {option_b}" history.append((user_message, option_a, option_b)) return "", history def save_labeling_data(last_dialogue, score): ( prompt, response_1, response_2, ) = last_dialogue response_1 = response_1[3:] # Remove label "A: " response_2 = response_2[3:] # Remove label "B: " file_name = "data.jsonl" if repo is not None: repo.git_pull(rebase=True) with open(os.path.join("data", file_name), "a", encoding="utf-8") as f: data = { "labeler_id": LABELER_ID, "session_id": SESSION_ID, "prompt": prompt, "response_1": response_1, "response_2": response_2, "score": score, } json.dump(data, f, ensure_ascii=False) f.write("\n") repo.push_to_hub() def on_select(event: gr.SelectData, history): score = event.value index_to_delete = event.index threading.Thread(target=save_labeling_data, args=(history[-1], score)).start() del history[-1][index_to_delete] return history with gr.Blocks(css=STYLE) as demo: chatbot = gr.Chatbot() user_message = gr.Textbox() clear = gr.Button("Clear") user_message.submit( generate, [user_message, chatbot], [user_message, chatbot], queue=False, ).then( None, None, None, _js="""()=>{ let last_elem = document.querySelector("div.message.bot.done"); last_elem.classList.remove("done"); } """, ) chatbot.select(on_select, chatbot, chatbot).then( None, None, None, _js="""()=>{ let last_elem = document.querySelector("div.message.bot.latest"); last_elem.classList.remove("latest"); last_elem.classList.add("done"); } """, ) clear.click(lambda: None, None, chatbot, queue=False) demo.launch()