import json
import spaces
import gradio as gr
from huggingface_hub import InferenceClient

client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

def system_instructions(question_difficulty, tone, topic):
    return f"""<s> [INST] Your are a great teacher and your task is to create 10 questions with 4 choices with a {question_difficulty} difficulty in a {tone} tone about {topic}, then create an answers. Index in JSON format, the questions as "Q#":"" to "Q#":"", the four choices as "Q#:C1":"" to "Q#:C4":"", and the answers as "A#":"Q#:C#" to "A#":"Q#:C#". [/INST]"""


with gr.Blocks(title="Quiz Maker", theme=gr.themes.Default(primary_hue="green", secondary_hue="green"), css="style.css") as demo:
    gr.HTML("""
<center>
<h1>Quiz Maker</h1>
<h2>AI-powered Learning Game</h2>
<i>⚠️ Still in development may take a few seconds to generate! ⚠️</i>
</center>
""")

    topic = gr.Textbox(label="Topic", placeholder="Write any topic")

    with gr.Row():
        radio = gr.Radio(
            ["easy", "average", "hard"], label="How difficult should the quiz be?"
        )

        radio_tone = gr.Radio(
            ["casual", "professional", "academic"], label="What tone should the quiz be?"
        )

    generate_quiz_btn = gr.Button("Generate Quiz!🚀")

    question_radios = [gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(
        visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(
        visible=False), gr.Radio(visible=False), gr.Radio(visible=False)]

    print(question_radios)

    @spaces.GPU
    @generate_quiz_btn.click(inputs=[radio, radio_tone, topic], outputs=question_radios, api_name="generate_quiz")
    def generate_quiz(question_difficulty, tone, user_prompt):
        formatted_prompt = system_instructions(
            question_difficulty, tone, user_prompt)

        pre_prompt = [
            {"role": "system", "content": formatted_prompt}
        ]

        generate_kwargs = dict(
            temperature=0.1,
            max_new_tokens=2048,
            top_p=0.95,
            repetition_penalty=1.0,
            do_sample=True,
            seed=42,
        )

        response = client.text_generation(
            formatted_prompt, **generate_kwargs, stream=False, details=False, return_full_text=False,
        )

        print(response)

        output_json = json.loads(response)

        print(output_json)

        global quiz_data

        quiz_data = output_json

        question_radio_list = []

        for question_num in range(1, 11):
            question_key = f"Q{question_num}"
            answer_key = f"A{question_num}"

            question = quiz_data.get(question_key)
            answer = quiz_data.get(quiz_data.get(answer_key))

            if not question or not answer:
                continue

            choice_keys = [f"{question_key}:C{i}" for i in range(1, 5)]
            choice_list = []
            for choice_key in choice_keys:
                choice = quiz_data.get(choice_key, "Choice not found")
                choice_list.append(f"{choice}")

            radio = gr.Radio(choices=choice_list, label=question,
                             visible=True, interactive=True)

            question_radio_list.append(radio)

        print(question_radio_list)

        return question_radio_list

    check_button = gr.Button("Check Score")

    score_textbox = gr.Markdown()

    @check_button.click(inputs=question_radios, outputs=score_textbox)
    def compare_answers(*user_answers):
        user_anwser_list = []
        user_anwser_list = user_answers

        answers_list = []

        for question_num in range(1, 20):
            answer_key = f"A{question_num}"
            answer = quiz_data.get(quiz_data.get(answer_key))
            if not answer:
                break
            answers_list.append(answer)

        score = 0

        for item in user_anwser_list:
            if item in answers_list:
                score += 1

        message = f"### You got {score} over 10!"

        return message

if __name__ == "__main__":
    demo.launch(show_api=False)