File size: 3,175 Bytes
cd123bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import json
import torch
from nltk.tokenize import word_tokenize
import argparse

from src.models.models import TransformerClassifier, MambaClassifier, LSTMClassifier

SAVE_DIR = "pretrained"
MODEL_PATH = f"{SAVE_DIR}/best_model.pth"
CONFIG_PATH = f"{SAVE_DIR}/config.json"
VOCAB_PATH = f"{SAVE_DIR}/vocab.json"

ID_TO_LABEL = {0: "Negative", 1: "Positive"}

def load_artifacts():
    with open(CONFIG_PATH, 'r') as f:
        config = json.load(f)
    with open(VOCAB_PATH, 'r') as f:
        vocab = json.load(f)

    model_type = config['model_type']
    model_params = config['model_params']

    if model_type == 'Transformer':
        model = TransformerClassifier(**model_params)
    elif model_type == 'Mamba':
        model = MambaClassifier(**model_params)
    elif model_type == 'LSTM':
        model = LSTMClassifier(**model_params)
    else:
        raise ValueError("Неизвестный тип модели в файле конфигурации.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
    model.to(device)
    model.eval()

    return model, vocab, config, device

def preprocess_text(text, vocab, max_len):
    tokens = word_tokenize(text.lower())
    ids = [vocab.get(token, vocab['<UNK>']) for token in tokens]
    if len(ids) < max_len:
        ids.extend([vocab['<PAD>']] * (max_len - len(ids)))
    else:
        ids = ids[:max_len]
    return torch.tensor(ids).unsqueeze(0)

def predict(text, model, vocab, config, device):
    input_tensor = preprocess_text(text, vocab, config['max_seq_len'])
    input_tensor = input_tensor.to(device)

    with torch.no_grad():
        outputs = model(input_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        prediction_id = torch.argmax(probabilities, dim=1).item()
    
    predicted_label = ID_TO_LABEL[prediction_id]
    confidence = probabilities[0][prediction_id].item()

    return predicted_label, confidence

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Предсказать тональность текста с помощью обученной модели.")
    parser.add_argument("text", type=str, help="Текст для анализа (в кавычках).")
    args = parser.parse_args()

    print("Загрузка модели и артефактов...")
    try:
        loaded_model, loaded_vocab, loaded_config, device = load_artifacts()
        print(f"Модель '{loaded_config['model_type']}' успешно загружена на устройство {device}.")
    except FileNotFoundError:
        print("\nОШИБКА: Файлы модели не найдены!")
        print("Сначала запустите скрипт train.py для обучения и сохранения модели.")
        exit()

    label, conf = predict(args.text, loaded_model, loaded_vocab, loaded_config, device)

    print("\n--- Результат предсказания ---")
    print(f"Текст: '{args.text}'")
    print(f"Тональность: {label}")
    print(f"Уверенность: {conf:.2%}")