import subprocess

commands = 'pip uninstall gradio -y; echo "pwd is: $(pwd)"; pip install ./gradio-12.34.57.tar.gz'
subprocess.run(commands, shell=True)

import json
import os

import shutil
import threading

import gradio as gr

from dialogues import DialogueTemplate
from huggingface_hub import Repository
from text_generation import Client

from utils import get_full_text, wrap_html_code

STYLE = """
// "done" class is injected when user has made
//  decision between two candidate generated answers
.message.bot.done {
  animation: colorTransition 2s ease-in-out;
}

// fade out animation effect when user selects a choice
@keyframes colorTransition {
  0% {
    background-color: var(--checkbox-background-color-selected);
  }
  100% {
    background-color: var(--background-fill-secondary);
  }
}
"""

HF_TOKEN = os.environ.get("HF_TOKEN", None)
REPO_ID = "sheonhan/rm-test-data"
API_URL = "https://api-inference.huggingface.co/models/HuggingFaceH4/starcoderbase-finetuned-oasst1"


LABELER_ID = "labeler_123"
SESSION_ID = "session_123"

client = Client(
    API_URL,
    headers={"Authorization": f"Bearer {HF_TOKEN}"},
)

repo = None
if HF_TOKEN:
    try:
        shutil.rmtree("./data/")
    except:
        pass
    print("Pulling repo...")
    repo = Repository(
        local_dir="./data/",
        clone_from=REPO_ID,
        use_auth_token=HF_TOKEN,
        repo_type="dataset",
    )
    repo.git_pull()

system_message = ""


def generate(user_message, history):
    past_messages = []
    for data in history:
        user_data, model_data = data

        past_messages.extend(
            [
                {"role": "user", "content": user_data},
                {"role": "assistant", "content": model_data.rstrip()},
            ]
        )

    if len(past_messages) < 1:
        dialogue_template = DialogueTemplate(
            system=system_message,
            messages=[{"role": "user", "content": user_message}],
            end_token="<|endoftext|>",
        )
        prompt = dialogue_template.get_inference_prompt()
    else:
        dialogue_template = DialogueTemplate(
            system=system_message,
            messages=past_messages + [{"role": "user", "content": user_message}],
            end_token="<|endoftext|>",
        )
        prompt = dialogue_template.get_inference_prompt()

    response_1 = client.generate_stream(
        prompt, temperature=0.1, stop_sequences=["<|end|>"]
    )

    response_2 = client.generate_stream(
        prompt, temperature=0.9, stop_sequences=["<|end|>"]
    )

    response_1_text = get_full_text(response_1)
    response_2_text = get_full_text(response_2)

    option_a = wrap_html_code(response_1_text.strip())
    option_b = wrap_html_code(response_2_text.strip())

    option_a = f"A: {option_a}"
    option_b = f"B: {option_b}"

    history.append((user_message, option_a, option_b))

    return "", history


def save_labeling_data(last_dialogue, score):
    (
        prompt,
        response_1,
        response_2,
    ) = last_dialogue
    
    response_1 = response_1[3:] # Remove label "A: "
    response_2 = response_2[3:] # Remove label "B: "

    file_name = "data.jsonl"

    if repo is not None:
        repo.git_pull(rebase=True)

        with open(os.path.join("data", file_name), "a", encoding="utf-8") as f:
            data = {
                "labeler_id": LABELER_ID,
                "session_id": SESSION_ID,
                "prompt": prompt,
                "response_1": response_1,
                "response_2": response_2,
                "score": score,
            }
            json.dump(data, f, ensure_ascii=False)
            f.write("\n")

        repo.push_to_hub()


def on_select(event: gr.SelectData, history):
    score = event.value
    index_to_delete = event.index

    threading.Thread(target=save_labeling_data, args=(history[-1], score)).start()

    del history[-1][index_to_delete]
    return history


with gr.Blocks(css=STYLE) as demo:
    chatbot = gr.Chatbot()
    user_message = gr.Textbox()
    clear = gr.Button("Clear")

    user_message.submit(
        generate,
        [user_message, chatbot],
        [user_message, chatbot],
        queue=False,
    ).then(
        None,
        None,
        None,
        _js="""()=>{
      let last_elem = document.querySelector("div.message.bot.done");
      last_elem.classList.remove("done");
}
""",
    )

    chatbot.select(on_select, chatbot, chatbot).then(
        None,
        None,
        None,
        _js="""()=>{

      let last_elem = document.querySelector("div.message.bot.latest");
      last_elem.classList.remove("latest");
      last_elem.classList.add("done");
}
""",
    )

    clear.click(lambda: None, None, chatbot, queue=False)

demo.launch()