import streamlit as st
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification


@st.cache_resource
def pipeline_getter():
    tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased')
    model = AutoModelForSequenceClassification.from_pretrained('KemmerEdition/my-distill-classifier')
    mapping = pd.read_csv('./categories.csv').values.squeeze()
    return tokenizer, model, mapping


tokenizer, model, mapping = pipeline_getter()


def predict_article_categories_with_confidence(
    text_data, 
    abstract_text=None, 
    confidence_level=0.95,
    max_categories=9
):
    tokenized_input = tokenizer(
        text=text_data,
        text_pair=abstract_text,
        padding=True,
        truncation=True,
        return_tensors='pt'
    )
    
    model_output = model(**tokenized_input)
    logits = model_output.logits
    probs = torch.sigmoid(logits).detach().numpy().flatten()
    
    sorted_indices = np.argsort(probs)[::-1]
    sorted_probs = probs[sorted_indices]
    
    cumulative_probs = np.cumsum(sorted_probs)
    
    selected_indices = []
    for i, cum_prob in enumerate(cumulative_probs):
        if cum_prob >= confidence_level or i >= max_categories - 1:
            selected_indices = sorted_indices[:i+1]
            break

    result = {
        'probabilities': probs,
        'predicted_categories': [mapping[idx] for idx in selected_indices],
        'confidence': cumulative_probs[len(selected_indices)-1],
        'top_category': mapping[sorted_indices[0]],
        'used_categories': len(selected_indices)
    }
    
    return result


st.markdown("""
<style>
    .header {
        font-size: 36px !important;
        color: #1f77b4;
        margin-bottom: 20px;
    }
    .input-box {
        background-color: #f0f2f6;
        padding: 20px;
        border-radius: 10px;
        margin-bottom: 20px;
    }
    .result-box {
        background-color: #e6f3ff;
        padding: 20px;
        border-radius: 10px;
        margin-top: 20px;
    }
    .category-badge {
        display: inline-block;
        background-color: #1f77b4;
        color: white;
        padding: 5px 10px;
        margin: 5px;
        border-radius: 15px;
        font-size: 14px;
    }
</style>
""", unsafe_allow_html=True)

st.markdown('<div class="header">Classificator of Paper from arxiv</div>', unsafe_allow_html=True)

with st.container():
    st.markdown('<div class="input-box">', unsafe_allow_html=True)
    title_input = st.text_input('**Here you can write title:**', placeholder="e.g. Quantum Machine Learning Approaches")
    abstract_input = st.text_area('**Here you can write summary from arxiv:**',
                                placeholder="Paste the abstract here for more accurate categorization...",
                                height=150)
    st.markdown('</div>', unsafe_allow_html=True)

    col1, col2 = st.columns(2)
    with col1:
        confidence_level = st.slider('**Confidence level (%)**', 80, 100, 95)
    with col2:
        max_categories = st.slider('**Maximum categories**', 1, 10, 3)

if st.button('**Press F (just press)**', type="primary"):
    if len(title_input) > 0:
        with st.spinner('Analyzing paper content...'):
            result = predict_article_categories_with_confidence(
                title_input,
                abstract_input if abstract_input else None,
                confidence_level=confidence_level/100,
                max_categories=max_categories
            )
        
        with st.container():
            st.markdown('<div class="result-box">', unsafe_allow_html=True)
            st.subheader("Categorization Results")
            
            st.markdown(f"**Most likely category:**")
            st.markdown(f'<div class="category-badge">{result["top_category"]} (p={result["probabilities"][np.argmax(result["probabilities"])]:.3f})</div>', 
                       unsafe_allow_html=True)
            
            if len(result["predicted_categories"]) > 1:
                st.markdown(f"Additional categories:")
                for category in result["predicted_categories"][1:]:
                    st.markdown(f'<div class="category-badge">{category}</div>', unsafe_allow_html=True)
            
            st.markdown("---")
    else:
        st.warning("Please enter at least the paper title")