import json import torch from nltk.tokenize import word_tokenize import argparse from src.models.models import TransformerClassifier, MambaClassifier, LSTMClassifier SAVE_DIR = "pretrained" MODEL_PATH = f"{SAVE_DIR}/best_model.pth" CONFIG_PATH = f"{SAVE_DIR}/config.json" VOCAB_PATH = f"{SAVE_DIR}/vocab.json" ID_TO_LABEL = {0: "Negative", 1: "Positive"} def load_artifacts(): with open(CONFIG_PATH, 'r') as f: config = json.load(f) with open(VOCAB_PATH, 'r') as f: vocab = json.load(f) model_type = config['model_type'] model_params = config['model_params'] if model_type == 'Transformer': model = TransformerClassifier(**model_params) elif model_type == 'Mamba': model = MambaClassifier(**model_params) elif model_type == 'LSTM': model = LSTMClassifier(**model_params) else: raise ValueError("Неизвестный тип модели в файле конфигурации.") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) model.to(device) model.eval() return model, vocab, config, device def preprocess_text(text, vocab, max_len): tokens = word_tokenize(text.lower()) ids = [vocab.get(token, vocab['']) for token in tokens] if len(ids) < max_len: ids.extend([vocab['']] * (max_len - len(ids))) else: ids = ids[:max_len] return torch.tensor(ids).unsqueeze(0) def predict(text, model, vocab, config, device): input_tensor = preprocess_text(text, vocab, config['max_seq_len']) input_tensor = input_tensor.to(device) with torch.no_grad(): outputs = model(input_tensor) probabilities = torch.softmax(outputs, dim=1) prediction_id = torch.argmax(probabilities, dim=1).item() predicted_label = ID_TO_LABEL[prediction_id] confidence = probabilities[0][prediction_id].item() return predicted_label, confidence if __name__ == "__main__": parser = argparse.ArgumentParser(description="Предсказать тональность текста с помощью обученной модели.") parser.add_argument("text", type=str, help="Текст для анализа (в кавычках).") args = parser.parse_args() print("Загрузка модели и артефактов...") try: loaded_model, loaded_vocab, loaded_config, device = load_artifacts() print(f"Модель '{loaded_config['model_type']}' успешно загружена на устройство {device}.") except FileNotFoundError: print("\nОШИБКА: Файлы модели не найдены!") print("Сначала запустите скрипт train.py для обучения и сохранения модели.") exit() label, conf = predict(args.text, loaded_model, loaded_vocab, loaded_config, device) print("\n--- Результат предсказания ---") print(f"Текст: '{args.text}'") print(f"Тональность: {label}") print(f"Уверенность: {conf:.2%}")