import gradio as gr
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline


class TwitterEmotionClassifier:
    def __init__(self, model_name: str, model_type: str):
        self.is_gpu = False
        self.model_type = model_type
        device = torch.device("cuda") if self.is_gpu else torch.device("cpu")
        model = AutoModelForSequenceClassification.from_pretrained(model_name)
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model.to(device)
        model.eval()
        self.bertweet = pipeline(
            "text-classification",
            model=model,
            tokenizer=tokenizer,
            device=self.is_gpu - 1,
        )
        self.deberta = None
        self.emotions = {
            "LABEL_0": "sadness",
            "LABEL_1": "joy",
            "LABEL_2": "love",
            "LABEL_3": "anger",
            "LABEL_4": "fear",
            "LABEL_5": "surprise",
        }

    def get_model(self, model_type: str):
        if self.model_type == "bertweet" and model_type == self.model_type:
            return self.bertweet
        elif model_type == "deberta":
            if self.deberta:
                return self.deberta
            model = AutoModelForSequenceClassification.from_pretrained(
                "Emanuel/twitter-emotion-deberta-v3-base"
            )
            tokenizer = AutoTokenizer.from_pretrained(
                "Emanuel/twitter-emotion-deberta-v3-base"
            )
            self.deberta = pipeline(
                "text-classification",
                model=model,
                tokenizer=tokenizer,
                device=self.is_gpu - 1,
            )
            return self.deberta

    def predict(self, twitter: str, model_type: str):
        classifier = self.get_model(model_type)
        preds = classifier(twitter, return_all_scores=True)
        if preds:
            pred = preds[0]
            res = {
                "Sadness 😢": pred[0]["score"],
                "Joy 😂": pred[1]["score"],
                "Love 💛": pred[2]["score"],
                "Anger 😠": pred[3]["score"],
                "Fear 😱": pred[4]["score"],
                "Surprise 😮": pred[5]["score"],
            }
            return res
        return None


def main():

    model = TwitterEmotionClassifier("Emanuel/bertweet-emotion-base", "bertweet")
    interFace = gr.Interface(
        fn=model.predict,
        inputs=[
            gr.inputs.Textbox(
                placeholder="What's happenning?", label="Tweet content", lines=5
            ),
            gr.inputs.Radio(["bertweet", "deberta"], label="Model"),
        ],
        outputs=gr.outputs.Label(num_top_classes=6, label="Emotions of this tweet is "),
        verbose=True,
        examples=[
            ["This GOT show just remember LOTR times!", "bertweet"],
            [
                "Man, can't believe that my 30 days of training just got a NaN loss",
                "bertweet",
            ],
            ["I couldn't see 3 Tom Hollands coming...", "bertweet"],
            [
                "There is nothing better than a soul-warming coffee in the morning",
                "bertweet",
            ],
            ["I fear the vanishing gradient", "deberta"],
        ],
        title="Emotion classification 🤖",
        description="",
        theme="huggingface",
    )
    interFace.launch()


if __name__ == "__main__":
    main()