astart01 commited on
Commit
2b13982
·
verified ·
1 Parent(s): d22a353

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +16 -0
  2. bert.py +37 -0
  3. inference_time.png +0 -0
  4. model_comparison.csv +4 -0
  5. perv.py +294 -0
  6. training_metrics.png +0 -0
app.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import bert
3
+ import perv
4
+
5
+ st.set_page_config(page_title="Объединённое NLP-приложение", layout="wide")
6
+
7
+ st.sidebar.title("Меню")
8
+ choice = st.sidebar.radio("Выберите модуль:", [
9
+ "Оценка токсичности",
10
+ "Анализ отзывов"
11
+ ])
12
+
13
+ if choice == "Оценка токсичности":
14
+ bert.run()
15
+ elif choice == "Анализ отзывов":
16
+ analysis.run()
bert.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ import streamlit as st
3
+ import torch
4
+
5
+
6
+ MODEL_PATH = "rubert-finetuned"
7
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
8
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
9
+ model.eval()
10
+
11
+ # === Streamlit UI ===
12
+ st.set_page_config(page_title="Оценка токсичности", layout="centered")
13
+ st.title("💬 Оценка токсичности текста")
14
+
15
+ text = st.text_area("Введите сообщение", "Ты ужасный человек!")
16
+ submit = st.button("Проверить токсичность")
17
+
18
+ if submit and text.strip():
19
+ # Токенизация
20
+ inputs = tokenizer(text, return_tensors="pt", truncation=True)
21
+
22
+ # Предсказание
23
+ with torch.no_grad():
24
+ outputs = model(**inputs)
25
+ logits = outputs.logits
26
+ score = torch.sigmoid(logits).item() # степень токсичности
27
+
28
+ # Вывод
29
+ st.subheader("Результат:")
30
+ st.write(f"**Степень токсичности:** `{score:.3f}`")
31
+
32
+ if score > 0.8:
33
+ st.error("⚠️ Высокая токсичность!")
34
+ elif score > 0.4:
35
+ st.warning("⚠️ Средняя токсичность")
36
+ else:
37
+ st.success("✅ Низкая токсичность")
inference_time.png ADDED
model_comparison.csv ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ,f1_macro,accuracy,training_time,inference_time
2
+ BERT,0.8967324592773303,0.9004249291784703,682.2076306343079,39.23531460762024
3
+ Classical ML,0.883260602673595,0.886543909348442,1.1419353485107422,0.15742778778076172
4
+ LSTM,0.873906501409314,0.8779036827195468,220.66053867340088,14.189206600189209
perv.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ from pathlib import Path
6
+ import time
7
+ import torch
8
+ import pickle
9
+ from transformers import AutoTokenizer, BertForSequenceClassification
10
+ from sklearn.pipeline import Pipeline
11
+ from sklearn.preprocessing import LabelEncoder
12
+ from sklearn.metrics import f1_score, accuracy_score
13
+ import torch.nn as nn
14
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
15
+ import json
16
+ from torch.serialization import safe_globals
17
+ from sklearn.preprocessing import LabelEncoder
18
+
19
+ def run():
20
+ def preprocess_text(text):
21
+ if not isinstance(text, str):
22
+ return ""
23
+ return text.lower().replace('\n', ' ').replace('\r', ' ').strip()
24
+
25
+ # Класс для Classical ML модели
26
+ class ClassicalML:
27
+ def __init__(self):
28
+ self.pipeline = None
29
+ self.label_encoder = None
30
+
31
+ def predict(self, X):
32
+ start_time = time.time()
33
+ preds = self.pipeline.predict(X)
34
+ return self.label_encoder.inverse_transform(preds), time.time() - start_time
35
+
36
+ with torch.serialization.safe_globals([LabelEncoder]):
37
+ checkpoint = torch.load('models/lstm/model.pt', map_location=torch.device('cpu'), weights_only=False)
38
+
39
+ class Attention(nn.Module):
40
+ def __init__(self, hidden_dim):
41
+ super().__init__()
42
+ self.attention = nn.Linear(hidden_dim, 1) # Простой линейный слой
43
+
44
+ def forward(self, lstm_output):
45
+ # lstm_output shape: [batch_size, seq_len, hidden_dim]
46
+ attention_weights = torch.softmax(self.attention(lstm_output).squeeze(-1), dim=1)
47
+ context = torch.bmm(attention_weights.unsqueeze(1), lstm_output).squeeze(1)
48
+ return context
49
+
50
+ # Класс для LSTM модели
51
+ class LSTMTrainer:
52
+ def __init__(self):
53
+ self.model = None
54
+ self.vocab = None
55
+ self.label_encoder = None
56
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
57
+
58
+ def predict(self, X):
59
+ self.model.eval()
60
+ preds = []
61
+ start_time = time.time()
62
+ with torch.no_grad():
63
+ for text in X:
64
+ tokens = preprocess_text(text).split()
65
+ seq = [self.vocab.get(token, 0) for token in tokens]
66
+ if not seq:
67
+ seq = [0]
68
+ text_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(self.device)
69
+ length_tensor = torch.tensor([len(seq)], dtype=torch.long)
70
+ output = self.model(text_tensor, length_tensor)
71
+ preds.append(torch.argmax(output).item())
72
+ return self.label_encoder.inverse_transform(preds), time.time() - start_time
73
+
74
+ @classmethod
75
+ def load(cls, path='models/lstm'):
76
+ checkpoint = torch.load(
77
+ f'{path}/model.pt',
78
+ map_location=torch.device('cpu'),
79
+ weights_only=False
80
+ )
81
+
82
+ model = cls()
83
+ model.vocab = checkpoint['vocab']
84
+ model.label_encoder = checkpoint['label_encoder']
85
+
86
+ # Инициализация модели
87
+ model.model = LSTMModel(
88
+ len(model.vocab),
89
+ checkpoint['embed_dim'],
90
+ checkpoint['hidden_dim'],
91
+ len(model.label_encoder.classes_)
92
+ ).to(model.device)
93
+
94
+ # Адаптация state_dict
95
+ state_dict = checkpoint['model_state_dict']
96
+ new_state_dict = {}
97
+
98
+ for key, value in state_dict.items():
99
+ if key.startswith('attention.attention.'):
100
+ # Преобразуем ключи для соответствия Sequential
101
+ if 'weight' in key:
102
+ new_key = key.replace('attention.attention.', 'attention.attention.0.')
103
+ elif 'bias' in key:
104
+ new_key = key.replace('attention.attention.', 'attention.attention.0.')
105
+ new_state_dict[new_key] = value
106
+ else:
107
+ new_state_dict[key] = value
108
+
109
+ model.model.load_state_dict(new_state_dict, strict=False)
110
+ return model
111
+
112
+ # Класс для BERT модели
113
+ class BERTClassifier:
114
+ def __init__(self):
115
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
116
+ self.tokenizer = None
117
+ self.model = None
118
+ self.label_encoder = None
119
+
120
+ def predict(self, X):
121
+ self.model.eval()
122
+ preds = []
123
+ start_time = time.time()
124
+ with torch.no_grad():
125
+ for text in X:
126
+ inputs = self.tokenizer(
127
+ text,
128
+ padding=True,
129
+ truncation=True,
130
+ max_length=128,
131
+ return_tensors="pt"
132
+ ).to(self.device) # Перемещаем входные данные на то же устройство, что и модель
133
+ outputs = self.model(**inputs)
134
+ preds.append(torch.argmax(outputs.logits).item())
135
+ return self.label_encoder.inverse_transform(preds), time.time() - start_time
136
+
137
+ # Функция для визуализации attention
138
+ def plot_attention(text, model, tokenizer):
139
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
140
+ outputs = model(**inputs, output_attentions=True)
141
+ attention = outputs.attentions[-1].squeeze(0).mean(dim=0)
142
+ tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
143
+
144
+ plt.figure(figsize=(10, 8))
145
+ sns.heatmap(attention.detach().cpu().numpy(),
146
+ xticklabels=tokens,
147
+ yticklabels=tokens,
148
+ cmap="YlGnBu")
149
+ plt.title("Attention Scores")
150
+ st.pyplot(plt)
151
+
152
+ @st.cache_resource
153
+ def load_models():
154
+ # Classical ML
155
+ classical_ml = ClassicalML()
156
+ with open('models/classical_ml/pipeline.pkl', 'rb') as f:
157
+ classical_ml.pipeline = pickle.load(f)
158
+ with open('models/classical_ml/label_encoder.pkl', 'rb') as f:
159
+ classical_ml.label_encoder = pickle.load(f)
160
+
161
+ # LSTM (с обработкой ошибки весов)
162
+ lstm = LSTMTrainer()
163
+ try:
164
+ # Пробуем загрузить с weights_only=True (безопасный вариант)
165
+ checkpoint = torch.load(
166
+ 'models/lstm/model.pt',
167
+ map_location=torch.device('cpu'), # Явно указываем CPU
168
+ weights_only=True
169
+ )
170
+ except:
171
+ # Если не получилось, загружаем с явным разрешением LabelEncoder
172
+ with safe_globals([LabelEncoder]):
173
+ checkpoint = torch.load(
174
+ 'models/lstm/model.pt',
175
+ map_location=torch.device('cpu'), # Явно указываем CPU
176
+ weights_only=False
177
+ )
178
+
179
+ lstm.vocab = checkpoint['vocab']
180
+ lstm.label_encoder = checkpoint['label_encoder']
181
+ lstm.model = LSTMModel(
182
+ len(lstm.vocab),
183
+ checkpoint['embed_dim'],
184
+ checkpoint['hidden_dim'],
185
+ len(lstm.label_encoder.classes_)
186
+ ).to(lstm.device) # Модель будет перенесена на устройство (CPU или GPU)
187
+ lstm.model.load_state_dict(checkpoint['model_state_dict'])
188
+
189
+ # BERT
190
+ bert = BERTClassifier()
191
+ bert.tokenizer = AutoTokenizer.from_pretrained('models/bert')
192
+ bert.model = BertForSequenceClassification.from_pretrained('models/bert')
193
+ bert.model.to(bert.device) # Перемещаем модель на нужное устройство после загрузки
194
+ with open('models/bert/label_encoder.pkl', 'rb') as f:
195
+ bert.label_encoder = pickle.load(f)
196
+
197
+ return classical_ml, lstm, bert
198
+
199
+ # Класс LSTM модели (добавлен для полноты)
200
+ class LSTMModel(nn.Module):
201
+ def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim):
202
+ super().__init__()
203
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
204
+ self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
205
+ self.attention = Attention(hidden_dim)
206
+ self.fc = nn.Linear(hidden_dim, output_dim)
207
+ self.dropout = nn.Dropout(0.5)
208
+
209
+ def forward(self, text, lengths):
210
+ embedded = self.embedding(text)
211
+ packed = pack_padded_sequence(
212
+ embedded,
213
+ lengths.cpu(), # Убедимся, что lengths на CPU
214
+ batch_first=True,
215
+ enforce_sorted=False
216
+ )
217
+ packed_output, (hidden, cell) = self.lstm(packed)
218
+ output, _ = pad_packed_sequence(packed_output, batch_first=True)
219
+ context = self.attention(output)
220
+ return self.fc(self.dropout(context))
221
+
222
+ # Основное приложение
223
+ def main():
224
+ st.title("Анализ отзывов медицинских учреждений")
225
+
226
+ # Загрузка моделей
227
+ classical_ml, lstm, bert = load_models()
228
+
229
+ # Примерные метрики (замените на реальные из вашего обучения)
230
+ metrics = {
231
+ 'Classical ML': {'f1_macro': 0.85, 'inference_time': 0.01},
232
+ 'LSTM': {'f1_macro': 0.87, 'inference_time': 0.12},
233
+ 'BERT': {'f1_macro': 0.92, 'inference_time': 0.05}
234
+ }
235
+ metrics_df = pd.DataFrame.from_dict(metrics, orient='index')
236
+
237
+ # Поле ввода текста
238
+ user_input = st.text_area("Введите ваш отзыв:", "Очень хорошая клиника, внимательные врачи")
239
+
240
+ if st.button("Проанализировать отзыв"):
241
+ if user_input:
242
+ # Добавляем категорию для совместимости
243
+ input_with_category = f"Поликлиники стоматологические {user_input}"
244
+
245
+ with st.spinner('Обработка...'):
246
+ # Получаем предсказания
247
+ ml_pred, ml_time = classical_ml.predict([input_with_category])
248
+ lstm_pred, lstm_time = lstm.predict([input_with_category])
249
+ bert_pred, bert_time = bert.predict([input_with_category])
250
+
251
+ # Вывод результатов в три колонки
252
+ col1, col2, col3 = st.columns(3)
253
+
254
+ with col1:
255
+ st.subheader("Classical ML")
256
+ st.metric("Предсказание", ml_pred[0])
257
+ st.metric("Время (сек)", f"{ml_time:.4f}")
258
+
259
+ with col2:
260
+ st.subheader("LSTM")
261
+ st.metric("Предсказание", lstm_pred[0])
262
+ st.metric("Время (сек)", f"{lstm_time:.4f}")
263
+
264
+ with col3:
265
+ st.subheader("BERT")
266
+ st.metric("Предсказание", bert_pred[0])
267
+ st.metric("Время (сек)", f"{bert_time:.4f}")
268
+
269
+ # Визуализация attention для BERT
270
+ st.header("Attention-механизм BERT")
271
+ plot_attention(user_input, bert.model, bert.tokenizer)
272
+
273
+ # Сравнительная таблица метрик
274
+ st.header("Сравнение моделей")
275
+ st.dataframe(metrics_df.style.highlight_max(axis=0))
276
+
277
+ # Графики метрик
278
+ st.header("Визуализация метрик")
279
+ fig, ax = plt.subplots(1, 2, figsize=(15, 5))
280
+
281
+ # График F1-score
282
+ metrics_df['f1_macro'].plot(kind='bar', ax=ax[0], color='skyblue')
283
+ ax[0].set_title('F1-macro score')
284
+ ax[0].set_ylabel('Score')
285
+
286
+ # График времени предсказания
287
+ metrics_df['inference_time'].plot(kind='bar', ax=ax[1], color='salmon')
288
+ ax[1].set_title('Время предсказания (сек)')
289
+ ax[1].set_ylabel('Seconds')
290
+
291
+ st.pyplot(fig)
292
+
293
+ if __name__ == "__main__":
294
+ main()
training_metrics.png ADDED