import streamlit as st
import torch
from transformers import AutoModelForSequenceClassification, pipeline, AutoTokenizer, DistilBertForSequenceClassification, DistilBertTokenizerFast
import pandas as pd
import comments
from random import randint
import requests


def predict_cyberbullying_probability(sentence, tokenizer, model):
    # Preprocess the input sentence
    inputs = tokenizer(sentence, padding='max_length', return_token_type_ids=False, return_attention_mask=True, truncation=True, max_length=512, return_tensors='pt')
    
    attention_mask = inputs['attention_mask']
    inputs = inputs['input_ids']
    
    with torch.no_grad():
        # Forward pass
        outputs = model(inputs, attention_mask=attention_mask)

        probs = torch.sigmoid(outputs.logits.unsqueeze(1).flatten())

    res = probs.numpy().tolist()
    return res

# @st.cache
def perform_cyberbullying_analysis(tweet):
    with st.spinner(text="loading model, wait until spinner ends..."):

        model = AutoModelForSequenceClassification.from_pretrained('kingsotn/finetuned_cyberbullying')
        tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

    df = pd.DataFrame({'comment': [tweet]})
    list_probs = predict_cyberbullying_probability(tweet, tokenizer, model)
    for i, label in enumerate(labels[1:]):
        df[label] = list_probs[i]
    
    return df

def perform_default_analysis(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name)

    clf = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, framework="pt")

    tweet = st.text_area(label="Enter Text:",value="I'm nice at ping pong")
    submitted = st.form_submit_button("Analyze")
    
    if submitted:
        #loading bar
        with st.spinner(text="loading..."):
            out = clf(tweet)
            
        st.json(out)
        
        if out[0]["label"] == "POSITIVE" or out[0]["label"] == "POS":
            st.balloons()
            # prompt = f"{basic_prompt} + \n\nThe user wrote a tweet that says: {tweet}, compliment them on how nice of a person they are! Remember try to be as cringe and awkard as possible!"
            # response = generator(prompt, max_length=1000)[0]
            st.success("nice tweet!")
        else:
            # prompt = f"{basic_prompt} + \n\nThe user wrote a tweet that says: {tweet}, tell them on how terrible of a person they are! Remember try to be as cringe and awkard as possible!"
            # response = generator(prompt, max_length=1000)[0]
            st.error("bad tweet!")


# main -->
st.title("Toxic Tweets Analyzer")
st.write("💡 Toxic Tweets Analyzer uses AI with kingsotn/finetuned_cyberbullying (distilbert) to score tweets for toxicity, threat, and insult.")
image = "kanye_loves_tweet.jpg"
st.image(image, use_column_width=True)

labels = ['comment', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']

with st.form("my_form"):
    #select model
    model_name = st.selectbox("Enter a text and select a pre-trained model to get the sentiment analysis", ["kingsotn/finetuned_cyberbullying", "distilbert-base-uncased-finetuned-sst-2-english", "finiteautomata/bertweet-base-sentiment-analysis", "distilbert-base-uncased"])
    
    if model_name == "kingsotn/finetuned_cyberbullying":
        default = "I'm not even going to lie to you. I love me so much right now."
        tweet = st.text_area(label="Enter Text:",value=default)
        submitted = st.form_submit_button("Analyze textbox")
        random = st.form_submit_button("Get a random 😈😈😈 tweet (warning!!)")
        kanye = st.form_submit_button("Get a ye quote 🐻🎤🎧🎶")
        
        if random:
            tweet = comments.comments[randint(0, 354)]
            st.write(tweet)
            submitted = True
        
        if kanye:
            response = requests.get('https://api.kanye.rest/')
            if response.status_code == 200:
                data = response.json()
                tweet = data['quote']
            else:
                st.error("Error getting Kanye quote | status code: " + str(response.status_code))
            st.write(tweet)
            submitted = True
        
        if submitted:
            df = perform_cyberbullying_analysis(tweet)
            st.table(df)
    else:
        perform_default_analysis(model_name)