Akeb0n0's picture
Update app.py
a7e2485 verified
raw
history blame
3.56 kB
import streamlit as st
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from PIL import Image
import requests
from io import BytesIO
@st.cache_data
def load_header_image():
response = requests.get(
"https://upload.wikimedia.org/wikipedia/commons/thumb/b/bc/ArXiv_logo_2022.svg/512px-ArXiv_logo_2022.svg.png"
)
return Image.open(BytesIO(response.content))
@st.cache_resource
def load_model():
checkpoint = torch.load('TinyBERT_cls_model.pt', map_location='cpu')
model = AutoModelForSequenceClassification.from_pretrained(
"huawei-noah/TinyBERT_General_4L_312D",
num_labels=len(checkpoint['idx_to_category'])
)
model.load_state_dict(checkpoint['model_state_dict'])
tokenizer = checkpoint['tokenizer']
idx_to_category = checkpoint['idx_to_category']
return model, tokenizer, idx_to_category
def predict(title, abstract, model, tokenizer, idx_to_category, threshold=0.95):
text = f"{title} /n {abstract}" if abstract else title
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
results = []
cumulative_prob = 0.0
for i in range(len(sorted_probs)):
if cumulative_prob >= threshold:
break
prob = sorted_probs[i].item()
results.append({
"category": idx_to_category[sorted_indices[i].item()],
"probability": prob
})
cumulative_prob += prob
return results, cumulative_prob
def main():
model, tokenizer, idx_to_category = load_model()
header_img = load_header_image()
st.set_page_config(page_title="arXiv Classifier", layout="wide")
col1, col2 = st.columns([1, 4])
with col1:
st.image(header_img, width=100)
with col2:
st.title("arXiv Article Classifier")
st.markdown("Определение тематики научных статей по названию и аннотации")
with st.form("input_form"):
title = st.text_input("Название статьи*", placeholder="Введите название...")
abstract = st.text_area("Аннотация", placeholder="Введите текст аннотации (необязательно)...", height=150)
submitted = st.form_submit_button("Классифицировать")
if submitted and not title:
st.error("Пожалуйста, введите название статьи")
if submitted and title:
with st.spinner("Анализируем статью..."):
results, total_prob = predict(
title=title,
abstract=abstract,
model=model,
tokenizer=tokenizer,
idx_to_category=idx_to_category
)
st.success("Результаты классификации:")
st.metric("Общая вероятность", f"{total_prob*100:.1f}%")
for i, res in enumerate(results, 1):
col1, col2 = st.columns([1, 4])
with col1:
st.metric(f"Топ {i}", f"{res['probability']*100:.1f}%")
with col2:
st.progress(res['probability'], text=res['category'])
if __name__ == "__main__":
main()