import streamlit as st
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification

id_to_cat = {0: 'High Energy Physics - Theory',
             1: 'Category Theory',
             2: 'Methodology',
             3: 'Formal Languages and Automata Theory',
             4: 'Robotics',
             5: 'Fluid Dynamics',
             6: 'Spectral Theory',
             7: 'Econometrics',
             8: 'Programming Languages',
             9: 'Discrete Mathematics',
             10: 'Networking and Internet Architecture',
             11: 'Quantum Gases',
             12: 'Data Structures and Algorithms',
             13: 'Databases',
             14: 'Earth and Planetary Astrophysics',
             15: 'Optimization and Control',
             16: 'Biomolecules',
             17: 'Cryptography and Security',
             18: 'Geometric Topology',
             19: 'Other Condensed Matter',
             20: 'Statistical Mechanics',
             21: 'Analysis of PDEs',
             22: 'Quantitative Methods',
             23: 'Artificial Intelligence',
             24: 'Classical Analysis and ODEs',
             25: 'Machine Learning',
             26: 'Combinatorics',
             27: 'Pattern Formation and Solitons',
             28: 'Solar and Stellar Astrophysics',
             29: 'Audio and Speech Processing',
             30: 'Computer Science and Game Theory',
             31: 'Mesoscale and Nanoscale Physics',
             32: 'Instrumentation and Methods for Astrophysics',
             33: 'Logic',
             34: 'General Relativity and Quantum Cosmology',
             35: 'Differential Geometry',
             36: 'Graphics',
             37: 'Logic in Computer Science',
             38: 'Materials Science',
             39: 'Computational Finance',
             40: 'General Literature',
             41: 'Tissues and Organs',
             42: 'Digital Libraries',
             43: 'Sound',
             44: 'Computational Engineering, Finance, and Science',
             45: 'Biological Physics',
             46: 'Algebraic Geometry',
             47: 'Genomics',
             48: 'Algebraic Topology',
             49: 'Mathematical Software',
             50: 'Cosmology and Nongalactic Astrophysics',
             51: 'Probability',
             52: 'Data Analysis, Statistics and Probability',
             53: 'Classical Physics',
             54: 'Image and Video Processing',
             55: 'Neural and Evolutionary Computing',
             56: 'History and Philosophy of Physics',
             57: 'Astrophysics of Galaxies',
             58: 'Molecular Networks',
             59: 'Cellular Automata and Lattice Gases',
             60: 'Optics',
             61: 'General Finance',
             62: 'Mathematical Physics',
             63: 'Multimedia',
             64: 'Computational Physics',
             65: 'Performance',
             66: 'History and Overview',
             67: 'Instrumentation and Detectors',
             68: 'Computer Vision and Pattern Recognition',
             69: 'Medical Physics',
             70: 'Quantum Physics',
             71: 'Number Theory',
             72: 'Social and Information Networks',
             73: 'Populations and Evolution',
             74: 'High Energy Physics - Lattice',
             75: 'Pricing of Securities',
             76: 'Nuclear Theory',
             77: 'Human-Computer Interaction',
             78: 'Representation Theory',
             79: 'Geophysics',
             80: 'Operator Algebras',
             81: 'Computational Complexity',
             82: 'Distributed, Parallel, and Cluster Computing',
             83: 'Software Engineering',
             84: 'Computational Geometry',
             85: 'Cell Behavior',
             86: 'Quantum Algebra',
             87: 'Hardware Architecture',
             88: 'Strongly Correlated Electrons',
             89: 'Portfolio Management',
             90: 'General Topology',
             91: 'Statistical Finance',
             92: 'Computation and Language',
             93: 'Atmospheric and Oceanic Physics',
             94: 'Multiagent Systems',
             95: 'Rings and Algebras',
             96: 'Nuclear Experiment',
             97: 'Space Physics',
             98: 'Risk Management',
             99: 'General Mathematics',
             100: 'Other Statistics',
             101: 'Symbolic Computation',
             102: 'High Energy Physics - Phenomenology',
             103: 'Popular Physics',
             104: 'Functional Analysis',
             105: 'Economics',
             106: 'Computation',
             107: 'Operating Systems',
             108: 'Complex Variables',
             109: 'Applications',
             110: 'Information Theory',
             111: 'Physics and Society',
             112: 'Other Computer Science',
             113: 'Metric Geometry',
             114: 'Signal Processing',
             115: 'Information Retrieval',
             116: 'Numerical Analysis',
             117: 'Chemical Physics',
             118: 'Trading and Market Microstructure',
             119: 'Soft Condensed Matter',
             120: 'Computers and Society',
             121: 'General Physics',
             122: 'Superconductivity',
             123: 'Statistics Theory',
             124: 'Emerging Technologies',
             125: 'High Energy Astrophysical Phenomena',
             126: 'Other Quantitative Biology',
             127: 'High Energy Physics - Experiment',
             128: 'Commutative Algebra',
             129: 'Applied Physics',
             130: 'Dynamical Systems',
             131: 'Adaptation and Self-Organizing Systems',
             132: 'Neurons and Cognition',
             133: 'Subcellular Processes',
             134: 'Chaotic Dynamics',
             135: 'Group Theory',
             136: 'Systems and Control',
             137: 'Disordered Systems and Neural Networks'
            }

@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased')
    model = AutoModelForSequenceClassification.from_pretrained(
        'checkpoint',
        num_labels=len(id_to_cat),
        problem_type="multi_label_classification"
    )
    return model, tokenizer

try:
    model, tokenizer = load_model()
except OSError as e:
    st.error(f"Ошибка при загрузке модели: {e}")
    st.stop()

def classify_text(title, description):
    text = f"{title.strip()} {description.strip()}"
    try:
        classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=len(id_to_cat))
        results = classifier(text)
    except Exception as e:
        st.error(f"Ошибка при классификации текста: {e}")
        return []

    res = [
        (id_to_cat[int(entry['label'].split('_')[1])], entry['score'])
        for entry in results[0]
    ]
    total = sum(score for _, score in res)
    return [(label, score / total) for label, score in res]

st.title("🔬 Классификация англоязычных научных статей")
st.markdown("Введите заголовок и краткое описание научной статьи, чтобы определить её тематические категории.")

title = st.text_input("📝 Заголовок статьи", placeholder="Например: Deep Learning for Image Recognition")
description = st.text_area("🧾 Краткое описание статьи", height=150, placeholder="Кратко опишите содержание статьи...")
top_percent = st.text_input("📊 Порог суммарной вероятности (например, 95 или 0.95 для top 95%)", value="95")

if st.button("🚀 Классифицировать"):
    if not title and not description:
        st.warning("Пожалуйста, введите заголовок или описание статьи.")
    else:
        try:
            t = float(top_percent)
            if t > 1:
                t = t / 100
            if not (0 < t <= 1):
                raise ValueError()
        except ValueError:
            st.warning("Некорректное значение для порога вероятности. Используем значение по умолчанию: 95%.")
            t = 0.95

        with st.spinner("🔍 Классификация..."):
            results = classify_text(title, description)

            if results:
                cumulative_prob = 0.0
                st.subheader(f"📚 Топ категорий (до {int(t*100)}% совокупной вероятности):")
                for label, score in results:
                    st.write(f"- **{label}**: {score*100:.2f}%")
                    cumulative_prob += score
                    if cumulative_prob >= t:
                        break
            else:
                st.info("Не удалось получить результаты классификации.")
elif title or description:
    st.warning("Нажмите кнопку 'Классифицировать', чтобы получить результат.")