|
import streamlit as st
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
from pathlib import Path
|
|
import time
|
|
import torch
|
|
import pickle
|
|
from transformers import AutoTokenizer, BertForSequenceClassification
|
|
from sklearn.pipeline import Pipeline
|
|
from sklearn.preprocessing import LabelEncoder
|
|
from sklearn.metrics import f1_score, accuracy_score
|
|
import torch.nn as nn
|
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
|
import json
|
|
from torch.serialization import safe_globals
|
|
from sklearn.preprocessing import LabelEncoder
|
|
|
|
def run():
|
|
def preprocess_text(text):
|
|
if not isinstance(text, str):
|
|
return ""
|
|
return text.lower().replace('\n', ' ').replace('\r', ' ').strip()
|
|
|
|
|
|
class ClassicalML:
|
|
def __init__(self):
|
|
self.pipeline = None
|
|
self.label_encoder = None
|
|
|
|
def predict(self, X):
|
|
start_time = time.time()
|
|
preds = self.pipeline.predict(X)
|
|
return self.label_encoder.inverse_transform(preds), time.time() - start_time
|
|
|
|
with torch.serialization.safe_globals([LabelEncoder]):
|
|
checkpoint = torch.load('models/lstm/model.pt', map_location=torch.device('cpu'), weights_only=False)
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self, hidden_dim):
|
|
super().__init__()
|
|
self.attention = nn.Linear(hidden_dim, 1)
|
|
|
|
def forward(self, lstm_output):
|
|
|
|
attention_weights = torch.softmax(self.attention(lstm_output).squeeze(-1), dim=1)
|
|
context = torch.bmm(attention_weights.unsqueeze(1), lstm_output).squeeze(1)
|
|
return context
|
|
|
|
|
|
class LSTMTrainer:
|
|
def __init__(self):
|
|
self.model = None
|
|
self.vocab = None
|
|
self.label_encoder = None
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
def predict(self, X):
|
|
self.model.eval()
|
|
preds = []
|
|
start_time = time.time()
|
|
with torch.no_grad():
|
|
for text in X:
|
|
tokens = preprocess_text(text).split()
|
|
seq = [self.vocab.get(token, 0) for token in tokens]
|
|
if not seq:
|
|
seq = [0]
|
|
text_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(self.device)
|
|
length_tensor = torch.tensor([len(seq)], dtype=torch.long)
|
|
output = self.model(text_tensor, length_tensor)
|
|
preds.append(torch.argmax(output).item())
|
|
return self.label_encoder.inverse_transform(preds), time.time() - start_time
|
|
|
|
@classmethod
|
|
def load(cls, path='models/lstm'):
|
|
checkpoint = torch.load(
|
|
f'{path}/model.pt',
|
|
map_location=torch.device('cpu'),
|
|
weights_only=False
|
|
)
|
|
|
|
model = cls()
|
|
model.vocab = checkpoint['vocab']
|
|
model.label_encoder = checkpoint['label_encoder']
|
|
|
|
|
|
model.model = LSTMModel(
|
|
len(model.vocab),
|
|
checkpoint['embed_dim'],
|
|
checkpoint['hidden_dim'],
|
|
len(model.label_encoder.classes_)
|
|
).to(model.device)
|
|
|
|
|
|
state_dict = checkpoint['model_state_dict']
|
|
new_state_dict = {}
|
|
|
|
for key, value in state_dict.items():
|
|
if key.startswith('attention.attention.'):
|
|
|
|
if 'weight' in key:
|
|
new_key = key.replace('attention.attention.', 'attention.attention.0.')
|
|
elif 'bias' in key:
|
|
new_key = key.replace('attention.attention.', 'attention.attention.0.')
|
|
new_state_dict[new_key] = value
|
|
else:
|
|
new_state_dict[key] = value
|
|
|
|
model.model.load_state_dict(new_state_dict, strict=False)
|
|
return model
|
|
|
|
|
|
class BERTClassifier:
|
|
def __init__(self):
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.tokenizer = None
|
|
self.model = None
|
|
self.label_encoder = None
|
|
|
|
def predict(self, X):
|
|
self.model.eval()
|
|
preds = []
|
|
start_time = time.time()
|
|
with torch.no_grad():
|
|
for text in X:
|
|
inputs = self.tokenizer(
|
|
text,
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=128,
|
|
return_tensors="pt"
|
|
).to(self.device)
|
|
outputs = self.model(**inputs)
|
|
preds.append(torch.argmax(outputs.logits).item())
|
|
return self.label_encoder.inverse_transform(preds), time.time() - start_time
|
|
|
|
|
|
def plot_attention(text, model, tokenizer):
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
|
|
outputs = model(**inputs, output_attentions=True)
|
|
attention = outputs.attentions[-1].squeeze(0).mean(dim=0)
|
|
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
|
|
|
|
plt.figure(figsize=(10, 8))
|
|
sns.heatmap(attention.detach().cpu().numpy(),
|
|
xticklabels=tokens,
|
|
yticklabels=tokens,
|
|
cmap="YlGnBu")
|
|
plt.title("Attention Scores")
|
|
st.pyplot(plt)
|
|
|
|
@st.cache_resource
|
|
def load_models():
|
|
|
|
classical_ml = ClassicalML()
|
|
with open('models/classical_ml/pipeline.pkl', 'rb') as f:
|
|
classical_ml.pipeline = pickle.load(f)
|
|
with open('models/classical_ml/label_encoder.pkl', 'rb') as f:
|
|
classical_ml.label_encoder = pickle.load(f)
|
|
|
|
|
|
lstm = LSTMTrainer()
|
|
try:
|
|
|
|
checkpoint = torch.load(
|
|
'models/lstm/model.pt',
|
|
map_location=torch.device('cpu'),
|
|
weights_only=True
|
|
)
|
|
except:
|
|
|
|
with safe_globals([LabelEncoder]):
|
|
checkpoint = torch.load(
|
|
'models/lstm/model.pt',
|
|
map_location=torch.device('cpu'),
|
|
weights_only=False
|
|
)
|
|
|
|
lstm.vocab = checkpoint['vocab']
|
|
lstm.label_encoder = checkpoint['label_encoder']
|
|
lstm.model = LSTMModel(
|
|
len(lstm.vocab),
|
|
checkpoint['embed_dim'],
|
|
checkpoint['hidden_dim'],
|
|
len(lstm.label_encoder.classes_)
|
|
).to(lstm.device)
|
|
lstm.model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
|
|
bert = BERTClassifier()
|
|
bert.tokenizer = AutoTokenizer.from_pretrained('models/bert')
|
|
bert.model = BertForSequenceClassification.from_pretrained('models/bert')
|
|
bert.model.to(bert.device)
|
|
with open('models/bert/label_encoder.pkl', 'rb') as f:
|
|
bert.label_encoder = pickle.load(f)
|
|
|
|
return classical_ml, lstm, bert
|
|
|
|
|
|
class LSTMModel(nn.Module):
|
|
def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim):
|
|
super().__init__()
|
|
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
|
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
|
|
self.attention = Attention(hidden_dim)
|
|
self.fc = nn.Linear(hidden_dim, output_dim)
|
|
self.dropout = nn.Dropout(0.5)
|
|
|
|
def forward(self, text, lengths):
|
|
embedded = self.embedding(text)
|
|
packed = pack_padded_sequence(
|
|
embedded,
|
|
lengths.cpu(),
|
|
batch_first=True,
|
|
enforce_sorted=False
|
|
)
|
|
packed_output, (hidden, cell) = self.lstm(packed)
|
|
output, _ = pad_packed_sequence(packed_output, batch_first=True)
|
|
context = self.attention(output)
|
|
return self.fc(self.dropout(context))
|
|
|
|
|
|
def main():
|
|
st.title("Анализ отзывов медицинских учреждений")
|
|
|
|
|
|
classical_ml, lstm, bert = load_models()
|
|
|
|
|
|
metrics = {
|
|
'Classical ML': {'f1_macro': 0.85, 'inference_time': 0.01},
|
|
'LSTM': {'f1_macro': 0.87, 'inference_time': 0.12},
|
|
'BERT': {'f1_macro': 0.92, 'inference_time': 0.05}
|
|
}
|
|
metrics_df = pd.DataFrame.from_dict(metrics, orient='index')
|
|
|
|
|
|
user_input = st.text_area("Введите ваш отзыв:", "Очень хорошая клиника, внимательные врачи")
|
|
|
|
if st.button("Проанализировать отзыв"):
|
|
if user_input:
|
|
|
|
input_with_category = f"Поликлиники стоматологические {user_input}"
|
|
|
|
with st.spinner('Обработка...'):
|
|
|
|
ml_pred, ml_time = classical_ml.predict([input_with_category])
|
|
lstm_pred, lstm_time = lstm.predict([input_with_category])
|
|
bert_pred, bert_time = bert.predict([input_with_category])
|
|
|
|
|
|
col1, col2, col3 = st.columns(3)
|
|
|
|
with col1:
|
|
st.subheader("Classical ML")
|
|
st.metric("Предсказание", ml_pred[0])
|
|
st.metric("Время (сек)", f"{ml_time:.4f}")
|
|
|
|
with col2:
|
|
st.subheader("LSTM")
|
|
st.metric("Предсказание", lstm_pred[0])
|
|
st.metric("Время (сек)", f"{lstm_time:.4f}")
|
|
|
|
with col3:
|
|
st.subheader("BERT")
|
|
st.metric("Предсказание", bert_pred[0])
|
|
st.metric("Время (сек)", f"{bert_time:.4f}")
|
|
|
|
|
|
st.header("Attention-механизм BERT")
|
|
plot_attention(user_input, bert.model, bert.tokenizer)
|
|
|
|
|
|
st.header("Сравнение моделей")
|
|
st.dataframe(metrics_df.style.highlight_max(axis=0))
|
|
|
|
|
|
st.header("Визуализация метрик")
|
|
fig, ax = plt.subplots(1, 2, figsize=(15, 5))
|
|
|
|
|
|
metrics_df['f1_macro'].plot(kind='bar', ax=ax[0], color='skyblue')
|
|
ax[0].set_title('F1-macro score')
|
|
ax[0].set_ylabel('Score')
|
|
|
|
|
|
metrics_df['inference_time'].plot(kind='bar', ax=ax[1], color='salmon')
|
|
ax[1].set_title('Время предсказания (сек)')
|
|
ax[1].set_ylabel('Seconds')
|
|
|
|
st.pyplot(fig)
|
|
|
|
if __name__ == "__main__":
|
|
main() |