{ "cells": [ { "cell_type": "markdown", "id": "9bfb61e1", "metadata": {}, "source": [ "# Сравниваем модели и сохраняем в `src/models/pretrained`" ] }, { "cell_type": "markdown", "id": "f0574ac3", "metadata": {}, "source": [ "- Импорты\n", "- Константы\n", "- Считывание датасетов" ] }, { "cell_type": "code", "execution_count": 1, "id": "5a237c5c", "metadata": {}, "outputs": [], "source": [ "import os\n", "import time\n", "import torch\n", "import warnings\n", "import numpy as np\n", "import pandas as pd\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader, TensorDataset\n", "from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n", "for warn in [UserWarning, FutureWarning]: warnings.filterwarnings(\"ignore\", category = warn)\n", "\n", "from src.data_utils.config import DatasetConfig\n", "from src.data_utils.dataset_params import DatasetName\n", "from src.data_utils.dataset_generator import DatasetGenerator\n", "from src.models.models import TransformerClassifier, CustomMambaClassifier, LSTMClassifier\n", "\n", "MAX_SEQ_LEN = 300\n", "EMBEDDING_DIM = 128\n", "BATCH_SIZE = 32\n", "LEARNING_RATE = 1e-4\n", "NUM_EPOCHS = 5 # для быстрого сравнения моделей\n", "NUM_CLASSES = 2\n", "\n", "SAVE_DIR = \"../pretrained_comparison\"\n", "os.makedirs(SAVE_DIR, exist_ok=True)\n", "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "config = DatasetConfig(\n", " load_from_disk=True,\n", " path_to_data=\"../datasets\"\n", ")\n", "\n", "generator = DatasetGenerator(DatasetName.IMDB, config=config)\n", "(X_train, y_train), (X_val, y_val), (X_test, y_test) = generator.generate_dataset()\n", "VOCAB_SIZE = len(generator.vocab)" ] }, { "cell_type": "markdown", "id": "5b95192d", "metadata": {}, "source": [ "Вспомогательные функции для трейна/валидации/теста" ] }, { "cell_type": "code", "execution_count": 2, "id": "b2a4534c", "metadata": {}, "outputs": [], "source": [ "\n", "def train_and_evaluate(model, train_loader, val_loader, optimizer, criterion, num_epochs, device, model_name, save_path):\n", " best_val_f1 = 0.0\n", " history = {'train_loss': [], 'val_loss': [], 'val_accuracy': [], 'val_f1': []}\n", " \n", " print(f\"--- Начало обучения модели: {model_name} на устройстве {device} ---\")\n", "\n", " for epoch in range(num_epochs):\n", " model.train()\n", " start_time = time.time()\n", " total_train_loss = 0\n", "\n", " for batch_X, batch_y in train_loader:\n", " batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n", " optimizer.zero_grad()\n", " outputs = model(batch_X)\n", " loss = criterion(outputs, batch_y)\n", " loss.backward()\n", " optimizer.step()\n", " total_train_loss += loss.item()\n", " \n", " avg_train_loss = total_train_loss / len(train_loader)\n", " history['train_loss'].append(avg_train_loss)\n", "\n", " model.eval()\n", " total_val_loss = 0\n", " all_preds = []\n", " all_labels = []\n", "\n", " with torch.no_grad():\n", " for batch_X, batch_y in val_loader:\n", " batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n", " outputs = model(batch_X)\n", " loss = criterion(outputs, batch_y)\n", " total_val_loss += loss.item()\n", " \n", " _, predicted = torch.max(outputs.data, 1)\n", " all_preds.extend(predicted.cpu().numpy())\n", " all_labels.extend(batch_y.cpu().numpy())\n", " \n", " avg_val_loss = total_val_loss / len(val_loader)\n", " \n", " accuracy = accuracy_score(all_labels, all_preds)\n", " precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n", " \n", " history['val_loss'].append(avg_val_loss)\n", " history['val_accuracy'].append(accuracy)\n", " history['val_f1'].append(f1)\n", "\n", " epoch_time = time.time() - start_time\n", " print(f\"Эпоха {epoch+1}/{num_epochs} | Время: {epoch_time:.2f}с | Train Loss: {avg_train_loss:.4f} | \"\n", " f\"Val Loss: {avg_val_loss:.4f} | Val Acc: {accuracy:.4f} | Val F1: {f1:.4f}\")\n", "\n", " if f1 > best_val_f1:\n", " best_val_f1 = f1\n", " torch.save(model.state_dict(), save_path)\n", " print(f\" -> Модель сохранена, новый лучший Val F1: {best_val_f1:.4f}\")\n", " \n", " print(f\"--- Обучение модели {model_name} завершено ---\")\n", " return history\n", "\n", "def evaluate_on_test(model, test_loader, device, criterion):\n", " model.eval()\n", " total_test_loss = 0\n", " all_preds = []\n", " all_labels = []\n", "\n", " with torch.no_grad():\n", " for batch_X, batch_y in test_loader:\n", " batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n", " outputs = model(batch_X)\n", " loss = criterion(outputs, batch_y)\n", " total_test_loss += loss.item()\n", " \n", " _, predicted = torch.max(outputs.data, 1)\n", " all_preds.extend(predicted.cpu().numpy())\n", " all_labels.extend(batch_y.cpu().numpy())\n", " \n", " avg_test_loss = total_test_loss / len(test_loader)\n", " \n", " accuracy = accuracy_score(all_labels, all_preds)\n", " precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n", " \n", " return {'loss': avg_test_loss, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1_score': f1}" ] }, { "cell_type": "markdown", "id": "1be50523", "metadata": {}, "source": [ "Создание даталоадера" ] }, { "cell_type": "code", "execution_count": 3, "id": "cccc5bea", "metadata": {}, "outputs": [], "source": [ "def create_dataloader(X, y, batch_size, shuffle=True):\n", " X_tensor = torch.as_tensor(X, dtype=torch.long)\n", " y_tensor = torch.as_tensor(y, dtype=torch.long)\n", " dataset = TensorDataset(X_tensor, y_tensor)\n", " return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)\n", "\n", "train_loader = create_dataloader(X_train, y_train, BATCH_SIZE)\n", "val_loader = create_dataloader(X_val, y_val, BATCH_SIZE, shuffle=False)\n", "test_loader = create_dataloader(X_test, y_test, BATCH_SIZE, shuffle=False)" ] }, { "cell_type": "markdown", "id": "4938b9f3", "metadata": {}, "source": [ "Сравнения моделей\n", "\n", "Смотрим первые 5 эпох чтобы выбрать лучшую модель, с которой будем играться дальше" ] }, { "cell_type": "code", "execution_count": 4, "id": "0244aafa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--- Начало обучения модели: CustomMamba на устройстве cuda ---\n", "Эпоха 1/5 | Время: 337.85с | Train Loss: 0.6768 | Val Loss: 0.6168 | Val Acc: 0.6592 | Val F1: 0.5937\n", " -> Модель сохранена, новый лучший Val F1: 0.5937\n", "Эпоха 2/5 | Время: 345.54с | Train Loss: 0.5266 | Val Loss: 0.4964 | Val Acc: 0.7580 | Val F1: 0.7552\n", " -> Модель сохранена, новый лучший Val F1: 0.7552\n", "Эпоха 3/5 | Время: 343.23с | Train Loss: 0.4329 | Val Loss: 0.4586 | Val Acc: 0.7812 | Val F1: 0.7830\n", " -> Модель сохранена, новый лучший Val F1: 0.7830\n", "Эпоха 4/5 | Время: 342.62с | Train Loss: 0.3730 | Val Loss: 0.4596 | Val Acc: 0.7928 | Val F1: 0.8056\n", " -> Модель сохранена, новый лучший Val F1: 0.8056\n", "Эпоха 5/5 | Время: 340.21с | Train Loss: 0.3127 | Val Loss: 0.4469 | Val Acc: 0.7996 | Val F1: 0.8124\n", " -> Модель сохранена, новый лучший Val F1: 0.8124\n", "--- Обучение модели CustomMamba завершено ---\n", "--- Оценка лучшей модели CustomMamba на тестовых данных ---\n", "Результаты для CustomMamba: {'loss': 0.44949763529239944, 'accuracy': 0.8062, 'precision': 0.778874269005848, 'recall': 0.8541082164328657, 'f1_score': 0.8147581724335691}\n", "------------------------------------------------------------\n", "--- Начало обучения модели: Lib_LSTM на устройстве cuda ---\n", "Эпоха 1/5 | Время: 5.09с | Train Loss: 0.6930 | Val Loss: 0.6922 | Val Acc: 0.5170 | Val F1: 0.4221\n", " -> Модель сохранена, новый лучший Val F1: 0.4221\n", "Эпоха 2/5 | Время: 5.03с | Train Loss: 0.6911 | Val Loss: 0.6899 | Val Acc: 0.5324 | Val F1: 0.4880\n", " -> Модель сохранена, новый лучший Val F1: 0.4880\n", "Эпоха 3/5 | Время: 5.03с | Train Loss: 0.6864 | Val Loss: 0.6837 | Val Acc: 0.5530 | Val F1: 0.5605\n", " -> Модель сохранена, новый лучший Val F1: 0.5605\n", "Эпоха 4/5 | Время: 5.03с | Train Loss: 0.6740 | Val Loss: 0.6589 | Val Acc: 0.6096 | Val F1: 0.6208\n", " -> Модель сохранена, новый лучший Val F1: 0.6208\n", "Эпоха 5/5 | Время: 5.04с | Train Loss: 0.6489 | Val Loss: 0.6501 | Val Acc: 0.6498 | Val F1: 0.6460\n", " -> Модель сохранена, новый лучший Val F1: 0.6460\n", "--- Обучение модели Lib_LSTM завершено ---\n", "--- Оценка лучшей модели Lib_LSTM на тестовых данных ---\n", "Результаты для Lib_LSTM: {'loss': 0.6330309821541902, 'accuracy': 0.6644, 'precision': 0.6724356268467708, 'recall': 0.6384769539078157, 'f1_score': 0.655016447368421}\n", "------------------------------------------------------------\n", "--- Начало обучения модели: Lib_Transformer на устройстве cuda ---\n", "Эпоха 1/5 | Время: 4.28с | Train Loss: 0.6712 | Val Loss: 0.6773 | Val Acc: 0.5292 | Val F1: 0.1729\n", " -> Модель сохранена, новый лучший Val F1: 0.1729\n", "Эпоха 2/5 | Время: 4.14с | Train Loss: 0.5753 | Val Loss: 0.5631 | Val Acc: 0.7308 | Val F1: 0.7701\n", " -> Модель сохранена, новый лучший Val F1: 0.7701\n", "Эпоха 3/5 | Время: 4.17с | Train Loss: 0.4836 | Val Loss: 0.5106 | Val Acc: 0.7622 | Val F1: 0.7830\n", " -> Модель сохранена, новый лучший Val F1: 0.7830\n", "Эпоха 4/5 | Время: 4.16с | Train Loss: 0.4399 | Val Loss: 0.4880 | Val Acc: 0.7814 | Val F1: 0.7763\n", "Эпоха 5/5 | Время: 4.13с | Train Loss: 0.4014 | Val Loss: 0.4611 | Val Acc: 0.7946 | Val F1: 0.8078\n", " -> Модель сохранена, новый лучший Val F1: 0.8078\n", "--- Обучение модели Lib_Transformer завершено ---\n", "--- Оценка лучшей модели Lib_Transformer на тестовых данных ---\n", "Результаты для Lib_Transformer: {'loss': 0.4671077333438169, 'accuracy': 0.7938, 'precision': 0.7661818181818182, 'recall': 0.8444889779559118, 'f1_score': 0.8034318398474738}\n", "------------------------------------------------------------\n", "\n", "\n", "--- Итоговая таблица сравнения моделей на тестовых данных ---\n", " loss accuracy precision recall f1_score\n", "CustomMamba 0.449498 0.8062 0.778874 0.854108 0.814758\n", "Lib_LSTM 0.633031 0.6644 0.672436 0.638477 0.655016\n", "Lib_Transformer 0.467108 0.7938 0.766182 0.844489 0.803432\n" ] } ], "source": [ "model_configs = {\n", " \"CustomMamba\": {\n", " \"class\": CustomMambaClassifier,\n", " \"params\": {'vocab_size': VOCAB_SIZE, 'd_model': EMBEDDING_DIM, 'd_state': 8, \n", " 'd_conv': 4, 'num_layers': 2, 'num_classes': NUM_CLASSES},\n", " },\n", "\n", " \"Lib_LSTM\": {\n", " \"class\": LSTMClassifier,\n", " \"params\": {'vocab_size': VOCAB_SIZE, 'embed_dim': EMBEDDING_DIM, 'hidden_dim': 128, \n", " 'num_layers': 2, 'num_classes': NUM_CLASSES, 'dropout': 0.5},\n", " },\n", " \"Lib_Transformer\": {\n", " \"class\": TransformerClassifier,\n", " \"params\": {'vocab_size': VOCAB_SIZE, 'embed_dim': EMBEDDING_DIM, 'num_heads': 4, \n", " 'num_layers': 2, 'num_classes': NUM_CLASSES, 'max_seq_len': MAX_SEQ_LEN},\n", " },\n", "}\n", "\n", "results = {}\n", "for model_name, config in model_configs.items():\n", "\n", " model_path = os.path.join(SAVE_DIR, f\"best_model_{model_name.lower()}.pth\")\n", " \n", " model = config['class'](**config['params']).to(DEVICE)\n", " optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)\n", " criterion = nn.CrossEntropyLoss()\n", " \n", " train_and_evaluate(\n", " model=model, train_loader=train_loader, val_loader=val_loader,\n", " optimizer=optimizer, criterion=criterion, num_epochs=NUM_EPOCHS,\n", " device=DEVICE, model_name=model_name, save_path=model_path\n", " )\n", " \n", " print(f\"--- Оценка лучшей модели {model_name} на тестовых данных ---\")\n", " if os.path.exists(model_path):\n", " best_model = config['class'](**config['params']).to(DEVICE)\n", " best_model.load_state_dict(torch.load(model_path))\n", " test_metrics = evaluate_on_test(best_model, test_loader, DEVICE, criterion)\n", " results[model_name] = test_metrics\n", " print(f\"Результаты для {model_name}: {test_metrics}\")\n", " else:\n", " print(f\"Файл лучшей модели для {model_name} не найден. Пропускаем оценку.\")\n", "\n", " print(\"-\" * 60)\n", " \n", "if results:\n", " results_df = pd.DataFrame(results).T\n", " print(\"\\n\\n--- Итоговая таблица сравнения моделей на тестовых данных ---\")\n", " print(results_df.to_string())\n", "else:\n", " print(\"Не удалось получить результаты ни для одной модели.\")\n" ] }, { "cell_type": "markdown", "id": "404db766", "metadata": {}, "source": [ "По результатам видно, что LSTM и Transformer обучаются быстро, но Mamba обучается хорошо. Дальнейшие шаги следующие \n", " - Пробуем сравнить Transformer и Mamba более детально, играем с гиперпараметрами\n", " - LSTM проигрывает Transformer и по времени, и по качеству, поэтому в следующий этап сравнения не пойдет\n", " \n", "Цель следующего иследования: найти идеальный баланс между временем и качеством. Поставим больше эпох, меньший lr для обоих моделей, увеличим датасет (в текущем сетапе было 10'000 сэмплов на трейн и по 5'000 на валидацию/тест)" ] } ], "metadata": { "kernelspec": { "display_name": "monkey-coding-dl-project-rj23F0vJ-py3.12", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.5" } }, "nbformat": 4, "nbformat_minor": 5 }