import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
from datasets import load_dataset
from peft import PeftModel
import os

title = "Gemma-2b SciQ"
description = """
Gemma-2b fine-tuned on SciQ
"""

article = "GitHub repository: https://github.com/P-Zande/nlp-team-4"

model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ.get("HF_TOKEN"))
base_model = AutoModelForCausalLM.from_pretrained(model_id, token=os.environ.get("HF_TOKEN"))

model = PeftModel.from_pretrained(base_model, "./")
model = model.merge_and_unload()


dataset = load_dataset("allenai/sciq")
random_test_samples = dataset["test"].select(range(5))

examples = []
for row in random_test_samples:
    examples.append([row['support'], ""])
    examples.append([row['support'], row['correct_answer']])


def predict(context = "", answer = ""):
    formatted = context.replace('\n', ' ') + "\n"
    
    if answer != "":
        formatted = context.replace('\n', ' ') + "\n" + answer.replace('\n', ' ') + "\n"

        
    inputs = tokenizer(formatted, return_tensors="pt")
    outputs = model.generate(**inputs, max_new_tokens=100)
    decoded_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)
    split_outputs = decoded_outputs.split("\n")
    
    if len(split_outputs) == 6:
        return (
            split_outputs[0],
            split_outputs[1],
            split_outputs[2],
            split_outputs[3],
            split_outputs[4],
            split_outputs[5],
        )
        
    return ("ERROR: " + decoded_outputs, None, None, None, None, None)
    

support_gr = gr.TextArea(
    label="Context",
    value="Bananas are yellow and curved."
)

answer_gr = gr.Text(
    label="Answer (optional)",
    value="yellow"
)

context_output_gr = gr.Text(
    label="Context"
)
answer_output_gr = gr.Text(
    label="Answer"
)
question_output_gr = gr.Text(
    label="Question"
)
distractor1_output_gr = gr.Text(
    label="Distractor 1"
)
distractor2_output_gr = gr.Text(
    label="Distractor 2"
)
distractor3_output_gr = gr.Text(
    label="Distractor 3"
)

gr.Interface(
    fn=predict,
    inputs=[support_gr, answer_gr],
    outputs=[context_output_gr, answer_output_gr, question_output_gr, distractor1_output_gr, distractor2_output_gr, distractor3_output_gr],
    title=title,
    description=description,
    article=article,
    examples=examples,
).launch()