Akeb0n0 commited on
Commit
a7e2485
·
verified ·
1 Parent(s): c685cad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -2
app.py CHANGED
@@ -1,4 +1,100 @@
1
  import streamlit as st
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
+ from PIL import Image
5
+ import requests
6
+ from io import BytesIO
7
 
8
+
9
+ @st.cache_data
10
+ def load_header_image():
11
+ response = requests.get(
12
+ "https://upload.wikimedia.org/wikipedia/commons/thumb/b/bc/ArXiv_logo_2022.svg/512px-ArXiv_logo_2022.svg.png"
13
+ )
14
+ return Image.open(BytesIO(response.content))
15
+
16
+
17
+ @st.cache_resource
18
+ def load_model():
19
+ checkpoint = torch.load('TinyBERT_cls_model.pt', map_location='cpu')
20
+
21
+ model = AutoModelForSequenceClassification.from_pretrained(
22
+ "huawei-noah/TinyBERT_General_4L_312D",
23
+ num_labels=len(checkpoint['idx_to_category'])
24
+ )
25
+ model.load_state_dict(checkpoint['model_state_dict'])
26
+
27
+ tokenizer = checkpoint['tokenizer']
28
+ idx_to_category = checkpoint['idx_to_category']
29
+
30
+ return model, tokenizer, idx_to_category
31
+
32
+ def predict(title, abstract, model, tokenizer, idx_to_category, threshold=0.95):
33
+ text = f"{title} /n {abstract}" if abstract else title
34
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
35
+
36
+ with torch.no_grad():
37
+ outputs = model(**inputs)
38
+
39
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
40
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
41
+
42
+ results = []
43
+ cumulative_prob = 0.0
44
+
45
+ for i in range(len(sorted_probs)):
46
+ if cumulative_prob >= threshold:
47
+ break
48
+ prob = sorted_probs[i].item()
49
+ results.append({
50
+ "category": idx_to_category[sorted_indices[i].item()],
51
+ "probability": prob
52
+ })
53
+ cumulative_prob += prob
54
+
55
+ return results, cumulative_prob
56
+
57
+
58
+ def main():
59
+ model, tokenizer, idx_to_category = load_model()
60
+ header_img = load_header_image()
61
+
62
+ st.set_page_config(page_title="arXiv Classifier", layout="wide")
63
+
64
+ col1, col2 = st.columns([1, 4])
65
+ with col1:
66
+ st.image(header_img, width=100)
67
+ with col2:
68
+ st.title("arXiv Article Classifier")
69
+ st.markdown("Определение тематики научных статей по названию и аннотации")
70
+
71
+ with st.form("input_form"):
72
+ title = st.text_input("Название статьи*", placeholder="Введите название...")
73
+ abstract = st.text_area("Аннотация", placeholder="Введите текст аннотации (необязательно)...", height=150)
74
+ submitted = st.form_submit_button("Классифицировать")
75
+
76
+ if submitted and not title:
77
+ st.error("Пожалуйста, введите название статьи")
78
+
79
+ if submitted and title:
80
+ with st.spinner("Анализируем статью..."):
81
+ results, total_prob = predict(
82
+ title=title,
83
+ abstract=abstract,
84
+ model=model,
85
+ tokenizer=tokenizer,
86
+ idx_to_category=idx_to_category
87
+ )
88
+
89
+ st.success("Результаты классификации:")
90
+ st.metric("Общая вероятность", f"{total_prob*100:.1f}%")
91
+ for i, res in enumerate(results, 1):
92
+ col1, col2 = st.columns([1, 4])
93
+ with col1:
94
+ st.metric(f"Топ {i}", f"{res['probability']*100:.1f}%")
95
+ with col2:
96
+ st.progress(res['probability'], text=res['category'])
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()