import gradio as gr
from datasets import load_dataset
from PIL import Image
from collections import OrderedDict
from random import sample
import csv
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import random

feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")


classdict = OrderedDict()
for line in open('LOC_synset_mapping.txt', 'r').readlines():
    try:
        classdict[line.split(' ')[0]]= ' '.join(line.split(' ')[1:]).replace('\n','').split(',')[0]
    except:
        continue
classes = list(classdict.values())
imagedict={}
with open('image_labels.csv', 'r') as csv_file:
    reader = csv.DictReader(csv_file)
    for row in reader:
        imagedict[row['image_name']] = row['image_label']
images= list(imagedict.keys())
labels = list(set(imagedict.values()))

def model_classify(radio, im):
    if radio is not None:
        inputs = feature_extractor(images=im, return_tensors="pt")
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_class_idx = logits.argmax(-1).item()
        modelclass=model.config.id2label[predicted_class_idx]
        return  modelclass.split(',')[0], predicted_class_idx, True
    else:
        return None, None, False

def random_image():
    imname = random.choice(images)
    im = Image.open('images/'+ imname +'.jpg')
    label = str(imagedict[imname])
    labels.remove(label)
    options = sample(labels,3)
    options.append(label)
    random.shuffle(options)
    options = [classes[int(i)] for i in options]
    return im, label, gr.Radio.update(value=None, choices=options), None

def check_score(pred, truth, current_score, total_score, has_guessed):
    if not(has_guessed):
        if pred == classes[int(truth)]:
            total_score +=1
            return current_score + 1, f"Your score is {current_score+1} out of {total_score}!", total_score
        else:
            if pred is not None:
                total_score +=1
            return current_score, f"Your score is {current_score} out of {total_score}!", total_score
    else:
        return current_score, f"Your score is {current_score} out of {total_score}!", total_score



def compare_score(userclass, truth):
    if userclass is None:
        return"Try guessing a category!"
    else:
        if userclass == classes[int(truth)]:
            return "Great! You guessed it right"
        else:
            return "The right answer was " +str(classes[int(truth)])+ "! Try guessing the next image."

with gr.Blocks() as demo:
    user_score = gr.State(0)
    model_score = gr.State(0)
    image_label = gr.State()
    model_class = gr.State()
    total_score = gr.State(0)
    has_guessed = gr.State(False)

    gr.Markdown("# ImageNet Quiz")
    gr.Markdown("### ImageNet is one of the most popular datasets used for training and evaluating AI models.")
    gr.Markdown("### But many of its categories are hard to guess, even for humans.")
    gr.Markdown("#### Try your hand at guessing the category of each image displayed, from the options provided. Compare your answers to that of a neural network trained on the dataset, and see if you can do better!")
    with gr.Row():

        with gr.Column(min_width= 900):
            image = gr.Image(shape=(600, 600))
            radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True)
        with gr.Column():
            prediction = gr.Label(label="The AI model predicts:")
            score = gr.Label(label="Your Score")
            message = gr.Label(label="Did you guess it right?")

    btn = gr.Button("Next image")

    demo.load(random_image, None, [image, image_label, radio, prediction])
    radio.change(model_classify, [radio, image], [prediction, model_class, has_guessed])
    radio.change(check_score, [radio, image_label, user_score, total_score, has_guessed], [user_score, score, total_score])
    radio.change(compare_score, [radio, image_label], message)
    btn.click(random_image, None, [image, image_label, radio, prediction])
    btn.click(lambda :False, None, has_guessed)


demo.launch()