Spaces:
Sleeping
Sleeping
File size: 2,298 Bytes
ecbd4e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import streamlit as st
import torch
import sys
from pathlib import Path
import requests
import time
import cv2
import numpy as np
from transformers import AutoTokenizer
st.write("# Оценка степени токсичности пользовательского сообщения")
# st.write("Здесь вы можете загрузить картинку со своего устройства, либо при помощи ссылки")
# Добавление пути к проекту и моделям
project_root = Path(__file__).resolve().parents[1]
models_path = project_root / 'models'
sys.path.append(str(models_path))
from models.model2.preprocess_text import TextPreprocessorBERT
from models.model2.model import BERTClassifier
device = 'cpu'
# Загрузка модели и словаря
@st.cache_resource
def load_model():
model = BERTClassifier()
weights_path = models_path / 'model2' / 'model_weights.pth'
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
@st.cache_resource
def load_tokenizer():
return AutoTokenizer.from_pretrained('cointegrated/rubert-tiny-toxicity')
model = load_model()
tokenizer = load_tokenizer()
input_text = st.text_area('Введите текст сообщения')
if st.button('Предсказать'):
# Применяем предобработку
preprocessor = TextPreprocessorBERT()
preprocessed_text = preprocessor.transform(input_text)
# Токенизация
tokens = tokenizer.encode_plus(
preprocessed_text,
add_special_tokens=True,
truncation=True,
max_length=100,
padding='max_length',
return_tensors='pt'
)
# Получаем input_ids и attention_mask из токенов
input_ids = tokens['input_ids'].to(device)
attention_mask = tokens['attention_mask'].to(device)
# Предсказание
with torch.no_grad():
output = model(input_ids, attention_mask=attention_mask)
# Интерпретация результата
prediction = torch.sigmoid(output).item()
st.write(f'Предсказанный класс токсичности: {prediction:.4f}')
|