File size: 3,175 Bytes
cd123bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import json
import torch
from nltk.tokenize import word_tokenize
import argparse
from src.models.models import TransformerClassifier, MambaClassifier, LSTMClassifier
SAVE_DIR = "pretrained"
MODEL_PATH = f"{SAVE_DIR}/best_model.pth"
CONFIG_PATH = f"{SAVE_DIR}/config.json"
VOCAB_PATH = f"{SAVE_DIR}/vocab.json"
ID_TO_LABEL = {0: "Negative", 1: "Positive"}
def load_artifacts():
with open(CONFIG_PATH, 'r') as f:
config = json.load(f)
with open(VOCAB_PATH, 'r') as f:
vocab = json.load(f)
model_type = config['model_type']
model_params = config['model_params']
if model_type == 'Transformer':
model = TransformerClassifier(**model_params)
elif model_type == 'Mamba':
model = MambaClassifier(**model_params)
elif model_type == 'LSTM':
model = LSTMClassifier(**model_params)
else:
raise ValueError("Неизвестный тип модели в файле конфигурации.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval()
return model, vocab, config, device
def preprocess_text(text, vocab, max_len):
tokens = word_tokenize(text.lower())
ids = [vocab.get(token, vocab['<UNK>']) for token in tokens]
if len(ids) < max_len:
ids.extend([vocab['<PAD>']] * (max_len - len(ids)))
else:
ids = ids[:max_len]
return torch.tensor(ids).unsqueeze(0)
def predict(text, model, vocab, config, device):
input_tensor = preprocess_text(text, vocab, config['max_seq_len'])
input_tensor = input_tensor.to(device)
with torch.no_grad():
outputs = model(input_tensor)
probabilities = torch.softmax(outputs, dim=1)
prediction_id = torch.argmax(probabilities, dim=1).item()
predicted_label = ID_TO_LABEL[prediction_id]
confidence = probabilities[0][prediction_id].item()
return predicted_label, confidence
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Предсказать тональность текста с помощью обученной модели.")
parser.add_argument("text", type=str, help="Текст для анализа (в кавычках).")
args = parser.parse_args()
print("Загрузка модели и артефактов...")
try:
loaded_model, loaded_vocab, loaded_config, device = load_artifacts()
print(f"Модель '{loaded_config['model_type']}' успешно загружена на устройство {device}.")
except FileNotFoundError:
print("\nОШИБКА: Файлы модели не найдены!")
print("Сначала запустите скрипт train.py для обучения и сохранения модели.")
exit()
label, conf = predict(args.text, loaded_model, loaded_vocab, loaded_config, device)
print("\n--- Результат предсказания ---")
print(f"Текст: '{args.text}'")
print(f"Тональность: {label}")
print(f"Уверенность: {conf:.2%}")
|