File size: 2,302 Bytes
ecbd4e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdb0abe
ecbd4e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import streamlit as st
import torch
import sys
from pathlib import Path
import requests
import time
import cv2
import numpy as np
from transformers import AutoTokenizer


st.write("# Оценка степени токсичности пользовательского сообщения")
# st.write("Здесь вы можете загрузить картинку со своего устройства, либо при помощи ссылки")

# Добавление пути к проекту и моделям
project_root = Path(__file__).resolve().parents[1]
models_path = project_root / 'models'
sys.path.append(str(models_path))
from models.model2.preprocess_text import TextPreprocessorBERT
from models.model2.model import BERTClassifier

device = 'cpu'

# Загрузка модели и словаря
@st.cache_resource
def load_model():
    model = BERTClassifier()
    weights_path = models_path / 'model2' / 'model_weights_new.pth'
    state_dict = torch.load(weights_path, map_location=device)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    return model

@st.cache_resource
def load_tokenizer():
    return AutoTokenizer.from_pretrained('cointegrated/rubert-tiny-toxicity')

model = load_model()
tokenizer = load_tokenizer()

input_text = st.text_area('Введите текст сообщения')

if st.button('Предсказать'):
    # Применяем предобработку
    preprocessor = TextPreprocessorBERT()
    preprocessed_text = preprocessor.transform(input_text)
    
    # Токенизация
    tokens = tokenizer.encode_plus(
        preprocessed_text,
        add_special_tokens=True,
        truncation=True,
        max_length=100,
        padding='max_length',
        return_tensors='pt'
    )
    
    # Получаем input_ids и attention_mask из токенов
    input_ids = tokens['input_ids'].to(device)
    attention_mask = tokens['attention_mask'].to(device)
    
    # Предсказание
    with torch.no_grad():
        output = model(input_ids, attention_mask=attention_mask)
    
    # Интерпретация результата
    prediction = torch.sigmoid(output).item()
    st.write(f'Предсказанный класс токсичности: {prediction:.4f}')