nlp-bert-team / pages /comments.py
VerVelVel's picture
bot and new weights
cdb0abe
raw
history blame
2.3 kB
import streamlit as st
import torch
import sys
from pathlib import Path
import requests
import time
import cv2
import numpy as np
from transformers import AutoTokenizer
st.write("# Оценка степени токсичности пользовательского сообщения")
# st.write("Здесь вы можете загрузить картинку со своего устройства, либо при помощи ссылки")
# Добавление пути к проекту и моделям
project_root = Path(__file__).resolve().parents[1]
models_path = project_root / 'models'
sys.path.append(str(models_path))
from models.model2.preprocess_text import TextPreprocessorBERT
from models.model2.model import BERTClassifier
device = 'cpu'
# Загрузка модели и словаря
@st.cache_resource
def load_model():
model = BERTClassifier()
weights_path = models_path / 'model2' / 'model_weights_new.pth'
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
@st.cache_resource
def load_tokenizer():
return AutoTokenizer.from_pretrained('cointegrated/rubert-tiny-toxicity')
model = load_model()
tokenizer = load_tokenizer()
input_text = st.text_area('Введите текст сообщения')
if st.button('Предсказать'):
# Применяем предобработку
preprocessor = TextPreprocessorBERT()
preprocessed_text = preprocessor.transform(input_text)
# Токенизация
tokens = tokenizer.encode_plus(
preprocessed_text,
add_special_tokens=True,
truncation=True,
max_length=100,
padding='max_length',
return_tensors='pt'
)
# Получаем input_ids и attention_mask из токенов
input_ids = tokens['input_ids'].to(device)
attention_mask = tokens['attention_mask'].to(device)
# Предсказание
with torch.no_grad():
output = model(input_ids, attention_mask=attention_mask)
# Интерпретация результата
prediction = torch.sigmoid(output).item()
st.write(f'Предсказанный класс токсичности: {prediction:.4f}')