{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Инициализация глобальных переменных, достаем датасет" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import time\n", "import json\n", "import torch\n", "import warnings\n", "import numpy as np\n", "import pandas as pd\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader, TensorDataset\n", "from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n", "for warn in [UserWarning, FutureWarning]: warnings.filterwarnings(\"ignore\", category = warn)\n", "\n", "from src.data_utils.config import DatasetConfig\n", "from src.data_utils.dataset_params import DatasetName\n", "from src.data_utils.dataset_generator import DatasetGenerator\n", "from src.models.models import TransformerClassifier\n", "\n", "MAX_SEQ_LEN = 300\n", "EMBEDDING_DIM = 64 # уменьшили: 128 -> 64, чтобы влезло в гит\n", "BATCH_SIZE = 64 # подняли batch_size: 32 -> 64\n", "LEARNING_RATE = 7e-5\n", "NUM_EPOCHS = 100 # подняли количество эпох: 20 -> 100\n", "NUM_CLASSES = 2\n", "\n", "SAVE_DIR = \"../pretrained\"\n", "os.makedirs(SAVE_DIR, exist_ok=True)\n", "MODEL_SAVE_PATH = os.path.join(SAVE_DIR, \"best_model.pth\")\n", "VOCAB_SAVE_PATH = os.path.join(SAVE_DIR, \"vocab.json\")\n", "CONFIG_SAVE_PATH = os.path.join(SAVE_DIR, \"config.json\")\n", "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "MODEL_TO_TRAIN = 'Transformer' \n", "\n", "config = DatasetConfig(\n", " load_from_disk=True,\n", " path_to_data=\"../datasets\",\n", " train_size=25000, # взяли весь датасет\n", " val_size=12500,\n", " test_size=12500\n", ")\n", "generator = DatasetGenerator(DatasetName.IMDB, config=config)\n", "(X_train, y_train), (X_val, y_val), (X_test, y_test) = generator.generate_dataset()\n", "VOCAB_SIZE = len(generator.vocab)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Создаем даталоадеры" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def create_dataloader(X, y, batch_size):\n", " dataset = TensorDataset(torch.tensor(X, dtype=torch.long), torch.tensor(y, dtype=torch.long))\n", " return DataLoader(dataset, batch_size=batch_size, shuffle=True)\n", "train_loader = create_dataloader(X_train, y_train, BATCH_SIZE)\n", "val_loader = create_dataloader(X_val, y_val, BATCH_SIZE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Инициализация модели" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "model_params = {'vocab_size': len(generator.vocab), 'embed_dim': EMBEDDING_DIM, 'num_heads': 8, 'num_layers': 4, 'num_classes': 2, 'max_seq_len': MAX_SEQ_LEN}\n", "model = TransformerClassifier(**model_params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Обучение" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--- Начало обучения модели ---\n", "Эпоха 1/100 | Время: 14.99с | Train Loss: 0.6635 | Val Loss: 0.6622 | Val Acc: 0.6141\n", " -> Модель сохранена, новая лучшая Val Loss: 0.6622\n", "Эпоха 2/100 | Время: 14.15с | Train Loss: 0.6071 | Val Loss: 0.6136 | Val Acc: 0.6771\n", " -> Модель сохранена, новая лучшая Val Loss: 0.6136\n", "Эпоха 3/100 | Время: 14.18с | Train Loss: 0.5288 | Val Loss: 0.5463 | Val Acc: 0.7353\n", " -> Модель сохранена, новая лучшая Val Loss: 0.5463\n", "Эпоха 4/100 | Время: 14.12с | Train Loss: 0.4793 | Val Loss: 0.5079 | Val Acc: 0.7611\n", " -> Модель сохранена, новая лучшая Val Loss: 0.5079\n", "Эпоха 5/100 | Время: 14.09с | Train Loss: 0.4535 | Val Loss: 0.4906 | Val Acc: 0.7718\n", " -> Модель сохранена, новая лучшая Val Loss: 0.4906\n", "Эпоха 6/100 | Время: 14.19с | Train Loss: 0.4266 | Val Loss: 0.4683 | Val Acc: 0.7852\n", " -> Модель сохранена, новая лучшая Val Loss: 0.4683\n", "Эпоха 7/100 | Время: 14.19с | Train Loss: 0.4062 | Val Loss: 0.4531 | Val Acc: 0.7970\n", " -> Модель сохранена, новая лучшая Val Loss: 0.4531\n", "Эпоха 8/100 | Время: 14.28с | Train Loss: 0.3809 | Val Loss: 0.4390 | Val Acc: 0.8087\n", " -> Модель сохранена, новая лучшая Val Loss: 0.4390\n", "Эпоха 9/100 | Время: 14.17с | Train Loss: 0.3641 | Val Loss: 0.4281 | Val Acc: 0.8147\n", " -> Модель сохранена, новая лучшая Val Loss: 0.4281\n", "Эпоха 10/100 | Время: 14.27с | Train Loss: 0.3483 | Val Loss: 0.4213 | Val Acc: 0.8203\n", " -> Модель сохранена, новая лучшая Val Loss: 0.4213\n", "Эпоха 11/100 | Время: 14.09с | Train Loss: 0.3363 | Val Loss: 0.4098 | Val Acc: 0.8258\n", " -> Модель сохранена, новая лучшая Val Loss: 0.4098\n", "Эпоха 12/100 | Время: 14.19с | Train Loss: 0.3245 | Val Loss: 0.4068 | Val Acc: 0.8290\n", " -> Модель сохранена, новая лучшая Val Loss: 0.4068\n", "Эпоха 13/100 | Время: 14.14с | Train Loss: 0.3119 | Val Loss: 0.3983 | Val Acc: 0.8314\n", " -> Модель сохранена, новая лучшая Val Loss: 0.3983\n", "Эпоха 14/100 | Время: 14.16с | Train Loss: 0.3020 | Val Loss: 0.3958 | Val Acc: 0.8333\n", " -> Модель сохранена, новая лучшая Val Loss: 0.3958\n", "Эпоха 15/100 | Время: 14.21с | Train Loss: 0.2893 | Val Loss: 0.3918 | Val Acc: 0.8334\n", " -> Модель сохранена, новая лучшая Val Loss: 0.3918\n", "Эпоха 16/100 | Время: 14.17с | Train Loss: 0.2792 | Val Loss: 0.3913 | Val Acc: 0.8337\n", " -> Модель сохранена, новая лучшая Val Loss: 0.3913\n", "Эпоха 17/100 | Время: 14.21с | Train Loss: 0.2735 | Val Loss: 0.3845 | Val Acc: 0.8372\n", " -> Модель сохранена, новая лучшая Val Loss: 0.3845\n", "Эпоха 18/100 | Время: 14.17с | Train Loss: 0.2606 | Val Loss: 0.3860 | Val Acc: 0.8350\n", "Эпоха 19/100 | Время: 14.14с | Train Loss: 0.2490 | Val Loss: 0.3838 | Val Acc: 0.8396\n", " -> Модель сохранена, новая лучшая Val Loss: 0.3838\n", "Эпоха 20/100 | Время: 14.12с | Train Loss: 0.2438 | Val Loss: 0.3802 | Val Acc: 0.8386\n", " -> Модель сохранена, новая лучшая Val Loss: 0.3802\n", "Эпоха 21/100 | Время: 14.16с | Train Loss: 0.2347 | Val Loss: 0.3926 | Val Acc: 0.8387\n", "Эпоха 22/100 | Время: 14.10с | Train Loss: 0.2240 | Val Loss: 0.3877 | Val Acc: 0.8398\n", "Эпоха 23/100 | Время: 14.13с | Train Loss: 0.2148 | Val Loss: 0.3800 | Val Acc: 0.8399\n", " -> Модель сохранена, новая лучшая Val Loss: 0.3800\n", "Эпоха 24/100 | Время: 14.11с | Train Loss: 0.2091 | Val Loss: 0.4051 | Val Acc: 0.8296\n", "Эпоха 25/100 | Время: 14.17с | Train Loss: 0.1970 | Val Loss: 0.3867 | Val Acc: 0.8405\n", "Эпоха 26/100 | Время: 14.06с | Train Loss: 0.1887 | Val Loss: 0.3928 | Val Acc: 0.8370\n", "Эпоха 27/100 | Время: 14.31с | Train Loss: 0.1776 | Val Loss: 0.3960 | Val Acc: 0.8387\n", "Эпоха 28/100 | Время: 14.09с | Train Loss: 0.1712 | Val Loss: 0.3942 | Val Acc: 0.8416\n", "Эпоха 29/100 | Время: 14.08с | Train Loss: 0.1654 | Val Loss: 0.3925 | Val Acc: 0.8417\n", "Эпоха 30/100 | Время: 14.14с | Train Loss: 0.1529 | Val Loss: 0.4102 | Val Acc: 0.8398\n", "Эпоха 31/100 | Время: 14.12с | Train Loss: 0.1470 | Val Loss: 0.4541 | Val Acc: 0.8267\n", "Эпоха 32/100 | Время: 14.10с | Train Loss: 0.1422 | Val Loss: 0.4121 | Val Acc: 0.8411\n", "Эпоха 33/100 | Время: 14.09с | Train Loss: 0.1344 | Val Loss: 0.4387 | Val Acc: 0.8348\n", "Эпоха 34/100 | Время: 14.10с | Train Loss: 0.1263 | Val Loss: 0.4273 | Val Acc: 0.8394\n", "Эпоха 35/100 | Время: 14.11с | Train Loss: 0.1222 | Val Loss: 0.4251 | Val Acc: 0.8385\n", "Эпоха 36/100 | Время: 14.08с | Train Loss: 0.1146 | Val Loss: 0.4300 | Val Acc: 0.8403\n", "Эпоха 37/100 | Время: 14.10с | Train Loss: 0.1026 | Val Loss: 0.4730 | Val Acc: 0.8326\n", "Эпоха 38/100 | Время: 14.15с | Train Loss: 0.0979 | Val Loss: 0.4533 | Val Acc: 0.8390\n", "Эпоха 39/100 | Время: 14.14с | Train Loss: 0.0923 | Val Loss: 0.4644 | Val Acc: 0.8372\n", "Эпоха 40/100 | Время: 14.07с | Train Loss: 0.0871 | Val Loss: 0.4811 | Val Acc: 0.8366\n", "Эпоха 41/100 | Время: 14.06с | Train Loss: 0.0827 | Val Loss: 0.4794 | Val Acc: 0.8362\n", "Эпоха 42/100 | Время: 14.11с | Train Loss: 0.0776 | Val Loss: 0.5005 | Val Acc: 0.8346\n", "Эпоха 43/100 | Время: 14.10с | Train Loss: 0.0707 | Val Loss: 0.5144 | Val Acc: 0.8345\n", "Эпоха 44/100 | Время: 14.14с | Train Loss: 0.0666 | Val Loss: 0.5409 | Val Acc: 0.8294\n", "Эпоха 45/100 | Время: 14.08с | Train Loss: 0.0628 | Val Loss: 0.5412 | Val Acc: 0.8313\n", "Эпоха 46/100 | Время: 14.10с | Train Loss: 0.0565 | Val Loss: 0.5808 | Val Acc: 0.8253\n", "Эпоха 47/100 | Время: 14.10с | Train Loss: 0.0541 | Val Loss: 0.5608 | Val Acc: 0.8306\n", "Эпоха 48/100 | Время: 14.11с | Train Loss: 0.0474 | Val Loss: 0.6120 | Val Acc: 0.8246\n", "Эпоха 49/100 | Время: 14.12с | Train Loss: 0.0442 | Val Loss: 0.6136 | Val Acc: 0.8258\n", "Эпоха 50/100 | Время: 14.07с | Train Loss: 0.0437 | Val Loss: 0.6455 | Val Acc: 0.8232\n", "Эпоха 51/100 | Время: 14.06с | Train Loss: 0.0380 | Val Loss: 0.6393 | Val Acc: 0.8275\n", "Эпоха 52/100 | Время: 14.07с | Train Loss: 0.0402 | Val Loss: 0.6521 | Val Acc: 0.8249\n", "Эпоха 53/100 | Время: 14.07с | Train Loss: 0.0352 | Val Loss: 0.6490 | Val Acc: 0.8282\n", "Эпоха 54/100 | Время: 14.10с | Train Loss: 0.0327 | Val Loss: 0.6523 | Val Acc: 0.8305\n", "Эпоха 55/100 | Время: 14.10с | Train Loss: 0.0310 | Val Loss: 0.7744 | Val Acc: 0.8110\n", "Эпоха 56/100 | Время: 14.22с | Train Loss: 0.0301 | Val Loss: 0.6825 | Val Acc: 0.8278\n", "Эпоха 57/100 | Время: 14.31с | Train Loss: 0.0270 | Val Loss: 0.6777 | Val Acc: 0.8287\n", "Эпоха 58/100 | Время: 14.14с | Train Loss: 0.0285 | Val Loss: 0.6846 | Val Acc: 0.8278\n", "Эпоха 59/100 | Время: 14.07с | Train Loss: 0.0238 | Val Loss: 0.7758 | Val Acc: 0.8201\n", "Эпоха 60/100 | Время: 14.11с | Train Loss: 0.0251 | Val Loss: 0.7602 | Val Acc: 0.8184\n", "Эпоха 61/100 | Время: 14.08с | Train Loss: 0.0200 | Val Loss: 0.7125 | Val Acc: 0.8260\n", "Эпоха 62/100 | Время: 14.11с | Train Loss: 0.0202 | Val Loss: 0.7320 | Val Acc: 0.8292\n", "Эпоха 63/100 | Время: 14.27с | Train Loss: 0.0188 | Val Loss: 0.7456 | Val Acc: 0.8263\n", "Эпоха 64/100 | Время: 14.15с | Train Loss: 0.0171 | Val Loss: 0.9202 | Val Acc: 0.8054\n", "Эпоха 65/100 | Время: 14.14с | Train Loss: 0.0142 | Val Loss: 0.8072 | Val Acc: 0.8242\n", "Эпоха 66/100 | Время: 14.08с | Train Loss: 0.0179 | Val Loss: 0.7872 | Val Acc: 0.8225\n", "Эпоха 67/100 | Время: 14.07с | Train Loss: 0.0142 | Val Loss: 0.8251 | Val Acc: 0.8242\n", "Эпоха 68/100 | Время: 14.03с | Train Loss: 0.0186 | Val Loss: 0.8471 | Val Acc: 0.8183\n", "Эпоха 69/100 | Время: 14.07с | Train Loss: 0.0135 | Val Loss: 0.7898 | Val Acc: 0.8265\n", "Эпоха 70/100 | Время: 14.28с | Train Loss: 0.0155 | Val Loss: 0.8558 | Val Acc: 0.8210\n", "Эпоха 71/100 | Время: 14.17с | Train Loss: 0.0128 | Val Loss: 0.8150 | Val Acc: 0.8258\n", "Эпоха 72/100 | Время: 14.16с | Train Loss: 0.0110 | Val Loss: 0.7982 | Val Acc: 0.8265\n", "Эпоха 73/100 | Время: 14.14с | Train Loss: 0.0142 | Val Loss: 0.8362 | Val Acc: 0.8244\n", "Эпоха 74/100 | Время: 14.16с | Train Loss: 0.0105 | Val Loss: 0.8301 | Val Acc: 0.8269\n", "Эпоха 75/100 | Время: 14.15с | Train Loss: 0.0106 | Val Loss: 0.8284 | Val Acc: 0.8282\n", "Эпоха 76/100 | Время: 14.16с | Train Loss: 0.0111 | Val Loss: 0.8481 | Val Acc: 0.8266\n", "Эпоха 77/100 | Время: 14.15с | Train Loss: 0.0100 | Val Loss: 0.8781 | Val Acc: 0.8224\n", "Эпоха 78/100 | Время: 14.15с | Train Loss: 0.0097 | Val Loss: 0.8392 | Val Acc: 0.8248\n", "Эпоха 79/100 | Время: 14.13с | Train Loss: 0.0089 | Val Loss: 0.8454 | Val Acc: 0.8265\n", "Эпоха 80/100 | Время: 14.03с | Train Loss: 0.0123 | Val Loss: 0.8227 | Val Acc: 0.8279\n", "Эпоха 81/100 | Время: 14.04с | Train Loss: 0.0087 | Val Loss: 0.8900 | Val Acc: 0.8199\n", "Эпоха 82/100 | Время: 14.04с | Train Loss: 0.0074 | Val Loss: 0.8790 | Val Acc: 0.8229\n", "Эпоха 83/100 | Время: 14.08с | Train Loss: 0.0078 | Val Loss: 0.8883 | Val Acc: 0.8214\n", "Эпоха 84/100 | Время: 14.10с | Train Loss: 0.0099 | Val Loss: 0.8660 | Val Acc: 0.8242\n", "Эпоха 85/100 | Время: 14.09с | Train Loss: 0.0077 | Val Loss: 1.0760 | Val Acc: 0.8019\n", "Эпоха 86/100 | Время: 14.10с | Train Loss: 0.0091 | Val Loss: 0.8935 | Val Acc: 0.8234\n", "Эпоха 87/100 | Время: 14.12с | Train Loss: 0.0082 | Val Loss: 0.9913 | Val Acc: 0.8093\n", "Эпоха 88/100 | Время: 14.15с | Train Loss: 0.0051 | Val Loss: 0.8942 | Val Acc: 0.8241\n", "Эпоха 89/100 | Время: 14.17с | Train Loss: 0.0057 | Val Loss: 0.8939 | Val Acc: 0.8252\n", "Эпоха 90/100 | Время: 14.15с | Train Loss: 0.0064 | Val Loss: 0.8819 | Val Acc: 0.8284\n", "Эпоха 91/100 | Время: 14.15с | Train Loss: 0.0073 | Val Loss: 0.9078 | Val Acc: 0.8258\n", "Эпоха 92/100 | Время: 14.12с | Train Loss: 0.0068 | Val Loss: 0.9801 | Val Acc: 0.8162\n", "Эпоха 93/100 | Время: 14.06с | Train Loss: 0.0102 | Val Loss: 0.8664 | Val Acc: 0.8279\n", "Эпоха 94/100 | Время: 14.07с | Train Loss: 0.0058 | Val Loss: 0.9026 | Val Acc: 0.8246\n", "Эпоха 95/100 | Время: 14.08с | Train Loss: 0.0049 | Val Loss: 0.8784 | Val Acc: 0.8298\n", "Эпоха 96/100 | Время: 14.34с | Train Loss: 0.0051 | Val Loss: 0.9108 | Val Acc: 0.8246\n", "Эпоха 97/100 | Время: 14.18с | Train Loss: 0.0057 | Val Loss: 0.9554 | Val Acc: 0.8155\n", "Эпоха 98/100 | Время: 14.10с | Train Loss: 0.0063 | Val Loss: 0.9046 | Val Acc: 0.8274\n", "Эпоха 99/100 | Время: 14.12с | Train Loss: 0.0039 | Val Loss: 0.9310 | Val Acc: 0.8250\n", "Эпоха 100/100 | Время: 14.11с | Train Loss: 0.0064 | Val Loss: 0.9123 | Val Acc: 0.8269\n" ] } ], "source": [ "model.to(DEVICE)\n", "optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)\n", "criterion = nn.CrossEntropyLoss()\n", "\n", "best_val_loss = float('inf')\n", "print(f\"--- Начало обучения модели ---\")\n", "for epoch in range(NUM_EPOCHS):\n", " model.train()\n", " start_time = time.time()\n", " total_train_loss = 0\n", "\n", " for batch_X, batch_y in train_loader:\n", " batch_X, batch_y = batch_X.to(DEVICE), batch_y.to(DEVICE)\n", " optimizer.zero_grad()\n", " outputs = model(batch_X)\n", " loss = criterion(outputs, batch_y)\n", " loss.backward()\n", " optimizer.step()\n", " total_train_loss += loss.item()\n", " avg_train_loss = total_train_loss / len(train_loader)\n", " \n", " model.eval()\n", " total_val_loss, correct_val, total_val = 0, 0, 0\n", " with torch.no_grad():\n", " for batch_X, batch_y in val_loader:\n", " batch_X, batch_y = batch_X.to(DEVICE), batch_y.to(DEVICE)\n", " outputs = model(batch_X)\n", " loss = criterion(outputs, batch_y)\n", " total_val_loss += loss.item()\n", " _, predicted = torch.max(outputs.data, 1)\n", " total_val += batch_y.size(0)\n", " correct_val += (predicted == batch_y).sum().item()\n", " avg_val_loss = total_val_loss / len(val_loader)\n", " val_accuracy = correct_val / total_val\n", "\n", " epoch_time = time.time() - start_time\n", " print(f\"Эпоха {epoch+1}/{NUM_EPOCHS} | Время: {epoch_time:.2f}с | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.4f}\")\n", " \n", " if avg_val_loss < best_val_loss:\n", " best_val_loss = avg_val_loss\n", " model.to(\"cpu\")\n", " torch.save(model.state_dict(), MODEL_SAVE_PATH)\n", " model.to(DEVICE)\n", " print(f\" -> Модель сохранена, новая лучшая Val Loss: {best_val_loss:.4f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Снимем качество на тестовых данных из исходного датасета" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Метрики на тестовой выборке (из обучаемого датасета) итоговой модели\n", "{'loss': 0.8911580686666527, 'accuracy': 0.82944, 'precision': 0.8185334158415841, 'recall': 0.84656, 'f1_score': 0.8323108384458078}\n" ] } ], "source": [ "def evaluate_on_test(model, test_loader, device, criterion):\n", " model.eval()\n", " total_test_loss = 0\n", " all_preds = []\n", " all_labels = []\n", "\n", " with torch.no_grad():\n", " for batch_X, batch_y in test_loader:\n", " batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n", " outputs = model(batch_X)\n", " loss = criterion(outputs, batch_y)\n", " total_test_loss += loss.item()\n", " \n", " _, predicted = torch.max(outputs.data, 1)\n", " all_preds.extend(predicted.cpu().numpy())\n", " all_labels.extend(batch_y.cpu().numpy())\n", " \n", " avg_test_loss = total_test_loss / len(test_loader)\n", " \n", " accuracy = accuracy_score(all_labels, all_preds)\n", " precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n", " \n", " return {'loss': avg_test_loss, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1_score': f1}\n", "\n", "\n", "test_loader = create_dataloader(X_test, y_test, BATCH_SIZE)\n", "test_metrics = evaluate_on_test(model, test_loader, DEVICE, criterion)\n", "print(f\"Метрики на тестовой выборке (из обучаемого датасета) итоговой модели\\n{test_metrics}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Снимем качество на тестовых данных нового датасета. Считаем данные" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "text_processor = generator.get_text_processor()\n", "config_polarity = DatasetConfig(\n", " load_from_disk=True,\n", " path_to_data=\"../datasets\",\n", " train_size=25000, # взяли весь датасет\n", " val_size=12500,\n", " test_size=12500,\n", " build_vocab=False\n", ")\n", "generator_polarity = DatasetGenerator(DatasetName.POLARITY, config=config_polarity)\n", "generator_polarity.vocab = generator.vocab\n", "generator_polarity.id2word = generator.id2word\n", "generator_polarity.text_processor = text_processor\n", "(X_train, y_train), (X_val, y_val), (X_test, y_test) = generator_polarity.generate_dataset()\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "test_loader = create_dataloader(X_test, y_test, BATCH_SIZE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Посмтрим на метрики" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Метрики на тестовой выборке (из неизвестного датасета) итоговой модели\n", "{'loss': 0.7563912787911843, 'accuracy': 0.7212, 'precision': 0.6816347780814785, 'recall': 0.8344486934353091, 'f1_score': 0.7503402822551759}\n" ] } ], "source": [ "test_metrics = evaluate_on_test(model, test_loader, DEVICE, criterion)\n", "print(f\"Метрики на тестовой выборке (из неизвестного датасета) итоговой модели\\n{test_metrics}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "В целом видно, что модель что-то, да выучила. Гипотезы по улучшению:\n", " - Больше и разнообразнее данные для обучения\n", " - Чем больше словарь - тем лучше\n", " - Нужно чтобы тестовый датасет был больше похож на обучаемый" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Сохранение итоговой модели" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Конфигурация модели сохранена в: ../pretrained/config.json\n" ] } ], "source": [ "with open(VOCAB_SAVE_PATH, 'w', encoding='utf-8') as f:\n", " json.dump(generator.vocab, f, ensure_ascii=False, indent=4)\n", "\n", "config = {\n", " \"model_type\": MODEL_TO_TRAIN,\n", " \"max_seq_len\": MAX_SEQ_LEN,\n", " \"model_params\": model_params,\n", "}\n", "with open(CONFIG_SAVE_PATH, 'w', encoding='utf-8') as f:\n", " json.dump(config, f, ensure_ascii=False, indent=4)\n", "print(f\"Конфигурация модели сохранена в: {CONFIG_SAVE_PATH}\")" ] } ], "metadata": { "kernelspec": { "display_name": "monkey-coding-dl-project-rj23F0vJ-py3.12", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.5" } }, "nbformat": 4, "nbformat_minor": 2 }