Upload 6 files
Browse files- app.py +16 -0
- bert.py +37 -0
- inference_time.png +0 -0
- model_comparison.csv +4 -0
- perv.py +294 -0
- training_metrics.png +0 -0
app.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import bert
|
3 |
+
import perv
|
4 |
+
|
5 |
+
st.set_page_config(page_title="Объединённое NLP-приложение", layout="wide")
|
6 |
+
|
7 |
+
st.sidebar.title("Меню")
|
8 |
+
choice = st.sidebar.radio("Выберите модуль:", [
|
9 |
+
"Оценка токсичности",
|
10 |
+
"Анализ отзывов"
|
11 |
+
])
|
12 |
+
|
13 |
+
if choice == "Оценка токсичности":
|
14 |
+
bert.run()
|
15 |
+
elif choice == "Анализ отзывов":
|
16 |
+
analysis.run()
|
bert.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
2 |
+
import streamlit as st
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
MODEL_PATH = "rubert-finetuned"
|
7 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
|
8 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
9 |
+
model.eval()
|
10 |
+
|
11 |
+
# === Streamlit UI ===
|
12 |
+
st.set_page_config(page_title="Оценка токсичности", layout="centered")
|
13 |
+
st.title("💬 Оценка токсичности текста")
|
14 |
+
|
15 |
+
text = st.text_area("Введите сообщение", "Ты ужасный человек!")
|
16 |
+
submit = st.button("Проверить токсичность")
|
17 |
+
|
18 |
+
if submit and text.strip():
|
19 |
+
# Токенизация
|
20 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True)
|
21 |
+
|
22 |
+
# Предсказание
|
23 |
+
with torch.no_grad():
|
24 |
+
outputs = model(**inputs)
|
25 |
+
logits = outputs.logits
|
26 |
+
score = torch.sigmoid(logits).item() # степень токсичности
|
27 |
+
|
28 |
+
# Вывод
|
29 |
+
st.subheader("Результат:")
|
30 |
+
st.write(f"**Степень токсичности:** `{score:.3f}`")
|
31 |
+
|
32 |
+
if score > 0.8:
|
33 |
+
st.error("⚠️ Высокая токсичность!")
|
34 |
+
elif score > 0.4:
|
35 |
+
st.warning("⚠️ Средняя токсичность")
|
36 |
+
else:
|
37 |
+
st.success("✅ Низкая токсичность")
|
inference_time.png
ADDED
![]() |
model_comparison.csv
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,f1_macro,accuracy,training_time,inference_time
|
2 |
+
BERT,0.8967324592773303,0.9004249291784703,682.2076306343079,39.23531460762024
|
3 |
+
Classical ML,0.883260602673595,0.886543909348442,1.1419353485107422,0.15742778778076172
|
4 |
+
LSTM,0.873906501409314,0.8779036827195468,220.66053867340088,14.189206600189209
|
perv.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import seaborn as sns
|
5 |
+
from pathlib import Path
|
6 |
+
import time
|
7 |
+
import torch
|
8 |
+
import pickle
|
9 |
+
from transformers import AutoTokenizer, BertForSequenceClassification
|
10 |
+
from sklearn.pipeline import Pipeline
|
11 |
+
from sklearn.preprocessing import LabelEncoder
|
12 |
+
from sklearn.metrics import f1_score, accuracy_score
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
15 |
+
import json
|
16 |
+
from torch.serialization import safe_globals
|
17 |
+
from sklearn.preprocessing import LabelEncoder
|
18 |
+
|
19 |
+
def run():
|
20 |
+
def preprocess_text(text):
|
21 |
+
if not isinstance(text, str):
|
22 |
+
return ""
|
23 |
+
return text.lower().replace('\n', ' ').replace('\r', ' ').strip()
|
24 |
+
|
25 |
+
# Класс для Classical ML модели
|
26 |
+
class ClassicalML:
|
27 |
+
def __init__(self):
|
28 |
+
self.pipeline = None
|
29 |
+
self.label_encoder = None
|
30 |
+
|
31 |
+
def predict(self, X):
|
32 |
+
start_time = time.time()
|
33 |
+
preds = self.pipeline.predict(X)
|
34 |
+
return self.label_encoder.inverse_transform(preds), time.time() - start_time
|
35 |
+
|
36 |
+
with torch.serialization.safe_globals([LabelEncoder]):
|
37 |
+
checkpoint = torch.load('models/lstm/model.pt', map_location=torch.device('cpu'), weights_only=False)
|
38 |
+
|
39 |
+
class Attention(nn.Module):
|
40 |
+
def __init__(self, hidden_dim):
|
41 |
+
super().__init__()
|
42 |
+
self.attention = nn.Linear(hidden_dim, 1) # Простой линейный слой
|
43 |
+
|
44 |
+
def forward(self, lstm_output):
|
45 |
+
# lstm_output shape: [batch_size, seq_len, hidden_dim]
|
46 |
+
attention_weights = torch.softmax(self.attention(lstm_output).squeeze(-1), dim=1)
|
47 |
+
context = torch.bmm(attention_weights.unsqueeze(1), lstm_output).squeeze(1)
|
48 |
+
return context
|
49 |
+
|
50 |
+
# Класс для LSTM модели
|
51 |
+
class LSTMTrainer:
|
52 |
+
def __init__(self):
|
53 |
+
self.model = None
|
54 |
+
self.vocab = None
|
55 |
+
self.label_encoder = None
|
56 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
57 |
+
|
58 |
+
def predict(self, X):
|
59 |
+
self.model.eval()
|
60 |
+
preds = []
|
61 |
+
start_time = time.time()
|
62 |
+
with torch.no_grad():
|
63 |
+
for text in X:
|
64 |
+
tokens = preprocess_text(text).split()
|
65 |
+
seq = [self.vocab.get(token, 0) for token in tokens]
|
66 |
+
if not seq:
|
67 |
+
seq = [0]
|
68 |
+
text_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(self.device)
|
69 |
+
length_tensor = torch.tensor([len(seq)], dtype=torch.long)
|
70 |
+
output = self.model(text_tensor, length_tensor)
|
71 |
+
preds.append(torch.argmax(output).item())
|
72 |
+
return self.label_encoder.inverse_transform(preds), time.time() - start_time
|
73 |
+
|
74 |
+
@classmethod
|
75 |
+
def load(cls, path='models/lstm'):
|
76 |
+
checkpoint = torch.load(
|
77 |
+
f'{path}/model.pt',
|
78 |
+
map_location=torch.device('cpu'),
|
79 |
+
weights_only=False
|
80 |
+
)
|
81 |
+
|
82 |
+
model = cls()
|
83 |
+
model.vocab = checkpoint['vocab']
|
84 |
+
model.label_encoder = checkpoint['label_encoder']
|
85 |
+
|
86 |
+
# Инициализация модели
|
87 |
+
model.model = LSTMModel(
|
88 |
+
len(model.vocab),
|
89 |
+
checkpoint['embed_dim'],
|
90 |
+
checkpoint['hidden_dim'],
|
91 |
+
len(model.label_encoder.classes_)
|
92 |
+
).to(model.device)
|
93 |
+
|
94 |
+
# Адаптация state_dict
|
95 |
+
state_dict = checkpoint['model_state_dict']
|
96 |
+
new_state_dict = {}
|
97 |
+
|
98 |
+
for key, value in state_dict.items():
|
99 |
+
if key.startswith('attention.attention.'):
|
100 |
+
# Преобразуем ключи для соответствия Sequential
|
101 |
+
if 'weight' in key:
|
102 |
+
new_key = key.replace('attention.attention.', 'attention.attention.0.')
|
103 |
+
elif 'bias' in key:
|
104 |
+
new_key = key.replace('attention.attention.', 'attention.attention.0.')
|
105 |
+
new_state_dict[new_key] = value
|
106 |
+
else:
|
107 |
+
new_state_dict[key] = value
|
108 |
+
|
109 |
+
model.model.load_state_dict(new_state_dict, strict=False)
|
110 |
+
return model
|
111 |
+
|
112 |
+
# Класс для BERT модели
|
113 |
+
class BERTClassifier:
|
114 |
+
def __init__(self):
|
115 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
116 |
+
self.tokenizer = None
|
117 |
+
self.model = None
|
118 |
+
self.label_encoder = None
|
119 |
+
|
120 |
+
def predict(self, X):
|
121 |
+
self.model.eval()
|
122 |
+
preds = []
|
123 |
+
start_time = time.time()
|
124 |
+
with torch.no_grad():
|
125 |
+
for text in X:
|
126 |
+
inputs = self.tokenizer(
|
127 |
+
text,
|
128 |
+
padding=True,
|
129 |
+
truncation=True,
|
130 |
+
max_length=128,
|
131 |
+
return_tensors="pt"
|
132 |
+
).to(self.device) # Перемещаем входные данные на то же устройство, что и модель
|
133 |
+
outputs = self.model(**inputs)
|
134 |
+
preds.append(torch.argmax(outputs.logits).item())
|
135 |
+
return self.label_encoder.inverse_transform(preds), time.time() - start_time
|
136 |
+
|
137 |
+
# Функция для визуализации attention
|
138 |
+
def plot_attention(text, model, tokenizer):
|
139 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
|
140 |
+
outputs = model(**inputs, output_attentions=True)
|
141 |
+
attention = outputs.attentions[-1].squeeze(0).mean(dim=0)
|
142 |
+
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
|
143 |
+
|
144 |
+
plt.figure(figsize=(10, 8))
|
145 |
+
sns.heatmap(attention.detach().cpu().numpy(),
|
146 |
+
xticklabels=tokens,
|
147 |
+
yticklabels=tokens,
|
148 |
+
cmap="YlGnBu")
|
149 |
+
plt.title("Attention Scores")
|
150 |
+
st.pyplot(plt)
|
151 |
+
|
152 |
+
@st.cache_resource
|
153 |
+
def load_models():
|
154 |
+
# Classical ML
|
155 |
+
classical_ml = ClassicalML()
|
156 |
+
with open('models/classical_ml/pipeline.pkl', 'rb') as f:
|
157 |
+
classical_ml.pipeline = pickle.load(f)
|
158 |
+
with open('models/classical_ml/label_encoder.pkl', 'rb') as f:
|
159 |
+
classical_ml.label_encoder = pickle.load(f)
|
160 |
+
|
161 |
+
# LSTM (с обработкой ошибки весов)
|
162 |
+
lstm = LSTMTrainer()
|
163 |
+
try:
|
164 |
+
# Пробуем загрузить с weights_only=True (безопасный вариант)
|
165 |
+
checkpoint = torch.load(
|
166 |
+
'models/lstm/model.pt',
|
167 |
+
map_location=torch.device('cpu'), # Явно указываем CPU
|
168 |
+
weights_only=True
|
169 |
+
)
|
170 |
+
except:
|
171 |
+
# Если не получилось, загружаем с явным разрешением LabelEncoder
|
172 |
+
with safe_globals([LabelEncoder]):
|
173 |
+
checkpoint = torch.load(
|
174 |
+
'models/lstm/model.pt',
|
175 |
+
map_location=torch.device('cpu'), # Явно указываем CPU
|
176 |
+
weights_only=False
|
177 |
+
)
|
178 |
+
|
179 |
+
lstm.vocab = checkpoint['vocab']
|
180 |
+
lstm.label_encoder = checkpoint['label_encoder']
|
181 |
+
lstm.model = LSTMModel(
|
182 |
+
len(lstm.vocab),
|
183 |
+
checkpoint['embed_dim'],
|
184 |
+
checkpoint['hidden_dim'],
|
185 |
+
len(lstm.label_encoder.classes_)
|
186 |
+
).to(lstm.device) # Модель будет перенесена на устройство (CPU или GPU)
|
187 |
+
lstm.model.load_state_dict(checkpoint['model_state_dict'])
|
188 |
+
|
189 |
+
# BERT
|
190 |
+
bert = BERTClassifier()
|
191 |
+
bert.tokenizer = AutoTokenizer.from_pretrained('models/bert')
|
192 |
+
bert.model = BertForSequenceClassification.from_pretrained('models/bert')
|
193 |
+
bert.model.to(bert.device) # Перемещаем модель на нужное устройство после загрузки
|
194 |
+
with open('models/bert/label_encoder.pkl', 'rb') as f:
|
195 |
+
bert.label_encoder = pickle.load(f)
|
196 |
+
|
197 |
+
return classical_ml, lstm, bert
|
198 |
+
|
199 |
+
# Класс LSTM модели (добавлен для полноты)
|
200 |
+
class LSTMModel(nn.Module):
|
201 |
+
def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim):
|
202 |
+
super().__init__()
|
203 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
204 |
+
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
|
205 |
+
self.attention = Attention(hidden_dim)
|
206 |
+
self.fc = nn.Linear(hidden_dim, output_dim)
|
207 |
+
self.dropout = nn.Dropout(0.5)
|
208 |
+
|
209 |
+
def forward(self, text, lengths):
|
210 |
+
embedded = self.embedding(text)
|
211 |
+
packed = pack_padded_sequence(
|
212 |
+
embedded,
|
213 |
+
lengths.cpu(), # Убедимся, что lengths на CPU
|
214 |
+
batch_first=True,
|
215 |
+
enforce_sorted=False
|
216 |
+
)
|
217 |
+
packed_output, (hidden, cell) = self.lstm(packed)
|
218 |
+
output, _ = pad_packed_sequence(packed_output, batch_first=True)
|
219 |
+
context = self.attention(output)
|
220 |
+
return self.fc(self.dropout(context))
|
221 |
+
|
222 |
+
# Основное приложение
|
223 |
+
def main():
|
224 |
+
st.title("Анализ отзывов медицинских учреждений")
|
225 |
+
|
226 |
+
# Загрузка моделей
|
227 |
+
classical_ml, lstm, bert = load_models()
|
228 |
+
|
229 |
+
# Примерные метрики (замените на реальные из вашего обучения)
|
230 |
+
metrics = {
|
231 |
+
'Classical ML': {'f1_macro': 0.85, 'inference_time': 0.01},
|
232 |
+
'LSTM': {'f1_macro': 0.87, 'inference_time': 0.12},
|
233 |
+
'BERT': {'f1_macro': 0.92, 'inference_time': 0.05}
|
234 |
+
}
|
235 |
+
metrics_df = pd.DataFrame.from_dict(metrics, orient='index')
|
236 |
+
|
237 |
+
# Поле ввода текста
|
238 |
+
user_input = st.text_area("Введите ваш отзыв:", "Очень хорошая клиника, внимательные врачи")
|
239 |
+
|
240 |
+
if st.button("Проанализировать отзыв"):
|
241 |
+
if user_input:
|
242 |
+
# Добавляем категорию для совместимости
|
243 |
+
input_with_category = f"Поликлиники стоматологические {user_input}"
|
244 |
+
|
245 |
+
with st.spinner('Обработка...'):
|
246 |
+
# Получаем предсказания
|
247 |
+
ml_pred, ml_time = classical_ml.predict([input_with_category])
|
248 |
+
lstm_pred, lstm_time = lstm.predict([input_with_category])
|
249 |
+
bert_pred, bert_time = bert.predict([input_with_category])
|
250 |
+
|
251 |
+
# Вывод результатов в три колонки
|
252 |
+
col1, col2, col3 = st.columns(3)
|
253 |
+
|
254 |
+
with col1:
|
255 |
+
st.subheader("Classical ML")
|
256 |
+
st.metric("Предсказание", ml_pred[0])
|
257 |
+
st.metric("Время (сек)", f"{ml_time:.4f}")
|
258 |
+
|
259 |
+
with col2:
|
260 |
+
st.subheader("LSTM")
|
261 |
+
st.metric("Предсказание", lstm_pred[0])
|
262 |
+
st.metric("Время (сек)", f"{lstm_time:.4f}")
|
263 |
+
|
264 |
+
with col3:
|
265 |
+
st.subheader("BERT")
|
266 |
+
st.metric("Предсказание", bert_pred[0])
|
267 |
+
st.metric("Время (сек)", f"{bert_time:.4f}")
|
268 |
+
|
269 |
+
# Визуализация attention для BERT
|
270 |
+
st.header("Attention-механизм BERT")
|
271 |
+
plot_attention(user_input, bert.model, bert.tokenizer)
|
272 |
+
|
273 |
+
# Сравнительная таблица метрик
|
274 |
+
st.header("Сравнение моделей")
|
275 |
+
st.dataframe(metrics_df.style.highlight_max(axis=0))
|
276 |
+
|
277 |
+
# Графики метрик
|
278 |
+
st.header("Визуализация метрик")
|
279 |
+
fig, ax = plt.subplots(1, 2, figsize=(15, 5))
|
280 |
+
|
281 |
+
# График F1-score
|
282 |
+
metrics_df['f1_macro'].plot(kind='bar', ax=ax[0], color='skyblue')
|
283 |
+
ax[0].set_title('F1-macro score')
|
284 |
+
ax[0].set_ylabel('Score')
|
285 |
+
|
286 |
+
# График времени предсказания
|
287 |
+
metrics_df['inference_time'].plot(kind='bar', ax=ax[1], color='salmon')
|
288 |
+
ax[1].set_title('Время предсказания (сек)')
|
289 |
+
ax[1].set_ylabel('Seconds')
|
290 |
+
|
291 |
+
st.pyplot(fig)
|
292 |
+
|
293 |
+
if __name__ == "__main__":
|
294 |
+
main()
|
training_metrics.png
ADDED
![]() |