astart01 commited on
Commit
b104726
·
verified ·
1 Parent(s): 3839ac0

Update perv.py

Browse files
Files changed (1) hide show
  1. perv.py +293 -293
perv.py CHANGED
@@ -1,294 +1,294 @@
1
- import streamlit as st
2
- import pandas as pd
3
- import matplotlib.pyplot as plt
4
- import seaborn as sns
5
- from pathlib import Path
6
- import time
7
- import torch
8
- import pickle
9
- from transformers import AutoTokenizer, BertForSequenceClassification
10
- from sklearn.pipeline import Pipeline
11
- from sklearn.preprocessing import LabelEncoder
12
- from sklearn.metrics import f1_score, accuracy_score
13
- import torch.nn as nn
14
- from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
15
- import json
16
- from torch.serialization import safe_globals
17
- from sklearn.preprocessing import LabelEncoder
18
-
19
- def run():
20
- def preprocess_text(text):
21
- if not isinstance(text, str):
22
- return ""
23
- return text.lower().replace('\n', ' ').replace('\r', ' ').strip()
24
-
25
- # Класс для Classical ML модели
26
- class ClassicalML:
27
- def __init__(self):
28
- self.pipeline = None
29
- self.label_encoder = None
30
-
31
- def predict(self, X):
32
- start_time = time.time()
33
- preds = self.pipeline.predict(X)
34
- return self.label_encoder.inverse_transform(preds), time.time() - start_time
35
-
36
- with torch.serialization.safe_globals([LabelEncoder]):
37
- checkpoint = torch.load('models/lstm/model.pt', map_location=torch.device('cpu'), weights_only=False)
38
-
39
- class Attention(nn.Module):
40
- def __init__(self, hidden_dim):
41
- super().__init__()
42
- self.attention = nn.Linear(hidden_dim, 1) # Простой линейный слой
43
-
44
- def forward(self, lstm_output):
45
- # lstm_output shape: [batch_size, seq_len, hidden_dim]
46
- attention_weights = torch.softmax(self.attention(lstm_output).squeeze(-1), dim=1)
47
- context = torch.bmm(attention_weights.unsqueeze(1), lstm_output).squeeze(1)
48
- return context
49
-
50
- # Класс для LSTM модели
51
- class LSTMTrainer:
52
- def __init__(self):
53
- self.model = None
54
- self.vocab = None
55
- self.label_encoder = None
56
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
57
-
58
- def predict(self, X):
59
- self.model.eval()
60
- preds = []
61
- start_time = time.time()
62
- with torch.no_grad():
63
- for text in X:
64
- tokens = preprocess_text(text).split()
65
- seq = [self.vocab.get(token, 0) for token in tokens]
66
- if not seq:
67
- seq = [0]
68
- text_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(self.device)
69
- length_tensor = torch.tensor([len(seq)], dtype=torch.long)
70
- output = self.model(text_tensor, length_tensor)
71
- preds.append(torch.argmax(output).item())
72
- return self.label_encoder.inverse_transform(preds), time.time() - start_time
73
-
74
- @classmethod
75
- def load(cls, path='models/lstm'):
76
- checkpoint = torch.load(
77
- f'{path}/model.pt',
78
- map_location=torch.device('cpu'),
79
- weights_only=False
80
- )
81
-
82
- model = cls()
83
- model.vocab = checkpoint['vocab']
84
- model.label_encoder = checkpoint['label_encoder']
85
-
86
- # Инициализация модели
87
- model.model = LSTMModel(
88
- len(model.vocab),
89
- checkpoint['embed_dim'],
90
- checkpoint['hidden_dim'],
91
- len(model.label_encoder.classes_)
92
- ).to(model.device)
93
-
94
- # Адаптация state_dict
95
- state_dict = checkpoint['model_state_dict']
96
- new_state_dict = {}
97
-
98
- for key, value in state_dict.items():
99
- if key.startswith('attention.attention.'):
100
- # Преобразуем ключи для соответствия Sequential
101
- if 'weight' in key:
102
- new_key = key.replace('attention.attention.', 'attention.attention.0.')
103
- elif 'bias' in key:
104
- new_key = key.replace('attention.attention.', 'attention.attention.0.')
105
- new_state_dict[new_key] = value
106
- else:
107
- new_state_dict[key] = value
108
-
109
- model.model.load_state_dict(new_state_dict, strict=False)
110
- return model
111
-
112
- # Класс для BERT модели
113
- class BERTClassifier:
114
- def __init__(self):
115
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
116
- self.tokenizer = None
117
- self.model = None
118
- self.label_encoder = None
119
-
120
- def predict(self, X):
121
- self.model.eval()
122
- preds = []
123
- start_time = time.time()
124
- with torch.no_grad():
125
- for text in X:
126
- inputs = self.tokenizer(
127
- text,
128
- padding=True,
129
- truncation=True,
130
- max_length=128,
131
- return_tensors="pt"
132
- ).to(self.device) # Перемещаем входные данные на то же устройство, что и модель
133
- outputs = self.model(**inputs)
134
- preds.append(torch.argmax(outputs.logits).item())
135
- return self.label_encoder.inverse_transform(preds), time.time() - start_time
136
-
137
- # Функция для визуализации attention
138
- def plot_attention(text, model, tokenizer):
139
- inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
140
- outputs = model(**inputs, output_attentions=True)
141
- attention = outputs.attentions[-1].squeeze(0).mean(dim=0)
142
- tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
143
-
144
- plt.figure(figsize=(10, 8))
145
- sns.heatmap(attention.detach().cpu().numpy(),
146
- xticklabels=tokens,
147
- yticklabels=tokens,
148
- cmap="YlGnBu")
149
- plt.title("Attention Scores")
150
- st.pyplot(plt)
151
-
152
- @st.cache_resource
153
- def load_models():
154
- # Classical ML
155
- classical_ml = ClassicalML()
156
- with open('models/classical_ml/pipeline.pkl', 'rb') as f:
157
- classical_ml.pipeline = pickle.load(f)
158
- with open('models/classical_ml/label_encoder.pkl', 'rb') as f:
159
- classical_ml.label_encoder = pickle.load(f)
160
-
161
- # LSTM (с обработкой ошибки весов)
162
- lstm = LSTMTrainer()
163
- try:
164
- # Пробуем загрузить с weights_only=True (безопасный вариант)
165
- checkpoint = torch.load(
166
- 'models/lstm/model.pt',
167
- map_location=torch.device('cpu'), # Явно указываем CPU
168
- weights_only=True
169
- )
170
- except:
171
- # Если не получилось, загружаем с явным разрешением LabelEncoder
172
- with safe_globals([LabelEncoder]):
173
- checkpoint = torch.load(
174
- 'models/lstm/model.pt',
175
- map_location=torch.device('cpu'), # Явно указываем CPU
176
- weights_only=False
177
- )
178
-
179
- lstm.vocab = checkpoint['vocab']
180
- lstm.label_encoder = checkpoint['label_encoder']
181
- lstm.model = LSTMModel(
182
- len(lstm.vocab),
183
- checkpoint['embed_dim'],
184
- checkpoint['hidden_dim'],
185
- len(lstm.label_encoder.classes_)
186
- ).to(lstm.device) # Модель будет перенесена на устройство (CPU или GPU)
187
- lstm.model.load_state_dict(checkpoint['model_state_dict'])
188
-
189
- # BERT
190
- bert = BERTClassifier()
191
- bert.tokenizer = AutoTokenizer.from_pretrained('models/bert')
192
- bert.model = BertForSequenceClassification.from_pretrained('models/bert')
193
- bert.model.to(bert.device) # Перемещаем модель на нужное устройство после загрузки
194
- with open('models/bert/label_encoder.pkl', 'rb') as f:
195
- bert.label_encoder = pickle.load(f)
196
-
197
- return classical_ml, lstm, bert
198
-
199
- # Класс LSTM модели (добавлен для полноты)
200
- class LSTMModel(nn.Module):
201
- def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim):
202
- super().__init__()
203
- self.embedding = nn.Embedding(vocab_size, embed_dim)
204
- self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
205
- self.attention = Attention(hidden_dim)
206
- self.fc = nn.Linear(hidden_dim, output_dim)
207
- self.dropout = nn.Dropout(0.5)
208
-
209
- def forward(self, text, lengths):
210
- embedded = self.embedding(text)
211
- packed = pack_padded_sequence(
212
- embedded,
213
- lengths.cpu(), # Убедимся, что lengths на CPU
214
- batch_first=True,
215
- enforce_sorted=False
216
- )
217
- packed_output, (hidden, cell) = self.lstm(packed)
218
- output, _ = pad_packed_sequence(packed_output, batch_first=True)
219
- context = self.attention(output)
220
- return self.fc(self.dropout(context))
221
-
222
- # Основное приложение
223
- def main():
224
- st.title("Анализ отзывов медицинских учреждений")
225
-
226
- # Загрузка моделей
227
- classical_ml, lstm, bert = load_models()
228
-
229
- # Примерные метрики (замените на реальные из вашего обучения)
230
- metrics = {
231
- 'Classical ML': {'f1_macro': 0.85, 'inference_time': 0.01},
232
- 'LSTM': {'f1_macro': 0.87, 'inference_time': 0.12},
233
- 'BERT': {'f1_macro': 0.92, 'inference_time': 0.05}
234
- }
235
- metrics_df = pd.DataFrame.from_dict(metrics, orient='index')
236
-
237
- # Поле ввода текста
238
- user_input = st.text_area("Введите ваш отзыв:", "Очень хорошая клиника, внимательные врачи")
239
-
240
- if st.button("Проанализировать отзыв"):
241
- if user_input:
242
- # Добавляем категорию для совместимости
243
- input_with_category = f"Поликлиники стоматологические {user_input}"
244
-
245
- with st.spinner('Обработка...'):
246
- # Получаем предсказания
247
- ml_pred, ml_time = classical_ml.predict([input_with_category])
248
- lstm_pred, lstm_time = lstm.predict([input_with_category])
249
- bert_pred, bert_time = bert.predict([input_with_category])
250
-
251
- # Вывод результатов в три колонки
252
- col1, col2, col3 = st.columns(3)
253
-
254
- with col1:
255
- st.subheader("Classical ML")
256
- st.metric("Предсказание", ml_pred[0])
257
- st.metric("Время (сек)", f"{ml_time:.4f}")
258
-
259
- with col2:
260
- st.subheader("LSTM")
261
- st.metric("Предсказание", lstm_pred[0])
262
- st.metric("Время (сек)", f"{lstm_time:.4f}")
263
-
264
- with col3:
265
- st.subheader("BERT")
266
- st.metric("Предсказание", bert_pred[0])
267
- st.metric("Время (сек)", f"{bert_time:.4f}")
268
-
269
- # Визуализация attention для BERT
270
- st.header("Attention-механизм BERT")
271
- plot_attention(user_input, bert.model, bert.tokenizer)
272
-
273
- # Сравнительная таблица метрик
274
- st.header("Сравнение моделей")
275
- st.dataframe(metrics_df.style.highlight_max(axis=0))
276
-
277
- # Графики метрик
278
- st.header("Визуализация метрик")
279
- fig, ax = plt.subplots(1, 2, figsize=(15, 5))
280
-
281
- # График F1-score
282
- metrics_df['f1_macro'].plot(kind='bar', ax=ax[0], color='skyblue')
283
- ax[0].set_title('F1-macro score')
284
- ax[0].set_ylabel('Score')
285
-
286
- # График времени предсказания
287
- metrics_df['inference_time'].plot(kind='bar', ax=ax[1], color='salmon')
288
- ax[1].set_title('Время предсказания (сек)')
289
- ax[1].set_ylabel('Seconds')
290
-
291
- st.pyplot(fig)
292
-
293
- if __name__ == "__main__":
294
  main()
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ from pathlib import Path
6
+ import time
7
+ import torch
8
+ import pickle
9
+ from transformers import AutoTokenizer, BertForSequenceClassification
10
+ from sklearn.pipeline import Pipeline
11
+ from sklearn.preprocessing import LabelEncoder
12
+ from sklearn.metrics import f1_score, accuracy_score
13
+ import torch.nn as nn
14
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
15
+ import json
16
+ from torch.serialization import safe_globals
17
+ from sklearn.preprocessing import LabelEncoder
18
+
19
+ def run():
20
+ def preprocess_text(text):
21
+ if not isinstance(text, str):
22
+ return ""
23
+ return text.lower().replace('\n', ' ').replace('\r', ' ').strip()
24
+
25
+ # Класс для Classical ML модели
26
+ class ClassicalML:
27
+ def __init__(self):
28
+ self.pipeline = None
29
+ self.label_encoder = None
30
+
31
+ def predict(self, X):
32
+ start_time = time.time()
33
+ preds = self.pipeline.predict(X)
34
+ return self.label_encoder.inverse_transform(preds), time.time() - start_time
35
+
36
+ with torch.serialization.safe_globals([LabelEncoder]):
37
+ checkpoint = torch.load('lstm/model.pt', map_location=torch.device('cpu'), weights_only=False)
38
+
39
+ class Attention(nn.Module):
40
+ def __init__(self, hidden_dim):
41
+ super().__init__()
42
+ self.attention = nn.Linear(hidden_dim, 1) # Простой линейный слой
43
+
44
+ def forward(self, lstm_output):
45
+ # lstm_output shape: [batch_size, seq_len, hidden_dim]
46
+ attention_weights = torch.softmax(self.attention(lstm_output).squeeze(-1), dim=1)
47
+ context = torch.bmm(attention_weights.unsqueeze(1), lstm_output).squeeze(1)
48
+ return context
49
+
50
+ # Класс для LSTM модели
51
+ class LSTMTrainer:
52
+ def __init__(self):
53
+ self.model = None
54
+ self.vocab = None
55
+ self.label_encoder = None
56
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
57
+
58
+ def predict(self, X):
59
+ self.model.eval()
60
+ preds = []
61
+ start_time = time.time()
62
+ with torch.no_grad():
63
+ for text in X:
64
+ tokens = preprocess_text(text).split()
65
+ seq = [self.vocab.get(token, 0) for token in tokens]
66
+ if not seq:
67
+ seq = [0]
68
+ text_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(self.device)
69
+ length_tensor = torch.tensor([len(seq)], dtype=torch.long)
70
+ output = self.model(text_tensor, length_tensor)
71
+ preds.append(torch.argmax(output).item())
72
+ return self.label_encoder.inverse_transform(preds), time.time() - start_time
73
+
74
+ @classmethod
75
+ def load(cls, path='lstm'):
76
+ checkpoint = torch.load(
77
+ f'{path}/model.pt',
78
+ map_location=torch.device('cpu'),
79
+ weights_only=False
80
+ )
81
+
82
+ model = cls()
83
+ model.vocab = checkpoint['vocab']
84
+ model.label_encoder = checkpoint['label_encoder']
85
+
86
+ # Инициализация модели
87
+ model.model = LSTMModel(
88
+ len(model.vocab),
89
+ checkpoint['embed_dim'],
90
+ checkpoint['hidden_dim'],
91
+ len(model.label_encoder.classes_)
92
+ ).to(model.device)
93
+
94
+ # Адаптация state_dict
95
+ state_dict = checkpoint['model_state_dict']
96
+ new_state_dict = {}
97
+
98
+ for key, value in state_dict.items():
99
+ if key.startswith('attention.attention.'):
100
+ # Преобразуем ключи для соответствия Sequential
101
+ if 'weight' in key:
102
+ new_key = key.replace('attention.attention.', 'attention.attention.0.')
103
+ elif 'bias' in key:
104
+ new_key = key.replace('attention.attention.', 'attention.attention.0.')
105
+ new_state_dict[new_key] = value
106
+ else:
107
+ new_state_dict[key] = value
108
+
109
+ model.model.load_state_dict(new_state_dict, strict=False)
110
+ return model
111
+
112
+ # Класс для BERT модели
113
+ class BERTClassifier:
114
+ def __init__(self):
115
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
116
+ self.tokenizer = None
117
+ self.model = None
118
+ self.label_encoder = None
119
+
120
+ def predict(self, X):
121
+ self.model.eval()
122
+ preds = []
123
+ start_time = time.time()
124
+ with torch.no_grad():
125
+ for text in X:
126
+ inputs = self.tokenizer(
127
+ text,
128
+ padding=True,
129
+ truncation=True,
130
+ max_length=128,
131
+ return_tensors="pt"
132
+ ).to(self.device) # Перемещаем входные данные на то же устройство, что и модель
133
+ outputs = self.model(**inputs)
134
+ preds.append(torch.argmax(outputs.logits).item())
135
+ return self.label_encoder.inverse_transform(preds), time.time() - start_time
136
+
137
+ # Функция для визуализации attention
138
+ def plot_attention(text, model, tokenizer):
139
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
140
+ outputs = model(**inputs, output_attentions=True)
141
+ attention = outputs.attentions[-1].squeeze(0).mean(dim=0)
142
+ tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
143
+
144
+ plt.figure(figsize=(10, 8))
145
+ sns.heatmap(attention.detach().cpu().numpy(),
146
+ xticklabels=tokens,
147
+ yticklabels=tokens,
148
+ cmap="YlGnBu")
149
+ plt.title("Attention Scores")
150
+ st.pyplot(plt)
151
+
152
+ @st.cache_resource
153
+ def load_models():
154
+ # Classical ML
155
+ classical_ml = ClassicalML()
156
+ with open('ml/pipeline.pkl', 'rb') as f:
157
+ classical_ml.pipeline = pickle.load(f)
158
+ with open('ml/label_encoder.pkl', 'rb') as f:
159
+ classical_ml.label_encoder = pickle.load(f)
160
+
161
+ # LSTM (с обработкой ошибки весов)
162
+ lstm = LSTMTrainer()
163
+ try:
164
+ # Пробуем загрузить с weights_only=True (безопасный вариант)
165
+ checkpoint = torch.load(
166
+ 'lstm/model.pt',
167
+ map_location=torch.device('cpu'), # Явно указываем CPU
168
+ weights_only=True
169
+ )
170
+ except:
171
+ # Если не получилось, загружаем с явным разрешением LabelEncoder
172
+ with safe_globals([LabelEncoder]):
173
+ checkpoint = torch.load(
174
+ 'lstm/model.pt',
175
+ map_location=torch.device('cpu'), # Явно указываем CPU
176
+ weights_only=False
177
+ )
178
+
179
+ lstm.vocab = checkpoint['vocab']
180
+ lstm.label_encoder = checkpoint['label_encoder']
181
+ lstm.model = LSTMModel(
182
+ len(lstm.vocab),
183
+ checkpoint['embed_dim'],
184
+ checkpoint['hidden_dim'],
185
+ len(lstm.label_encoder.classes_)
186
+ ).to(lstm.device) # Модель будет перенесена на устройство (CPU или GPU)
187
+ lstm.model.load_state_dict(checkpoint['model_state_dict'])
188
+
189
+ # BERT
190
+ bert = BERTClassifier()
191
+ bert.tokenizer = AutoTokenizer.from_pretrained('bert1')
192
+ bert.model = BertForSequenceClassification.from_pretrained('bert1')
193
+ bert.model.to(bert.device) # Перемещаем модель на нужное устройство после загрузки
194
+ with open('bert1/label_encoder.pkl', 'rb') as f:
195
+ bert.label_encoder = pickle.load(f)
196
+
197
+ return classical_ml, lstm, bert
198
+
199
+ # Класс LSTM модели (добавлен для полноты)
200
+ class LSTMModel(nn.Module):
201
+ def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim):
202
+ super().__init__()
203
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
204
+ self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
205
+ self.attention = Attention(hidden_dim)
206
+ self.fc = nn.Linear(hidden_dim, output_dim)
207
+ self.dropout = nn.Dropout(0.5)
208
+
209
+ def forward(self, text, lengths):
210
+ embedded = self.embedding(text)
211
+ packed = pack_padded_sequence(
212
+ embedded,
213
+ lengths.cpu(), # Убедимся, что lengths на CPU
214
+ batch_first=True,
215
+ enforce_sorted=False
216
+ )
217
+ packed_output, (hidden, cell) = self.lstm(packed)
218
+ output, _ = pad_packed_sequence(packed_output, batch_first=True)
219
+ context = self.attention(output)
220
+ return self.fc(self.dropout(context))
221
+
222
+ # Основное приложение
223
+ def main():
224
+ st.title("Анализ отзывов медицинских учреждений")
225
+
226
+ # Загрузка моделей
227
+ classical_ml, lstm, bert = load_models()
228
+
229
+ # Примерные метрики (замените на реальные из вашего обучения)
230
+ metrics = {
231
+ 'Classical ML': {'f1_macro': 0.85, 'inference_time': 0.01},
232
+ 'LSTM': {'f1_macro': 0.87, 'inference_time': 0.12},
233
+ 'BERT': {'f1_macro': 0.92, 'inference_time': 0.05}
234
+ }
235
+ metrics_df = pd.DataFrame.from_dict(metrics, orient='index')
236
+
237
+ # Поле ввода текста
238
+ user_input = st.text_area("Введите ваш отзыв:", "Очень хорошая клиника, внимательные врачи")
239
+
240
+ if st.button("Проанализировать отзыв"):
241
+ if user_input:
242
+ # Добавляем категорию для совместимости
243
+ input_with_category = f"Поликлиники стоматологические {user_input}"
244
+
245
+ with st.spinner('Обработка...'):
246
+ # Получаем предсказания
247
+ ml_pred, ml_time = classical_ml.predict([input_with_category])
248
+ lstm_pred, lstm_time = lstm.predict([input_with_category])
249
+ bert_pred, bert_time = bert.predict([input_with_category])
250
+
251
+ # Вывод результатов в три колонки
252
+ col1, col2, col3 = st.columns(3)
253
+
254
+ with col1:
255
+ st.subheader("Classical ML")
256
+ st.metric("Предсказание", ml_pred[0])
257
+ st.metric("Время (сек)", f"{ml_time:.4f}")
258
+
259
+ with col2:
260
+ st.subheader("LSTM")
261
+ st.metric("Предсказание", lstm_pred[0])
262
+ st.metric("Время (сек)", f"{lstm_time:.4f}")
263
+
264
+ with col3:
265
+ st.subheader("BERT")
266
+ st.metric("Предсказание", bert_pred[0])
267
+ st.metric("Время (сек)", f"{bert_time:.4f}")
268
+
269
+ # Визуализация attention для BERT
270
+ st.header("Attention-механизм BERT")
271
+ plot_attention(user_input, bert.model, bert.tokenizer)
272
+
273
+ # Сравнительная таблица метрик
274
+ st.header("Сравнение моделей")
275
+ st.dataframe(metrics_df.style.highlight_max(axis=0))
276
+
277
+ # Графики метрик
278
+ st.header("Визуализация метрик")
279
+ fig, ax = plt.subplots(1, 2, figsize=(15, 5))
280
+
281
+ # График F1-score
282
+ metrics_df['f1_macro'].plot(kind='bar', ax=ax[0], color='skyblue')
283
+ ax[0].set_title('F1-macro score')
284
+ ax[0].set_ylabel('Score')
285
+
286
+ # График времени предсказания
287
+ metrics_df['inference_time'].plot(kind='bar', ax=ax[1], color='salmon')
288
+ ax[1].set_title('Время предсказания (сек)')
289
+ ax[1].set_ylabel('Seconds')
290
+
291
+ st.pyplot(fig)
292
+
293
+ if __name__ == "__main__":
294
  main()