nlp-bert-team / pages /1_policlinic.py
VerVelVel's picture
images
961ee03
raw
history blame
5.06 kB
import streamlit as st
import joblib
import pandas as pd
from models.model1.Custom_class import TextPreprocessor
from pathlib import Path
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import time
project_root = Path(__file__).resolve().parents[1]
models_path = project_root / 'models'
sys.path.append(str(models_path))
from models.model1.lstm_preprocessor import TextPreprocessorWord2Vec
from models.model1.lstm_model import LSTMConcatAttention
# Load the trained pipeline
pipeline = joblib.load('models/model1/logistic_regression_pipeline.pkl')
# Streamlit application
st.title('Классификация отзывов на русском языке')
input_text = st.text_area('Введите текст отзыва')
device = 'cpu'
# Загрузка модели LSTM и словаря
@st.cache_resource
def load_lstm_model():
model = LSTMConcatAttention()
weights_path = models_path / 'model1' / 'lstm_weights'
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
lstm_model = load_lstm_model()
@st.cache_resource
def load_int_to_vocab():
vocab_path = models_path / 'model1' / 'lstm_vocab_to_int.pkl'
vocab_to_int = joblib.load(vocab_path)
int_to_vocab = {j:i for i, j in vocab_to_int.items()}
return int_to_vocab
int_to_vocab = load_int_to_vocab()
def plot_and_predict_lstm(input_text):
preprocessor_lstm = TextPreprocessorWord2Vec()
preprocessed = preprocessor_lstm.transform(input_text)
lstm_model.eval()
with torch.inference_mode():
pred, att_scores = lstm_model(preprocessed.long().unsqueeze(0))
lstm_pred = pred.sigmoid().item()
# Получить индексы слов, которые не равны <pad> и не имеют индекс 0
valid_indices = [i for i, x in enumerate(preprocessed) if x.item() != 0 and int_to_vocab[x.item()] != "<pad>"]
# Получить соответствующие оценки внимания и метки слов
valid_att_scores = att_scores.detach().cpu().numpy()[0][valid_indices]
valid_labels = [int_to_vocab[preprocessed[i].item()] for i in valid_indices]
# Упорядочить метки и оценки внимания по убыванию веса смысла
sorted_indices = np.argsort(valid_att_scores)
sorted_labels = [valid_labels[i] for i in sorted_indices]
sorted_att_scores = valid_att_scores[sorted_indices]
# Построить график с учетом только валидных меток
plt.figure(figsize=(4, 8))
plt.barh(np.arange(len(sorted_indices)), sorted_att_scores)
plt.yticks(ticks=np.arange(len(sorted_indices)), labels=sorted_labels)
return lstm_pred, plt
if st.button('Предсказать'):
start_time_lr = time.time()
prediction = pipeline.predict(pd.Series([input_text]))
pred_probe = pipeline.predict_proba(pd.Series([input_text]))
pred_proba_rounded = np.round(pred_probe, 2).flatten()
if prediction[0] == 0:
predicted_class = "POSITIVE"
else:
predicted_class = "NEGATIVE"
st.subheader('Предсказанный класс с помощью логистической регрессии и tf-idf')
end_time_lr = time.time()
time_lr = end_time_lr - start_time_lr
st.write(f'**{predicted_class}** с вероятностью {pred_proba_rounded[0]}')
st.write(f'Время выполнения расчетов {time_lr:.4f} секунд')
start_time_lstm = time.time()
lstm_pred, lstm_plot = plot_and_predict_lstm(input_text)
if lstm_pred > 0.5:
predicted_lstm_class = "POSITIVE"
else:
predicted_lstm_class = "NEGATIVE"
st.subheader('Предсказанный класс с помощью LSTM + Word2Vec + BahdanauAttention:')
end_time_lstm = time.time()
time_lstm = end_time_lstm - start_time_lstm
st.write(f'**{predicted_lstm_class}** с вероятностью {round(lstm_pred, 3)}')
st.write(f'Время выполнения расчетов {time_lstm:.4f} секунд')
st.pyplot(lstm_plot)
st.write("# Информация об обучении модели логистической регрессии и tf-idf:")
st.image(str(project_root / 'images/pipeline_logreg.png'))
st.write("Модель обучалась на предсказание 1 класса")
st.write("Размер датасета - 70597 текстов отзывов")
st.write("Проведена предобработка текста")
st.write("Метрики:")
st.image(str(project_root / 'images/log_reg_metrics.png'))
st.write("# Информация об обучении модели LSTM + Word2Vec + BahdanauAttention:")
st.write("Время обучения модели - 10 эпох")
st.write("Метрики на 10 эпохе:")
st.write("Train f1: 0.95, Val f1: 0.93")
st.write("Train accuracy: 0.94, Val accuracy: 0.92")