Akeb0n0's picture
Update app.py
0ae5cd0 verified
import streamlit as st
import torch
import torch.serialization
from transformers import AutoModelForSequenceClassification, AutoTokenizer
@st.cache_resource
def load_model():
#trained_model = 'TinyBERT_cls_model.pt'
#base_model = 'huawei-noah/TinyBERT_General_4L_312D'
trained_model = 'distilbert-base_cls_model.pt'
base_model = 'distilbert-base-uncased'
checkpoint = torch.load(trained_model,
map_location='cpu',
weights_only=False)
model = AutoModelForSequenceClassification.from_pretrained(
base_model,
num_labels=len(checkpoint['idx_to_category'])
)
model.load_state_dict(checkpoint['model_state_dict'])
tokenizer = checkpoint['tokenizer']
idx_to_category = checkpoint['idx_to_category']
return model, tokenizer, idx_to_category
def predict(title, abstract, model, tokenizer, idx_to_category, threshold=0.95):
text = f"{title} /n {abstract}" if abstract else title
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
results = []
cumulative_prob = 0.0
for i in range(len(sorted_probs)):
if cumulative_prob >= threshold:
break
prob = sorted_probs[i].item()
results.append({
"category": idx_to_category[sorted_indices[i].item()],
"probability": prob
})
cumulative_prob += prob
return results, cumulative_prob
def main():
model, tokenizer, idx_to_category = load_model()
st.title("Классификатор статей")
st.markdown("Определение тематики научных статей по названию и аннотации")
with st.form("input_form"):
title = st.text_input("Название статьи*", placeholder="Введите название...")
abstract = st.text_area("Аннотация", placeholder="Введите текст аннотации (необязательно)...", height=150)
submitted = st.form_submit_button("Классифицировать")
if submitted and not title:
st.error("Пожалуйста, введите название статьи")
if submitted and title:
with st.spinner("Анализируем статью..."):
results, total_prob = predict(
title=title,
abstract=abstract,
model=model,
tokenizer=tokenizer,
idx_to_category=idx_to_category
)
st.success("Результаты классификации:")
st.metric("Общая вероятность", f"{total_prob*100:.1f}%")
for i, res in enumerate(results, 1):
col1, col2 = st.columns([1, 4])
with col1:
st.metric(f"Топ {i}", f"{res['probability']*100:.1f}%")
with col2:
st.progress(res['probability'], text=res['category'])
if __name__ == "__main__":
main()