import streamlit as st import torch import torch.serialization from transformers import AutoModelForSequenceClassification, AutoTokenizer @st.cache_resource def load_model(): #trained_model = 'TinyBERT_cls_model.pt' #base_model = 'huawei-noah/TinyBERT_General_4L_312D' trained_model = 'distilbert-base_cls_model.pt' base_model = 'distilbert-base-uncased' checkpoint = torch.load(trained_model, map_location='cpu', weights_only=False) model = AutoModelForSequenceClassification.from_pretrained( base_model, num_labels=len(checkpoint['idx_to_category']) ) model.load_state_dict(checkpoint['model_state_dict']) tokenizer = checkpoint['tokenizer'] idx_to_category = checkpoint['idx_to_category'] return model, tokenizer, idx_to_category def predict(title, abstract, model, tokenizer, idx_to_category, threshold=0.95): text = f"{title} /n {abstract}" if abstract else title inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0] sorted_probs, sorted_indices = torch.sort(probs, descending=True) results = [] cumulative_prob = 0.0 for i in range(len(sorted_probs)): if cumulative_prob >= threshold: break prob = sorted_probs[i].item() results.append({ "category": idx_to_category[sorted_indices[i].item()], "probability": prob }) cumulative_prob += prob return results, cumulative_prob def main(): model, tokenizer, idx_to_category = load_model() st.title("Классификатор статей") st.markdown("Определение тематики научных статей по названию и аннотации") with st.form("input_form"): title = st.text_input("Название статьи*", placeholder="Введите название...") abstract = st.text_area("Аннотация", placeholder="Введите текст аннотации (необязательно)...", height=150) submitted = st.form_submit_button("Классифицировать") if submitted and not title: st.error("Пожалуйста, введите название статьи") if submitted and title: with st.spinner("Анализируем статью..."): results, total_prob = predict( title=title, abstract=abstract, model=model, tokenizer=tokenizer, idx_to_category=idx_to_category ) st.success("Результаты классификации:") st.metric("Общая вероятность", f"{total_prob*100:.1f}%") for i, res in enumerate(results, 1): col1, col2 = st.columns([1, 4]) with col1: st.metric(f"Топ {i}", f"{res['probability']*100:.1f}%") with col2: st.progress(res['probability'], text=res['category']) if __name__ == "__main__": main()