|
import streamlit as st |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
from pathlib import Path |
|
import time |
|
import torch |
|
import pickle |
|
from transformers import AutoTokenizer, BertForSequenceClassification |
|
from sklearn.pipeline import Pipeline |
|
from sklearn.preprocessing import LabelEncoder |
|
from sklearn.metrics import f1_score, accuracy_score |
|
import torch.nn as nn |
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence |
|
import json |
|
|
|
from sklearn.preprocessing import LabelEncoder |
|
|
|
def run(): |
|
def preprocess_text(text): |
|
if not isinstance(text, str): |
|
return "" |
|
return text.lower().replace('\n', ' ').replace('\r', ' ').strip() |
|
|
|
|
|
class ClassicalML: |
|
def __init__(self): |
|
self.pipeline = None |
|
self.label_encoder = None |
|
|
|
def predict(self, X): |
|
start_time = time.time() |
|
preds = self.pipeline.predict(X) |
|
return self.label_encoder.inverse_transform(preds), time.time() - start_time |
|
|
|
checkpoint = torch.load('lstm/model.pt', map_location=torch.device('cpu')) |
|
|
|
class Attention(nn.Module): |
|
def __init__(self, hidden_dim): |
|
super().__init__() |
|
self.attention = nn.Linear(hidden_dim, 1) |
|
|
|
def forward(self, lstm_output): |
|
|
|
attention_weights = torch.softmax(self.attention(lstm_output).squeeze(-1), dim=1) |
|
context = torch.bmm(attention_weights.unsqueeze(1), lstm_output).squeeze(1) |
|
return context |
|
|
|
|
|
class LSTMTrainer: |
|
def __init__(self): |
|
self.model = None |
|
self.vocab = None |
|
self.label_encoder = None |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
def predict(self, X): |
|
self.model.eval() |
|
preds = [] |
|
start_time = time.time() |
|
with torch.no_grad(): |
|
for text in X: |
|
tokens = preprocess_text(text).split() |
|
seq = [self.vocab.get(token, 0) for token in tokens] |
|
if not seq: |
|
seq = [0] |
|
text_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(self.device) |
|
length_tensor = torch.tensor([len(seq)], dtype=torch.long) |
|
output = self.model(text_tensor, length_tensor) |
|
preds.append(torch.argmax(output).item()) |
|
return self.label_encoder.inverse_transform(preds), time.time() - start_time |
|
|
|
@classmethod |
|
def load(cls, path='lstm'): |
|
checkpoint = torch.load( |
|
f'{path}/model.pt', |
|
map_location=torch.device('cpu'), |
|
weights_only=False |
|
) |
|
|
|
model = cls() |
|
model.vocab = checkpoint['vocab'] |
|
model.label_encoder = checkpoint['label_encoder'] |
|
|
|
|
|
model.model = LSTMModel( |
|
len(model.vocab), |
|
checkpoint['embed_dim'], |
|
checkpoint['hidden_dim'], |
|
len(model.label_encoder.classes_) |
|
).to(model.device) |
|
|
|
|
|
state_dict = checkpoint['model_state_dict'] |
|
new_state_dict = {} |
|
|
|
for key, value in state_dict.items(): |
|
if key.startswith('attention.attention.'): |
|
|
|
if 'weight' in key: |
|
new_key = key.replace('attention.attention.', 'attention.attention.0.') |
|
elif 'bias' in key: |
|
new_key = key.replace('attention.attention.', 'attention.attention.0.') |
|
new_state_dict[new_key] = value |
|
else: |
|
new_state_dict[key] = value |
|
|
|
model.model.load_state_dict(new_state_dict, strict=False) |
|
return model |
|
|
|
|
|
class BERTClassifier: |
|
def __init__(self): |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.tokenizer = None |
|
self.model = None |
|
self.label_encoder = None |
|
|
|
def predict(self, X): |
|
self.model.eval() |
|
preds = [] |
|
start_time = time.time() |
|
with torch.no_grad(): |
|
for text in X: |
|
inputs = self.tokenizer( |
|
text, |
|
padding=True, |
|
truncation=True, |
|
max_length=128, |
|
return_tensors="pt" |
|
).to(self.device) |
|
outputs = self.model(**inputs) |
|
preds.append(torch.argmax(outputs.logits).item()) |
|
return self.label_encoder.inverse_transform(preds), time.time() - start_time |
|
|
|
|
|
def plot_attention(text, model, tokenizer): |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128) |
|
outputs = model(**inputs, output_attentions=True) |
|
attention = outputs.attentions[-1].squeeze(0).mean(dim=0) |
|
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) |
|
|
|
plt.figure(figsize=(10, 8)) |
|
sns.heatmap(attention.detach().cpu().numpy(), |
|
xticklabels=tokens, |
|
yticklabels=tokens, |
|
cmap="YlGnBu") |
|
plt.title("Attention Scores") |
|
st.pyplot(plt) |
|
|
|
@st.cache_resource |
|
def load_models(): |
|
|
|
classical_ml = ClassicalML() |
|
with open('ml/pipeline.pkl', 'rb') as f: |
|
classical_ml.pipeline = pickle.load(f) |
|
with open('ml/label_encoder.pkl', 'rb') as f: |
|
classical_ml.label_encoder = pickle.load(f) |
|
|
|
|
|
lstm = LSTMTrainer() |
|
try: |
|
|
|
checkpoint = torch.load( |
|
'lstm/model.pt', |
|
map_location=torch.device('cpu'), |
|
weights_only=True |
|
) |
|
except: |
|
|
|
with safe_globals([LabelEncoder]): |
|
checkpoint = torch.load( |
|
'lstm/model.pt', |
|
map_location=torch.device('cpu'), |
|
weights_only=False |
|
) |
|
|
|
lstm.vocab = checkpoint['vocab'] |
|
lstm.label_encoder = checkpoint['label_encoder'] |
|
lstm.model = LSTMModel( |
|
len(lstm.vocab), |
|
checkpoint['embed_dim'], |
|
checkpoint['hidden_dim'], |
|
len(lstm.label_encoder.classes_) |
|
).to(lstm.device) |
|
lstm.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
bert = BERTClassifier() |
|
bert.tokenizer = AutoTokenizer.from_pretrained('bert1') |
|
bert.model = BertForSequenceClassification.from_pretrained('bert1') |
|
bert.model.to(bert.device) |
|
with open('bert1/label_encoder.pkl', 'rb') as f: |
|
bert.label_encoder = pickle.load(f) |
|
|
|
return classical_ml, lstm, bert |
|
|
|
|
|
class LSTMModel(nn.Module): |
|
def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim): |
|
super().__init__() |
|
self.embedding = nn.Embedding(vocab_size, embed_dim) |
|
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True) |
|
self.attention = Attention(hidden_dim) |
|
self.fc = nn.Linear(hidden_dim, output_dim) |
|
self.dropout = nn.Dropout(0.5) |
|
|
|
def forward(self, text, lengths): |
|
embedded = self.embedding(text) |
|
packed = pack_padded_sequence( |
|
embedded, |
|
lengths.cpu(), |
|
batch_first=True, |
|
enforce_sorted=False |
|
) |
|
packed_output, (hidden, cell) = self.lstm(packed) |
|
output, _ = pad_packed_sequence(packed_output, batch_first=True) |
|
context = self.attention(output) |
|
return self.fc(self.dropout(context)) |
|
|
|
|
|
def main(): |
|
st.title("Анализ отзывов медицинских учреждений") |
|
|
|
|
|
classical_ml, lstm, bert = load_models() |
|
|
|
|
|
metrics = { |
|
'Classical ML': {'f1_macro': 0.85, 'inference_time': 0.01}, |
|
'LSTM': {'f1_macro': 0.87, 'inference_time': 0.12}, |
|
'BERT': {'f1_macro': 0.92, 'inference_time': 0.05} |
|
} |
|
metrics_df = pd.DataFrame.from_dict(metrics, orient='index') |
|
|
|
|
|
user_input = st.text_area("Введите ваш отзыв:", "Очень хорошая клиника, внимательные врачи") |
|
|
|
if st.button("Проанализировать отзыв"): |
|
if user_input: |
|
|
|
input_with_category = f"Поликлиники стоматологические {user_input}" |
|
|
|
with st.spinner('Обработка...'): |
|
|
|
ml_pred, ml_time = classical_ml.predict([input_with_category]) |
|
lstm_pred, lstm_time = lstm.predict([input_with_category]) |
|
bert_pred, bert_time = bert.predict([input_with_category]) |
|
|
|
|
|
col1, col2, col3 = st.columns(3) |
|
|
|
with col1: |
|
st.subheader("Classical ML") |
|
st.metric("Предсказание", ml_pred[0]) |
|
st.metric("Время (сек)", f"{ml_time:.4f}") |
|
|
|
with col2: |
|
st.subheader("LSTM") |
|
st.metric("Предсказание", lstm_pred[0]) |
|
st.metric("Время (сек)", f"{lstm_time:.4f}") |
|
|
|
with col3: |
|
st.subheader("BERT") |
|
st.metric("Предсказание", bert_pred[0]) |
|
st.metric("Время (сек)", f"{bert_time:.4f}") |
|
|
|
|
|
st.header("Attention-механизм BERT") |
|
plot_attention(user_input, bert.model, bert.tokenizer) |
|
|
|
|
|
st.header("Сравнение моделей") |
|
st.dataframe(metrics_df.style.highlight_max(axis=0)) |
|
|
|
|
|
st.header("Визуализация метрик") |
|
fig, ax = plt.subplots(1, 2, figsize=(15, 5)) |
|
|
|
|
|
metrics_df['f1_macro'].plot(kind='bar', ax=ax[0], color='skyblue') |
|
ax[0].set_title('F1-macro score') |
|
ax[0].set_ylabel('Score') |
|
|
|
|
|
metrics_df['inference_time'].plot(kind='bar', ax=ax[1], color='salmon') |
|
ax[1].set_title('Время предсказания (сек)') |
|
ax[1].set_ylabel('Seconds') |
|
|
|
st.pyplot(fig) |
|
|
|
if __name__ == "__main__": |
|
main() |