Spaces:
Running
Running
import streamlit as st | |
import torch | |
import torch.serialization | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
def load_model(): | |
#trained_model = 'TinyBERT_cls_model.pt' | |
#base_model = 'huawei-noah/TinyBERT_General_4L_312D' | |
trained_model = 'distilbert-base_cls_model.pt' | |
base_model = 'distilbert-base-uncased' | |
checkpoint = torch.load(trained_model, | |
map_location='cpu', | |
weights_only=False) | |
model = AutoModelForSequenceClassification.from_pretrained( | |
base_model, | |
num_labels=len(checkpoint['idx_to_category']) | |
) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
tokenizer = checkpoint['tokenizer'] | |
idx_to_category = checkpoint['idx_to_category'] | |
return model, tokenizer, idx_to_category | |
def predict(title, abstract, model, tokenizer, idx_to_category, threshold=0.95): | |
text = f"{title} /n {abstract}" if abstract else title | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0] | |
sorted_probs, sorted_indices = torch.sort(probs, descending=True) | |
results = [] | |
cumulative_prob = 0.0 | |
for i in range(len(sorted_probs)): | |
if cumulative_prob >= threshold: | |
break | |
prob = sorted_probs[i].item() | |
results.append({ | |
"category": idx_to_category[sorted_indices[i].item()], | |
"probability": prob | |
}) | |
cumulative_prob += prob | |
return results, cumulative_prob | |
def main(): | |
model, tokenizer, idx_to_category = load_model() | |
st.title("Классификатор статей") | |
st.markdown("Определение тематики научных статей по названию и аннотации") | |
with st.form("input_form"): | |
title = st.text_input("Название статьи*", placeholder="Введите название...") | |
abstract = st.text_area("Аннотация", placeholder="Введите текст аннотации (необязательно)...", height=150) | |
submitted = st.form_submit_button("Классифицировать") | |
if submitted and not title: | |
st.error("Пожалуйста, введите название статьи") | |
if submitted and title: | |
with st.spinner("Анализируем статью..."): | |
results, total_prob = predict( | |
title=title, | |
abstract=abstract, | |
model=model, | |
tokenizer=tokenizer, | |
idx_to_category=idx_to_category | |
) | |
st.success("Результаты классификации:") | |
st.metric("Общая вероятность", f"{total_prob*100:.1f}%") | |
for i, res in enumerate(results, 1): | |
col1, col2 = st.columns([1, 4]) | |
with col1: | |
st.metric(f"Топ {i}", f"{res['probability']*100:.1f}%") | |
with col2: | |
st.progress(res['probability'], text=res['category']) | |
if __name__ == "__main__": | |
main() | |