Synced repo using 'sync_with_huggingface' Github Action
Browse files- notebooks/mamba_vs_transformerts.ipynb +361 -0
- notebooks/models_comparations_second_dataset.ipynb +240 -0
- notebooks/train.ipynb +309 -67
- pretrained/best_model.pth +2 -2
- pretrained/config.json +4 -4
- pretrained/vocab.json +0 -0
- src/data_utils/config.py +2 -0
- src/data_utils/dataset_generator.py +2 -1
notebooks/mamba_vs_transformerts.ipynb
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os\n",
|
10 |
+
"import time\n",
|
11 |
+
"import torch\n",
|
12 |
+
"import warnings\n",
|
13 |
+
"import numpy as np\n",
|
14 |
+
"import pandas as pd\n",
|
15 |
+
"import torch.nn as nn\n",
|
16 |
+
"import torch.optim as optim\n",
|
17 |
+
"from torch.utils.data import DataLoader, TensorDataset\n",
|
18 |
+
"from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n",
|
19 |
+
"for warn in [UserWarning, FutureWarning]: warnings.filterwarnings(\"ignore\", category = warn)\n",
|
20 |
+
"\n",
|
21 |
+
"from src.data_utils.config import DatasetConfig\n",
|
22 |
+
"from src.data_utils.dataset_params import DatasetName\n",
|
23 |
+
"from src.data_utils.dataset_generator import DatasetGenerator\n",
|
24 |
+
"from src.models.models import TransformerClassifier, CustomMambaClassifier, LSTMClassifier\n",
|
25 |
+
"\n",
|
26 |
+
"MAX_SEQ_LEN = 300\n",
|
27 |
+
"EMBEDDING_DIM = 128\n",
|
28 |
+
"BATCH_SIZE = 32\n",
|
29 |
+
"LEARNING_RATE = 7e-5 # уменьшили lr: 1e-4 -> 7e-5\n",
|
30 |
+
"NUM_EPOCHS = 20 # подняли количество эпох: 5 -> 20\n",
|
31 |
+
"NUM_CLASSES = 2\n",
|
32 |
+
"\n",
|
33 |
+
"SAVE_DIR = \"../best_models/\"\n",
|
34 |
+
"os.makedirs(SAVE_DIR, exist_ok=True)\n",
|
35 |
+
"DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
36 |
+
"\n",
|
37 |
+
"config = DatasetConfig(\n",
|
38 |
+
" load_from_disk=True,\n",
|
39 |
+
" path_to_data=\"../datasets\",\n",
|
40 |
+
" train_size=25000, # увеличили количество сэмплов\n",
|
41 |
+
" val_size=12500,\n",
|
42 |
+
" test_size=12500\n",
|
43 |
+
")"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"cell_type": "code",
|
48 |
+
"execution_count": 2,
|
49 |
+
"metadata": {},
|
50 |
+
"outputs": [],
|
51 |
+
"source": [
|
52 |
+
"\n",
|
53 |
+
"generator = DatasetGenerator(DatasetName.IMDB, config=config)\n",
|
54 |
+
"(X_train, y_train), (X_val, y_val), (X_test, y_test) = generator.generate_dataset()\n",
|
55 |
+
"VOCAB_SIZE = len(generator.vocab)"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "markdown",
|
60 |
+
"metadata": {},
|
61 |
+
"source": [
|
62 |
+
"Также создадим генератор для тестовых данных второго датасета"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "code",
|
67 |
+
"execution_count": 3,
|
68 |
+
"metadata": {},
|
69 |
+
"outputs": [],
|
70 |
+
"source": [
|
71 |
+
"text_processor = generator.get_text_processor()\n",
|
72 |
+
"config_polarity = DatasetConfig(\n",
|
73 |
+
" load_from_disk=True,\n",
|
74 |
+
" path_to_data=\"../datasets\",\n",
|
75 |
+
" test_size=10000,\n",
|
76 |
+
" build_vocab=False\n",
|
77 |
+
")\n",
|
78 |
+
"generator_polarity = DatasetGenerator(DatasetName.POLARITY, config=config_polarity)\n",
|
79 |
+
"generator_polarity.vocab = generator.vocab\n",
|
80 |
+
"generator_polarity.id2word = generator.id2word\n",
|
81 |
+
"generator_polarity.text_processor = text_processor\n",
|
82 |
+
"(_, _), (_, _), (X_test_polarity, y_test_polarity) = generator_polarity.generate_dataset()"
|
83 |
+
]
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"cell_type": "code",
|
87 |
+
"execution_count": 4,
|
88 |
+
"metadata": {},
|
89 |
+
"outputs": [],
|
90 |
+
"source": [
|
91 |
+
"\n",
|
92 |
+
"def train_and_evaluate(model, train_loader, val_loader, optimizer, criterion, num_epochs, device, model_name, save_path):\n",
|
93 |
+
" best_val_f1 = 0.0\n",
|
94 |
+
" history = {'train_loss': [], 'val_loss': [], 'val_accuracy': [], 'val_f1': []}\n",
|
95 |
+
" \n",
|
96 |
+
" print(f\"--- Начало обучения модели: {model_name} на устройстве {device} ---\")\n",
|
97 |
+
"\n",
|
98 |
+
" for epoch in range(num_epochs):\n",
|
99 |
+
" model.train()\n",
|
100 |
+
" start_time = time.time()\n",
|
101 |
+
" total_train_loss = 0\n",
|
102 |
+
"\n",
|
103 |
+
" for batch_X, batch_y in train_loader:\n",
|
104 |
+
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
|
105 |
+
" optimizer.zero_grad()\n",
|
106 |
+
" outputs = model(batch_X)\n",
|
107 |
+
" loss = criterion(outputs, batch_y)\n",
|
108 |
+
" loss.backward()\n",
|
109 |
+
" optimizer.step()\n",
|
110 |
+
" total_train_loss += loss.item()\n",
|
111 |
+
" \n",
|
112 |
+
" avg_train_loss = total_train_loss / len(train_loader)\n",
|
113 |
+
" history['train_loss'].append(avg_train_loss)\n",
|
114 |
+
"\n",
|
115 |
+
" model.eval()\n",
|
116 |
+
" total_val_loss = 0\n",
|
117 |
+
" all_preds = []\n",
|
118 |
+
" all_labels = []\n",
|
119 |
+
"\n",
|
120 |
+
" with torch.no_grad():\n",
|
121 |
+
" for batch_X, batch_y in val_loader:\n",
|
122 |
+
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
|
123 |
+
" outputs = model(batch_X)\n",
|
124 |
+
" loss = criterion(outputs, batch_y)\n",
|
125 |
+
" total_val_loss += loss.item()\n",
|
126 |
+
" \n",
|
127 |
+
" _, predicted = torch.max(outputs.data, 1)\n",
|
128 |
+
" all_preds.extend(predicted.cpu().numpy())\n",
|
129 |
+
" all_labels.extend(batch_y.cpu().numpy())\n",
|
130 |
+
" \n",
|
131 |
+
" avg_val_loss = total_val_loss / len(val_loader)\n",
|
132 |
+
" \n",
|
133 |
+
" accuracy = accuracy_score(all_labels, all_preds)\n",
|
134 |
+
" precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
|
135 |
+
" \n",
|
136 |
+
" history['val_loss'].append(avg_val_loss)\n",
|
137 |
+
" history['val_accuracy'].append(accuracy)\n",
|
138 |
+
" history['val_f1'].append(f1)\n",
|
139 |
+
"\n",
|
140 |
+
" epoch_time = time.time() - start_time\n",
|
141 |
+
" print(f\"Эпоха {epoch+1}/{num_epochs} | Время: {epoch_time:.2f}с | Train Loss: {avg_train_loss:.4f} | \"\n",
|
142 |
+
" f\"Val Loss: {avg_val_loss:.4f} | Val Acc: {accuracy:.4f} | Val F1: {f1:.4f}\")\n",
|
143 |
+
"\n",
|
144 |
+
" if f1 > best_val_f1:\n",
|
145 |
+
" best_val_f1 = f1\n",
|
146 |
+
" torch.save(model.state_dict(), save_path)\n",
|
147 |
+
" print(f\" -> Модель сохранена, новый лучший Val F1: {best_val_f1:.4f}\")\n",
|
148 |
+
" \n",
|
149 |
+
" print(f\"--- Обучение модели {model_name} завершено ---\")\n",
|
150 |
+
" return history\n",
|
151 |
+
"\n",
|
152 |
+
"def evaluate_on_test(model, test_loader, device, criterion):\n",
|
153 |
+
" model.eval()\n",
|
154 |
+
" total_test_loss = 0\n",
|
155 |
+
" all_preds = []\n",
|
156 |
+
" all_labels = []\n",
|
157 |
+
"\n",
|
158 |
+
" with torch.no_grad():\n",
|
159 |
+
" for batch_X, batch_y in test_loader:\n",
|
160 |
+
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
|
161 |
+
" outputs = model(batch_X)\n",
|
162 |
+
" loss = criterion(outputs, batch_y)\n",
|
163 |
+
" total_test_loss += loss.item()\n",
|
164 |
+
" \n",
|
165 |
+
" _, predicted = torch.max(outputs.data, 1)\n",
|
166 |
+
" all_preds.extend(predicted.cpu().numpy())\n",
|
167 |
+
" all_labels.extend(batch_y.cpu().numpy())\n",
|
168 |
+
" \n",
|
169 |
+
" avg_test_loss = total_test_loss / len(test_loader)\n",
|
170 |
+
" \n",
|
171 |
+
" accuracy = accuracy_score(all_labels, all_preds)\n",
|
172 |
+
" precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
|
173 |
+
" \n",
|
174 |
+
" return {'loss': avg_test_loss, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1_score': f1}"
|
175 |
+
]
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"cell_type": "code",
|
179 |
+
"execution_count": 5,
|
180 |
+
"metadata": {},
|
181 |
+
"outputs": [],
|
182 |
+
"source": [
|
183 |
+
"\n",
|
184 |
+
"def create_dataloader(X, y, batch_size, shuffle=True):\n",
|
185 |
+
" X_tensor = torch.as_tensor(X, dtype=torch.long)\n",
|
186 |
+
" y_tensor = torch.as_tensor(y, dtype=torch.long)\n",
|
187 |
+
" dataset = TensorDataset(X_tensor, y_tensor)\n",
|
188 |
+
" return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)\n",
|
189 |
+
"\n",
|
190 |
+
"train_loader = create_dataloader(X_train, y_train, BATCH_SIZE)\n",
|
191 |
+
"val_loader = create_dataloader(X_val, y_val, BATCH_SIZE, shuffle=False)\n",
|
192 |
+
"test_loader = create_dataloader(X_test, y_test, BATCH_SIZE, shuffle=False)\n",
|
193 |
+
"test_loader_polarity = create_dataloader(X_test_polarity, y_test_polarity, BATCH_SIZE, shuffle=False)"
|
194 |
+
]
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"cell_type": "code",
|
198 |
+
"execution_count": 6,
|
199 |
+
"metadata": {},
|
200 |
+
"outputs": [
|
201 |
+
{
|
202 |
+
"name": "stdout",
|
203 |
+
"output_type": "stream",
|
204 |
+
"text": [
|
205 |
+
"--- Начало обучения модели: CustomMamba на устройстве cuda ---\n",
|
206 |
+
"Эпоха 1/20 | Время: 897.35с | Train Loss: 0.6261 | Val Loss: 0.5313 | Val Acc: 0.7388 | Val F1: 0.7503\n",
|
207 |
+
" -> Модель сохранена, новый лучший Val F1: 0.7503\n",
|
208 |
+
"Эпоха 2/20 | Время: 903.31с | Train Loss: 0.4748 | Val Loss: 0.4559 | Val Acc: 0.7889 | Val F1: 0.7901\n",
|
209 |
+
" -> Модель сохранена, новый лучший Val F1: 0.7901\n",
|
210 |
+
"Эпоха 3/20 | Время: 930.96с | Train Loss: 0.3955 | Val Loss: 0.4176 | Val Acc: 0.8090 | Val F1: 0.8048\n",
|
211 |
+
" -> Модель сохранена, новый лучший Val F1: 0.8048\n",
|
212 |
+
"Эпоха 4/20 | Время: 952.79с | Train Loss: 0.3429 | Val Loss: 0.3998 | Val Acc: 0.8230 | Val F1: 0.8303\n",
|
213 |
+
" -> Модель сохранена, новый лучший Val F1: 0.8303\n",
|
214 |
+
"Эпоха 5/20 | Время: 904.67с | Train Loss: 0.2984 | Val Loss: 0.4387 | Val Acc: 0.8165 | Val F1: 0.8337\n",
|
215 |
+
" -> Модель сохранена, новый лучший Val F1: 0.8337\n",
|
216 |
+
"Эпоха 6/20 | Время: 935.58с | Train Loss: 0.2609 | Val Loss: 0.4219 | Val Acc: 0.8255 | Val F1: 0.8386\n",
|
217 |
+
" -> Модель сохранена, новый лучший Val F1: 0.8386\n",
|
218 |
+
"Эпоха 7/20 | Время: 902.73с | Train Loss: 0.2266 | Val Loss: 0.4342 | Val Acc: 0.8334 | Val F1: 0.8420\n",
|
219 |
+
" -> Модель сохранена, новый лучший Val F1: 0.8420\n",
|
220 |
+
"Эпоха 8/20 | Время: 921.49с | Train Loss: 0.1953 | Val Loss: 0.4675 | Val Acc: 0.8299 | Val F1: 0.8411\n",
|
221 |
+
"Эпоха 9/20 | Время: 892.54с | Train Loss: 0.1590 | Val Loss: 0.5205 | Val Acc: 0.8359 | Val F1: 0.8394\n",
|
222 |
+
"Эпоха 10/20 | Время: 893.22с | Train Loss: 0.1263 | Val Loss: 0.6014 | Val Acc: 0.8303 | Val F1: 0.8398\n",
|
223 |
+
"Эпоха 11/20 | Время: 953.69с | Train Loss: 0.0945 | Val Loss: 0.8025 | Val Acc: 0.8183 | Val F1: 0.8352\n",
|
224 |
+
"Эпоха 12/20 | Время: 924.98с | Train Loss: 0.0644 | Val Loss: 0.8539 | Val Acc: 0.8290 | Val F1: 0.8369\n",
|
225 |
+
"Эпоха 13/20 | Время: 916.92с | Train Loss: 0.0428 | Val Loss: 1.0646 | Val Acc: 0.8286 | Val F1: 0.8266\n",
|
226 |
+
"Эпоха 14/20 | Время: 904.39с | Train Loss: 0.0266 | Val Loss: 1.5225 | Val Acc: 0.8149 | Val F1: 0.8305\n",
|
227 |
+
"Эпоха 15/20 | Время: 923.80с | Train Loss: 0.0199 | Val Loss: 1.6176 | Val Acc: 0.8242 | Val F1: 0.8351\n",
|
228 |
+
"Эпоха 16/20 | Время: 922.03с | Train Loss: 0.0134 | Val Loss: 1.8983 | Val Acc: 0.8258 | Val F1: 0.8354\n",
|
229 |
+
"Эпоха 17/20 | Время: 914.58с | Train Loss: 0.0069 | Val Loss: 1.7992 | Val Acc: 0.8260 | Val F1: 0.8252\n",
|
230 |
+
"Эпоха 18/20 | Время: 937.09с | Train Loss: 0.0134 | Val Loss: 2.2935 | Val Acc: 0.8120 | Val F1: 0.8287\n",
|
231 |
+
"Эпоха 19/20 | Время: 898.26с | Train Loss: 0.0061 | Val Loss: 2.3201 | Val Acc: 0.8294 | Val F1: 0.8338\n",
|
232 |
+
"Эпоха 20/20 | Время: 910.46с | Train Loss: 0.0065 | Val Loss: 1.7146 | Val Acc: 0.8282 | Val F1: 0.8311\n",
|
233 |
+
"--- Обучение модели CustomMamba завершено ---\n",
|
234 |
+
"--- Оценка лучшей модели CustomMamba на тестовых данных ---\n",
|
235 |
+
"Результаты для CustomMamba: {'loss': 0.4370202890042301, 'accuracy': 0.8328, 'precision': 0.8017116333043226, 'recall': 0.88432, 'f1_score': 0.8409920876445527}\n",
|
236 |
+
"------------------------------------------------------------\n",
|
237 |
+
"--- Начало обучения модели: Lib_Transformer на устройстве cuda ---\n",
|
238 |
+
"Эпоха 1/20 | Время: 21.90с | Train Loss: 0.5894 | Val Loss: 0.5326 | Val Acc: 0.7487 | Val F1: 0.7689\n",
|
239 |
+
" -> Модель сохранена, новый лучший Val F1: 0.7689\n",
|
240 |
+
"Эпоха 2/20 | Время: 21.56с | Train Loss: 0.4505 | Val Loss: 0.4653 | Val Acc: 0.7953 | Val F1: 0.7894\n",
|
241 |
+
" -> Модель сохранена, новый лучший Val F1: 0.7894\n",
|
242 |
+
"Эпоха 3/20 | Время: 21.64с | Train Loss: 0.3925 | Val Loss: 0.4553 | Val Acc: 0.8001 | Val F1: 0.7818\n",
|
243 |
+
"Эпоха 4/20 | Время: 22.27с | Train Loss: 0.3562 | Val Loss: 0.4642 | Val Acc: 0.7836 | Val F1: 0.7480\n",
|
244 |
+
"Эпоха 5/20 | Время: 21.73с | Train Loss: 0.3294 | Val Loss: 0.4035 | Val Acc: 0.8305 | Val F1: 0.8337\n",
|
245 |
+
" -> Модель сохранена, новый лучший Val F1: 0.8337\n",
|
246 |
+
"Эпоха 6/20 | Время: 21.58с | Train Loss: 0.3022 | Val Loss: 0.3936 | Val Acc: 0.8364 | Val F1: 0.8356\n",
|
247 |
+
" -> Модель сохранена, новый лучший Val F1: 0.8356\n",
|
248 |
+
"Эпоха 7/20 | Время: 21.67с | Train Loss: 0.2752 | Val Loss: 0.3853 | Val Acc: 0.8380 | Val F1: 0.8345\n",
|
249 |
+
"Эпоха 8/20 | Время: 21.82с | Train Loss: 0.2507 | Val Loss: 0.3882 | Val Acc: 0.8377 | Val F1: 0.8329\n",
|
250 |
+
"Эпоха 9/20 | Время: 21.81с | Train Loss: 0.2286 | Val Loss: 0.4488 | Val Acc: 0.8118 | Val F1: 0.8333\n",
|
251 |
+
"Эпоха 10/20 | Время: 21.65с | Train Loss: 0.2056 | Val Loss: 0.3876 | Val Acc: 0.8402 | Val F1: 0.8350\n",
|
252 |
+
"Эпоха 11/20 | Время: 21.60с | Train Loss: 0.1803 | Val Loss: 0.3949 | Val Acc: 0.8358 | Val F1: 0.8385\n",
|
253 |
+
" -> Модель сохранена, новый лучший Val F1: 0.8385\n",
|
254 |
+
"Эпоха 12/20 | Время: 21.57с | Train Loss: 0.1605 | Val Loss: 0.4024 | Val Acc: 0.8360 | Val F1: 0.8414\n",
|
255 |
+
" -> Модель сохранена, новый лучший Val F1: 0.8414\n",
|
256 |
+
"Эпоха 13/20 | Время: 21.65с | Train Loss: 0.1392 | Val Loss: 0.4087 | Val Acc: 0.8356 | Val F1: 0.8340\n",
|
257 |
+
"Эпоха 14/20 | Время: 21.57с | Train Loss: 0.1172 | Val Loss: 0.4315 | Val Acc: 0.8323 | Val F1: 0.8297\n",
|
258 |
+
"Эпоха 15/20 | Время: 21.56с | Train Loss: 0.1005 | Val Loss: 0.4626 | Val Acc: 0.8317 | Val F1: 0.8284\n",
|
259 |
+
"Эпоха 16/20 | Время: 21.65с | Train Loss: 0.0876 | Val Loss: 0.4680 | Val Acc: 0.8318 | Val F1: 0.8335\n",
|
260 |
+
"Эпоха 17/20 | Время: 21.73с | Train Loss: 0.0728 | Val Loss: 0.4823 | Val Acc: 0.8317 | Val F1: 0.8326\n",
|
261 |
+
"Эпоха 18/20 | Время: 21.55с | Train Loss: 0.0656 | Val Loss: 0.5540 | Val Acc: 0.8206 | Val F1: 0.8068\n",
|
262 |
+
"Эпоха 19/20 | Время: 21.56с | Train Loss: 0.0491 | Val Loss: 0.6002 | Val Acc: 0.8235 | Val F1: 0.8178\n",
|
263 |
+
"Эпоха 20/20 | Время: 21.60с | Train Loss: 0.0445 | Val Loss: 0.5776 | Val Acc: 0.8314 | Val F1: 0.8318\n",
|
264 |
+
"--- Обучение модели Lib_Transformer завершено ---\n",
|
265 |
+
"--- Оценка лучшей модели Lib_Transformer на тестовых данных ---\n",
|
266 |
+
"Результаты для Lib_Transformer: {'loss': 0.3889380347202806, 'accuracy': 0.84488, 'precision': 0.8214765100671141, 'recall': 0.88128, 'f1_score': 0.8503280586646083}\n",
|
267 |
+
"------------------------------------------------------------\n",
|
268 |
+
"\n",
|
269 |
+
"\n",
|
270 |
+
"--- Итоговая таблица сравнения моделей на тестовых данных ---\n",
|
271 |
+
" loss accuracy precision recall f1_score\n",
|
272 |
+
"CustomMamba 0.437020 0.83280 0.801712 0.884320 0.840992\n",
|
273 |
+
"CustomMamba_polarity 0.567920 0.73850 0.688808 0.869522 0.768686\n",
|
274 |
+
"Lib_Transformer 0.388938 0.84488 0.821477 0.881280 0.850328\n",
|
275 |
+
"Lib_Transformer_polarity 0.543388 0.73980 0.690897 0.867320 0.769122\n"
|
276 |
+
]
|
277 |
+
}
|
278 |
+
],
|
279 |
+
"source": [
|
280 |
+
"model_configs = {\n",
|
281 |
+
" \"CustomMamba\": {\n",
|
282 |
+
" \"class\": CustomMambaClassifier,\n",
|
283 |
+
" \"params\": {'vocab_size': VOCAB_SIZE, 'd_model': EMBEDDING_DIM, 'd_state': 8, \n",
|
284 |
+
" 'd_conv': 4, 'num_layers': 2, 'num_classes': NUM_CLASSES},\n",
|
285 |
+
" },\n",
|
286 |
+
"\n",
|
287 |
+
" \"Lib_Transformer\": {\n",
|
288 |
+
" \"class\": TransformerClassifier,\n",
|
289 |
+
" \"params\": {'vocab_size': VOCAB_SIZE, 'embed_dim': EMBEDDING_DIM, 'num_heads': 8, \n",
|
290 |
+
" 'num_layers': 4, 'num_classes': NUM_CLASSES, 'max_seq_len': MAX_SEQ_LEN},\n",
|
291 |
+
" # num_layers: 2 -> 4\n",
|
292 |
+
" # num_heads: 4 -> 8\n",
|
293 |
+
" },\n",
|
294 |
+
"}\n",
|
295 |
+
"\n",
|
296 |
+
"results = {}\n",
|
297 |
+
"for model_name, config in model_configs.items():\n",
|
298 |
+
"\n",
|
299 |
+
" model_path = os.path.join(SAVE_DIR, f\"best_model_{model_name.lower()}.pth\")\n",
|
300 |
+
" \n",
|
301 |
+
" model = config['class'](**config['params']).to(DEVICE)\n",
|
302 |
+
" optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
|
303 |
+
" criterion = nn.CrossEntropyLoss()\n",
|
304 |
+
" \n",
|
305 |
+
" train_and_evaluate(\n",
|
306 |
+
" model=model, train_loader=train_loader, val_loader=val_loader,\n",
|
307 |
+
" optimizer=optimizer, criterion=criterion, num_epochs=NUM_EPOCHS,\n",
|
308 |
+
" device=DEVICE, model_name=model_name, save_path=model_path\n",
|
309 |
+
" )\n",
|
310 |
+
" \n",
|
311 |
+
" print(f\"--- Оценка лучшей модели {model_name} на тестовых данных ---\")\n",
|
312 |
+
" if os.path.exists(model_path):\n",
|
313 |
+
" best_model = config['class'](**config['params']).to(DEVICE)\n",
|
314 |
+
" best_model.load_state_dict(torch.load(model_path))\n",
|
315 |
+
" test_metrics = evaluate_on_test(best_model, test_loader, DEVICE, criterion)\n",
|
316 |
+
" results[model_name] = test_metrics\n",
|
317 |
+
" results[model_name + \"_polarity\"] = evaluate_on_test(best_model, test_loader_polarity, DEVICE, criterion)\n",
|
318 |
+
" print(f\"Результаты для {model_name}: {test_metrics}\")\n",
|
319 |
+
" else:\n",
|
320 |
+
" print(f\"Файл лучшей модели для {model_name} не найден. Пропускаем оценку.\")\n",
|
321 |
+
"\n",
|
322 |
+
" print(\"-\" * 60)\n",
|
323 |
+
" \n",
|
324 |
+
"if results:\n",
|
325 |
+
" results_df = pd.DataFrame(results).T\n",
|
326 |
+
" print(\"\\n\\n--- Итоговая таблица сравнения моделей на тестовых данных ---\")\n",
|
327 |
+
" print(results_df.to_string())\n",
|
328 |
+
"else:\n",
|
329 |
+
" print(\"Не удалось получить результаты ни для одной модели.\")"
|
330 |
+
]
|
331 |
+
},
|
332 |
+
{
|
333 |
+
"cell_type": "markdown",
|
334 |
+
"metadata": {},
|
335 |
+
"source": [
|
336 |
+
"Видим, что transformer сильно выигрывает у мамбы как и по времени, так и по качеству. Дальше посмотрим как они справляются на данных из других датасетов"
|
337 |
+
]
|
338 |
+
}
|
339 |
+
],
|
340 |
+
"metadata": {
|
341 |
+
"kernelspec": {
|
342 |
+
"display_name": "monkey-coding-dl-project-F4QJzkF_-py3.12",
|
343 |
+
"language": "python",
|
344 |
+
"name": "python3"
|
345 |
+
},
|
346 |
+
"language_info": {
|
347 |
+
"codemirror_mode": {
|
348 |
+
"name": "ipython",
|
349 |
+
"version": 3
|
350 |
+
},
|
351 |
+
"file_extension": ".py",
|
352 |
+
"mimetype": "text/x-python",
|
353 |
+
"name": "python",
|
354 |
+
"nbconvert_exporter": "python",
|
355 |
+
"pygments_lexer": "ipython3",
|
356 |
+
"version": "3.12.11"
|
357 |
+
}
|
358 |
+
},
|
359 |
+
"nbformat": 4,
|
360 |
+
"nbformat_minor": 2
|
361 |
+
}
|
notebooks/models_comparations_second_dataset.ipynb
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"Восстановим константы, словарь и модели из прошлого нотубка"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": 1,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [],
|
15 |
+
"source": [
|
16 |
+
"import os\n",
|
17 |
+
"import pandas as pd\n",
|
18 |
+
"import torch\n",
|
19 |
+
"import torch.nn as nn\n",
|
20 |
+
"import torch.nn.functional as F\n",
|
21 |
+
"from torch.utils.data import DataLoader, TensorDataset\n",
|
22 |
+
"\n",
|
23 |
+
"from src.models.models import TransformerClassifier, LSTMClassifier, CustomMambaClassifier, SimpleMambaBlock\n",
|
24 |
+
"from src.data_utils.config import DatasetConfig\n",
|
25 |
+
"from src.data_utils.dataset_params import DatasetName\n",
|
26 |
+
"from src.data_utils.dataset_generator import DatasetGenerator\n",
|
27 |
+
"\n",
|
28 |
+
"MAX_SEQ_LEN = 300\n",
|
29 |
+
"EMBEDDING_DIM = 128\n",
|
30 |
+
"BATCH_SIZE = 32 \n",
|
31 |
+
"NUM_CLASSES = 2\n",
|
32 |
+
"SAVE_DIR = \"../pretrained_comparison\" \n",
|
33 |
+
"DATA_DIR = \"../datasets\" \n",
|
34 |
+
"DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
35 |
+
"\n",
|
36 |
+
"config = DatasetConfig(load_from_disk=True, path_to_data=DATA_DIR)\n",
|
37 |
+
"generator = DatasetGenerator(DatasetName.IMDB, config=config)\n",
|
38 |
+
"\n",
|
39 |
+
"_, _, _ = generator.generate_dataset() \n",
|
40 |
+
"vocab = generator.vocab\n",
|
41 |
+
"VOCAB_SIZE = len(vocab)\n",
|
42 |
+
"text_processor = generator.get_text_processor()"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "markdown",
|
47 |
+
"metadata": {},
|
48 |
+
"source": [
|
49 |
+
"Возьмем всопомгательную функцию из пролшло нотубка"
|
50 |
+
]
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"cell_type": "code",
|
54 |
+
"execution_count": 2,
|
55 |
+
"metadata": {},
|
56 |
+
"outputs": [],
|
57 |
+
"source": [
|
58 |
+
"from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n",
|
59 |
+
"\n",
|
60 |
+
"def evaluate_on_test(model, test_loader, device, criterion):\n",
|
61 |
+
" model.eval()\n",
|
62 |
+
" total_test_loss = 0\n",
|
63 |
+
" all_preds = []\n",
|
64 |
+
" all_labels = []\n",
|
65 |
+
"\n",
|
66 |
+
" with torch.no_grad():\n",
|
67 |
+
" for batch_X, batch_y in test_loader:\n",
|
68 |
+
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
|
69 |
+
" outputs = model(batch_X)\n",
|
70 |
+
" loss = criterion(outputs, batch_y)\n",
|
71 |
+
" total_test_loss += loss.item()\n",
|
72 |
+
" \n",
|
73 |
+
" _, predicted = torch.max(outputs.data, 1)\n",
|
74 |
+
" all_preds.extend(predicted.cpu().numpy())\n",
|
75 |
+
" all_labels.extend(batch_y.cpu().numpy())\n",
|
76 |
+
" \n",
|
77 |
+
" avg_test_loss = total_test_loss / len(test_loader)\n",
|
78 |
+
" \n",
|
79 |
+
" accuracy = accuracy_score(all_labels, all_preds)\n",
|
80 |
+
" precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
|
81 |
+
" \n",
|
82 |
+
" return {'loss': avg_test_loss, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1_score': f1}"
|
83 |
+
]
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"cell_type": "markdown",
|
87 |
+
"metadata": {},
|
88 |
+
"source": [
|
89 |
+
"Создадим генератор датасета и передадим в него уже готовый текстовый процессор, заберем датасет из другого распределения"
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"cell_type": "code",
|
94 |
+
"execution_count": 3,
|
95 |
+
"metadata": {},
|
96 |
+
"outputs": [],
|
97 |
+
"source": [
|
98 |
+
"\n",
|
99 |
+
"def create_dataloader(X, y, batch_size, shuffle=True):\n",
|
100 |
+
" X_tensor = torch.as_tensor(X, dtype=torch.long)\n",
|
101 |
+
" y_tensor = torch.as_tensor(y, dtype=torch.long)\n",
|
102 |
+
" dataset = TensorDataset(X_tensor, y_tensor)\n",
|
103 |
+
" return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)\n",
|
104 |
+
"\n",
|
105 |
+
"text_processor = generator.get_text_processor()\n",
|
106 |
+
"config_polarity = DatasetConfig(\n",
|
107 |
+
" load_from_disk=True,\n",
|
108 |
+
" path_to_data=\"../datasets\",\n",
|
109 |
+
" train_size=25000, # взяли весь датасет\n",
|
110 |
+
" val_size=12500,\n",
|
111 |
+
" test_size=12500,\n",
|
112 |
+
" build_vocab=False\n",
|
113 |
+
")\n",
|
114 |
+
"generator_polarity = DatasetGenerator(DatasetName.POLARITY, config=config_polarity)\n",
|
115 |
+
"generator_polarity.vocab = generator.vocab\n",
|
116 |
+
"generator_polarity.id2word = generator.id2word\n",
|
117 |
+
"generator_polarity.text_processor = text_processor\n",
|
118 |
+
"(X_train, y_train), (X_val, y_val), (X_test, y_test) = generator_polarity.generate_dataset()\n",
|
119 |
+
"\n",
|
120 |
+
"\n",
|
121 |
+
"test_loader = create_dataloader(X_test, y_test, BATCH_SIZE, shuffle=False)"
|
122 |
+
]
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"cell_type": "markdown",
|
126 |
+
"metadata": {},
|
127 |
+
"source": [
|
128 |
+
"Восстановим конфигурации конфигов моделей из прошлого нотубка"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "code",
|
133 |
+
"execution_count": 4,
|
134 |
+
"metadata": {},
|
135 |
+
"outputs": [],
|
136 |
+
"source": [
|
137 |
+
"model_configs = {\n",
|
138 |
+
" \"CustomMamba\": {\n",
|
139 |
+
" \"class\": CustomMambaClassifier,\n",
|
140 |
+
" \"params\": {'vocab_size': VOCAB_SIZE, 'd_model': EMBEDDING_DIM, 'd_state': 8, \n",
|
141 |
+
" 'd_conv': 4, 'num_layers': 2, 'num_classes': NUM_CLASSES},\n",
|
142 |
+
" },\n",
|
143 |
+
" \"Lib_LSTM\": {\n",
|
144 |
+
" \"class\": LSTMClassifier,\n",
|
145 |
+
" \"params\": {'vocab_size': VOCAB_SIZE, 'embed_dim': EMBEDDING_DIM, 'hidden_dim': 128, \n",
|
146 |
+
" 'num_layers': 2, 'num_classes': NUM_CLASSES, 'dropout': 0.5},\n",
|
147 |
+
" },\n",
|
148 |
+
" \"Lib_Transformer\": {\n",
|
149 |
+
" \"class\": TransformerClassifier,\n",
|
150 |
+
" \"params\": {'vocab_size': VOCAB_SIZE, 'embed_dim': EMBEDDING_DIM, 'num_heads': 4, \n",
|
151 |
+
" 'num_layers': 2, 'num_classes': NUM_CLASSES, 'max_seq_len': MAX_SEQ_LEN},\n",
|
152 |
+
" },\n",
|
153 |
+
"}"
|
154 |
+
]
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"cell_type": "markdown",
|
158 |
+
"metadata": {},
|
159 |
+
"source": [
|
160 |
+
"Теперь посмотрим на результаты"
|
161 |
+
]
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"cell_type": "code",
|
165 |
+
"execution_count": 5,
|
166 |
+
"metadata": {},
|
167 |
+
"outputs": [
|
168 |
+
{
|
169 |
+
"name": "stderr",
|
170 |
+
"output_type": "stream",
|
171 |
+
"text": [
|
172 |
+
"/home/gab1k/.cache/pypoetry/virtualenvs/monkey-coding-dl-project-F4QJzkF_-py3.12/lib/python3.12/site-packages/torch/nn/modules/transformer.py:505: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. We recommend specifying layout=torch.jagged when constructing a nested tensor, as this layout receives active development, has better operator coverage, and works with torch.compile. (Triggered internally at /pytorch/aten/src/ATen/NestedTensorImpl.cpp:178.)\n",
|
173 |
+
" output = torch._nested_tensor_from_mask(\n"
|
174 |
+
]
|
175 |
+
},
|
176 |
+
{
|
177 |
+
"name": "stdout",
|
178 |
+
"output_type": "stream",
|
179 |
+
"text": [
|
180 |
+
"\n",
|
181 |
+
"\n",
|
182 |
+
"--- Итоговая таблица сравнения моделей на тестовых данных ---\n",
|
183 |
+
" loss accuracy precision recall f1_score\n",
|
184 |
+
"CustomMamba 0.583675 0.70344 0.653410 0.871734 0.746945\n",
|
185 |
+
"Lib_LSTM 0.675894 0.59520 0.574803 0.744423 0.648709\n",
|
186 |
+
"Lib_Transformer 0.618924 0.66432 0.612190 0.904238 0.730091\n"
|
187 |
+
]
|
188 |
+
}
|
189 |
+
],
|
190 |
+
"source": [
|
191 |
+
"results = {}\n",
|
192 |
+
"for model_name, config in model_configs.items(): \n",
|
193 |
+
" model_path = os.path.join(SAVE_DIR, f\"best_model_{model_name.lower()}.pth\") \n",
|
194 |
+
" model = config['class'](**config['params']).to(DEVICE)\n",
|
195 |
+
"\n",
|
196 |
+
" model.load_state_dict(torch.load(model_path, map_location=DEVICE))\n",
|
197 |
+
" criterion = nn.CrossEntropyLoss()\n",
|
198 |
+
" test_metrics = evaluate_on_test(model, test_loader, DEVICE, criterion)\n",
|
199 |
+
" results[model_name] = test_metrics\n",
|
200 |
+
" \n",
|
201 |
+
"results_df = pd.DataFrame(results).T\n",
|
202 |
+
"print(\"\\n\\n--- Итоговая таблица сравнения моделей на тестовых данных ---\")\n",
|
203 |
+
"print(results_df.to_string())\n"
|
204 |
+
]
|
205 |
+
},
|
206 |
+
{
|
207 |
+
"cell_type": "markdown",
|
208 |
+
"metadata": {},
|
209 |
+
"source": [
|
210 |
+
"Снимали тут на \"игрушечных данных\". На даже на них видно, что:\n",
|
211 |
+
" - accuracy выше всего на Mamba\n",
|
212 |
+
" - Трансформер справился тоже неплохо\n",
|
213 |
+
" - LSTM опять проиграл\n",
|
214 |
+
"\n",
|
215 |
+
"В следующем нотбуке обучим Mamba и Transformer на всем датасете и снимем качество на втором. Та модель, которая будет лучше, \"поедет в продакшн\" "
|
216 |
+
]
|
217 |
+
}
|
218 |
+
],
|
219 |
+
"metadata": {
|
220 |
+
"kernelspec": {
|
221 |
+
"display_name": "monkey-coding-dl-project-F4QJzkF_-py3.12",
|
222 |
+
"language": "python",
|
223 |
+
"name": "python3"
|
224 |
+
},
|
225 |
+
"language_info": {
|
226 |
+
"codemirror_mode": {
|
227 |
+
"name": "ipython",
|
228 |
+
"version": 3
|
229 |
+
},
|
230 |
+
"file_extension": ".py",
|
231 |
+
"mimetype": "text/x-python",
|
232 |
+
"name": "python",
|
233 |
+
"nbconvert_exporter": "python",
|
234 |
+
"pygments_lexer": "ipython3",
|
235 |
+
"version": "3.12.11"
|
236 |
+
}
|
237 |
+
},
|
238 |
+
"nbformat": 4,
|
239 |
+
"nbformat_minor": 2
|
240 |
+
}
|
notebooks/train.ipynb
CHANGED
@@ -1,68 +1,73 @@
|
|
1 |
{
|
2 |
"cells": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
"execution_count": 1,
|
6 |
"metadata": {},
|
7 |
"outputs": [],
|
8 |
"source": [
|
9 |
-
"import warnings\n",
|
10 |
-
"for warn in [UserWarning, FutureWarning]: warnings.filterwarnings(\"ignore\", category = warn)\n",
|
11 |
-
"\n",
|
12 |
"import os\n",
|
13 |
"import time\n",
|
14 |
"import json\n",
|
15 |
"import torch\n",
|
|
|
|
|
|
|
16 |
"import torch.nn as nn\n",
|
17 |
"import torch.optim as optim\n",
|
18 |
-
"\n",
|
19 |
"from torch.utils.data import DataLoader, TensorDataset\n",
|
|
|
|
|
20 |
"\n",
|
21 |
-
"
|
22 |
-
"from src.
|
23 |
-
|
24 |
-
|
25 |
-
{
|
26 |
-
"cell_type": "code",
|
27 |
-
"execution_count": 2,
|
28 |
-
"metadata": {},
|
29 |
-
"outputs": [],
|
30 |
-
"source": [
|
31 |
-
"MODEL_TO_TRAIN = 'Transformer' \n",
|
32 |
"\n",
|
33 |
-
"# Гиперпараметры данных и модели\n",
|
34 |
"MAX_SEQ_LEN = 300\n",
|
35 |
-
"EMBEDDING_DIM = 128
|
36 |
-
"BATCH_SIZE = 32\n",
|
37 |
-
"LEARNING_RATE =
|
38 |
-
"NUM_EPOCHS =
|
|
|
39 |
"\n",
|
40 |
-
"# Пути для сохранения артефактов\n",
|
41 |
"SAVE_DIR = \"../pretrained\"\n",
|
42 |
"os.makedirs(SAVE_DIR, exist_ok=True)\n",
|
43 |
"MODEL_SAVE_PATH = os.path.join(SAVE_DIR, \"best_model.pth\")\n",
|
44 |
"VOCAB_SAVE_PATH = os.path.join(SAVE_DIR, \"vocab.json\")\n",
|
45 |
"CONFIG_SAVE_PATH = os.path.join(SAVE_DIR, \"config.json\")\n",
|
46 |
-
"DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
]
|
48 |
},
|
49 |
{
|
50 |
-
"cell_type": "
|
51 |
-
"execution_count": 3,
|
52 |
"metadata": {},
|
53 |
-
"outputs": [],
|
54 |
"source": [
|
55 |
-
"
|
56 |
-
"from src.data_utils.dataset_params import DatasetName\n",
|
57 |
-
"\n",
|
58 |
-
"generator = DatasetGenerator(DatasetName.IMDB)\n",
|
59 |
-
"(X_train, y_train), (X_val, y_val), (X_test, y_test) = generator.generate_dataset()\n",
|
60 |
-
"X_train, y_train, X_val, y_val, X_test, y_test = X_train[:1000], y_train[:1000], X_val[:100], y_val[:100], X_test[:100], y_test[:100]"
|
61 |
]
|
62 |
},
|
63 |
{
|
64 |
"cell_type": "code",
|
65 |
-
"execution_count":
|
66 |
"metadata": {},
|
67 |
"outputs": [],
|
68 |
"source": [
|
@@ -73,45 +78,156 @@
|
|
73 |
"val_loader = create_dataloader(X_val, y_val, BATCH_SIZE)"
|
74 |
]
|
75 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
{
|
77 |
"cell_type": "code",
|
78 |
-
"execution_count":
|
79 |
"metadata": {},
|
80 |
"outputs": [],
|
81 |
"source": [
|
82 |
-
"model_params = {}\n",
|
83 |
-
"
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
"
|
91 |
-
" model = LSTMClassifier(**model_params)\n",
|
92 |
-
"else:\n",
|
93 |
-
" raise ValueError(\"Неизвестный тип модели. Выберите 'Transformer', 'Mamba' или 'LSTM'\")"
|
94 |
]
|
95 |
},
|
96 |
{
|
97 |
"cell_type": "code",
|
98 |
-
"execution_count":
|
99 |
"metadata": {},
|
100 |
"outputs": [
|
101 |
{
|
102 |
"name": "stdout",
|
103 |
"output_type": "stream",
|
104 |
"text": [
|
105 |
-
"--- Начало обучения
|
106 |
-
"Эпоха 1/
|
107 |
-
" -> Модель сохранена, новая лучшая Val Loss: 0.
|
108 |
-
"Эпоха 2/
|
109 |
-
" -> Модель сохранена, новая лучшая Val Loss: 0.
|
110 |
-
"Эпоха 3/
|
111 |
-
"
|
112 |
-
"
|
113 |
-
"
|
114 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
]
|
116 |
}
|
117 |
],
|
@@ -121,7 +237,7 @@
|
|
121 |
"criterion = nn.CrossEntropyLoss()\n",
|
122 |
"\n",
|
123 |
"best_val_loss = float('inf')\n",
|
124 |
-
"print(f\"--- Начало обучения
|
125 |
"for epoch in range(NUM_EPOCHS):\n",
|
126 |
" model.train()\n",
|
127 |
" start_time = time.time()\n",
|
@@ -160,10 +276,143 @@
|
|
160 |
" print(f\" -> Модель сохранена, новая лучшая Val Loss: {best_val_loss:.4f}\")"
|
161 |
]
|
162 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
{
|
164 |
"cell_type": "code",
|
165 |
"execution_count": 7,
|
166 |
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
"outputs": [
|
168 |
{
|
169 |
"name": "stdout",
|
@@ -184,20 +433,13 @@
|
|
184 |
"}\n",
|
185 |
"with open(CONFIG_SAVE_PATH, 'w', encoding='utf-8') as f:\n",
|
186 |
" json.dump(config, f, ensure_ascii=False, indent=4)\n",
|
187 |
-
"print(f\"Конфигурация модели сохранена в: {CONFIG_SAVE_PATH}\")
|
188 |
]
|
189 |
-
},
|
190 |
-
{
|
191 |
-
"cell_type": "code",
|
192 |
-
"execution_count": null,
|
193 |
-
"metadata": {},
|
194 |
-
"outputs": [],
|
195 |
-
"source": []
|
196 |
}
|
197 |
],
|
198 |
"metadata": {
|
199 |
"kernelspec": {
|
200 |
-
"display_name": "monkey-coding-dl-project-
|
201 |
"language": "python",
|
202 |
"name": "python3"
|
203 |
},
|
@@ -211,7 +453,7 @@
|
|
211 |
"name": "python",
|
212 |
"nbconvert_exporter": "python",
|
213 |
"pygments_lexer": "ipython3",
|
214 |
-
"version": "3.12.
|
215 |
}
|
216 |
},
|
217 |
"nbformat": 4,
|
|
|
1 |
{
|
2 |
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"Инициализация глобальных переменных, достаем датасет"
|
8 |
+
]
|
9 |
+
},
|
10 |
{
|
11 |
"cell_type": "code",
|
12 |
"execution_count": 1,
|
13 |
"metadata": {},
|
14 |
"outputs": [],
|
15 |
"source": [
|
|
|
|
|
|
|
16 |
"import os\n",
|
17 |
"import time\n",
|
18 |
"import json\n",
|
19 |
"import torch\n",
|
20 |
+
"import warnings\n",
|
21 |
+
"import numpy as np\n",
|
22 |
+
"import pandas as pd\n",
|
23 |
"import torch.nn as nn\n",
|
24 |
"import torch.optim as optim\n",
|
|
|
25 |
"from torch.utils.data import DataLoader, TensorDataset\n",
|
26 |
+
"from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n",
|
27 |
+
"for warn in [UserWarning, FutureWarning]: warnings.filterwarnings(\"ignore\", category = warn)\n",
|
28 |
"\n",
|
29 |
+
"from src.data_utils.config import DatasetConfig\n",
|
30 |
+
"from src.data_utils.dataset_params import DatasetName\n",
|
31 |
+
"from src.data_utils.dataset_generator import DatasetGenerator\n",
|
32 |
+
"from src.models.models import TransformerClassifier\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
"\n",
|
|
|
34 |
"MAX_SEQ_LEN = 300\n",
|
35 |
+
"EMBEDDING_DIM = 64 # уменьшили: 128 -> 64, чтобы влезло в гит\n",
|
36 |
+
"BATCH_SIZE = 64 # подняли batch_size: 32 -> 64\n",
|
37 |
+
"LEARNING_RATE = 7e-5\n",
|
38 |
+
"NUM_EPOCHS = 100 # подняли количество эпох: 20 -> 100\n",
|
39 |
+
"NUM_CLASSES = 2\n",
|
40 |
"\n",
|
|
|
41 |
"SAVE_DIR = \"../pretrained\"\n",
|
42 |
"os.makedirs(SAVE_DIR, exist_ok=True)\n",
|
43 |
"MODEL_SAVE_PATH = os.path.join(SAVE_DIR, \"best_model.pth\")\n",
|
44 |
"VOCAB_SAVE_PATH = os.path.join(SAVE_DIR, \"vocab.json\")\n",
|
45 |
"CONFIG_SAVE_PATH = os.path.join(SAVE_DIR, \"config.json\")\n",
|
46 |
+
"DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
47 |
+
"MODEL_TO_TRAIN = 'Transformer' \n",
|
48 |
+
"\n",
|
49 |
+
"config = DatasetConfig(\n",
|
50 |
+
" load_from_disk=True,\n",
|
51 |
+
" path_to_data=\"../datasets\",\n",
|
52 |
+
" train_size=25000, # взяли весь датасет\n",
|
53 |
+
" val_size=12500,\n",
|
54 |
+
" test_size=12500\n",
|
55 |
+
")\n",
|
56 |
+
"generator = DatasetGenerator(DatasetName.IMDB, config=config)\n",
|
57 |
+
"(X_train, y_train), (X_val, y_val), (X_test, y_test) = generator.generate_dataset()\n",
|
58 |
+
"VOCAB_SIZE = len(generator.vocab)"
|
59 |
]
|
60 |
},
|
61 |
{
|
62 |
+
"cell_type": "markdown",
|
|
|
63 |
"metadata": {},
|
|
|
64 |
"source": [
|
65 |
+
"Создаем даталоадеры"
|
|
|
|
|
|
|
|
|
|
|
66 |
]
|
67 |
},
|
68 |
{
|
69 |
"cell_type": "code",
|
70 |
+
"execution_count": 2,
|
71 |
"metadata": {},
|
72 |
"outputs": [],
|
73 |
"source": [
|
|
|
78 |
"val_loader = create_dataloader(X_val, y_val, BATCH_SIZE)"
|
79 |
]
|
80 |
},
|
81 |
+
{
|
82 |
+
"cell_type": "markdown",
|
83 |
+
"metadata": {},
|
84 |
+
"source": [
|
85 |
+
"Инициализация модели"
|
86 |
+
]
|
87 |
+
},
|
88 |
{
|
89 |
"cell_type": "code",
|
90 |
+
"execution_count": 3,
|
91 |
"metadata": {},
|
92 |
"outputs": [],
|
93 |
"source": [
|
94 |
+
"model_params = {'vocab_size': len(generator.vocab), 'embed_dim': EMBEDDING_DIM, 'num_heads': 8, 'num_layers': 4, 'num_classes': 2, 'max_seq_len': MAX_SEQ_LEN}\n",
|
95 |
+
"model = TransformerClassifier(**model_params)"
|
96 |
+
]
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"cell_type": "markdown",
|
100 |
+
"metadata": {},
|
101 |
+
"source": [
|
102 |
+
"Обучение"
|
|
|
|
|
|
|
103 |
]
|
104 |
},
|
105 |
{
|
106 |
"cell_type": "code",
|
107 |
+
"execution_count": 4,
|
108 |
"metadata": {},
|
109 |
"outputs": [
|
110 |
{
|
111 |
"name": "stdout",
|
112 |
"output_type": "stream",
|
113 |
"text": [
|
114 |
+
"--- Начало обучения модели ---\n",
|
115 |
+
"Эпоха 1/100 | Время: 14.62с | Train Loss: 0.6563 | Val Loss: 0.6460 | Val Acc: 0.6544\n",
|
116 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.6460\n",
|
117 |
+
"Эпоха 2/100 | Время: 14.10с | Train Loss: 0.5749 | Val Loss: 0.5673 | Val Acc: 0.7217\n",
|
118 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.5673\n",
|
119 |
+
"Эпоха 3/100 | Время: 14.13с | Train Loss: 0.5058 | Val Loss: 0.5285 | Val Acc: 0.7533\n",
|
120 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.5285\n",
|
121 |
+
"Эпоха 4/100 | Время: 14.09с | Train Loss: 0.4664 | Val Loss: 0.4980 | Val Acc: 0.7724\n",
|
122 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.4980\n",
|
123 |
+
"Эпоха 5/100 | Время: 14.22с | Train Loss: 0.4382 | Val Loss: 0.4785 | Val Acc: 0.7851\n",
|
124 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.4785\n",
|
125 |
+
"Эпоха 6/100 | Время: 14.29с | Train Loss: 0.4166 | Val Loss: 0.4775 | Val Acc: 0.7814\n",
|
126 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.4775\n",
|
127 |
+
"Эпоха 7/100 | Время: 14.14с | Train Loss: 0.3974 | Val Loss: 0.4636 | Val Acc: 0.7893\n",
|
128 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.4636\n",
|
129 |
+
"Эпоха 8/100 | Время: 14.11с | Train Loss: 0.3778 | Val Loss: 0.4689 | Val Acc: 0.7874\n",
|
130 |
+
"Эпоха 9/100 | Время: 14.30с | Train Loss: 0.3595 | Val Loss: 0.4491 | Val Acc: 0.7973\n",
|
131 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.4491\n",
|
132 |
+
"Эпоха 10/100 | Время: 14.23с | Train Loss: 0.3438 | Val Loss: 0.4236 | Val Acc: 0.8148\n",
|
133 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.4236\n",
|
134 |
+
"Эпоха 11/100 | Время: 14.45с | Train Loss: 0.3301 | Val Loss: 0.4173 | Val Acc: 0.8174\n",
|
135 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.4173\n",
|
136 |
+
"Эпоха 12/100 | Время: 14.21с | Train Loss: 0.3202 | Val Loss: 0.4140 | Val Acc: 0.8206\n",
|
137 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.4140\n",
|
138 |
+
"Эпоха 13/100 | Время: 14.10с | Train Loss: 0.3076 | Val Loss: 0.4079 | Val Acc: 0.8243\n",
|
139 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.4079\n",
|
140 |
+
"Эпоха 14/100 | Время: 14.07с | Train Loss: 0.2959 | Val Loss: 0.4091 | Val Acc: 0.8220\n",
|
141 |
+
"Эпоха 15/100 | Время: 14.06с | Train Loss: 0.2875 | Val Loss: 0.4074 | Val Acc: 0.8256\n",
|
142 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.4074\n",
|
143 |
+
"Эпоха 16/100 | Время: 14.25с | Train Loss: 0.2758 | Val Loss: 0.4021 | Val Acc: 0.8285\n",
|
144 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.4021\n",
|
145 |
+
"Эпоха 17/100 | Время: 14.17с | Train Loss: 0.2658 | Val Loss: 0.3933 | Val Acc: 0.8314\n",
|
146 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.3933\n",
|
147 |
+
"Эпоха 18/100 | Время: 14.21с | Train Loss: 0.2558 | Val Loss: 0.4100 | Val Acc: 0.8232\n",
|
148 |
+
"Эпоха 19/100 | Время: 14.14с | Train Loss: 0.2518 | Val Loss: 0.3940 | Val Acc: 0.8324\n",
|
149 |
+
"Эпоха 20/100 | Время: 14.14с | Train Loss: 0.2365 | Val Loss: 0.3934 | Val Acc: 0.8304\n",
|
150 |
+
"Эпоха 21/100 | Время: 14.08с | Train Loss: 0.2283 | Val Loss: 0.3913 | Val Acc: 0.8336\n",
|
151 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.3913\n",
|
152 |
+
"Эпоха 22/100 | Время: 14.06с | Train Loss: 0.2215 | Val Loss: 0.4161 | Val Acc: 0.8250\n",
|
153 |
+
"Эпоха 23/100 | Время: 14.25с | Train Loss: 0.2100 | Val Loss: 0.3956 | Val Acc: 0.8334\n",
|
154 |
+
"Эпоха 24/100 | Время: 14.25с | Train Loss: 0.2018 | Val Loss: 0.3957 | Val Acc: 0.8333\n",
|
155 |
+
"Эпоха 25/100 | Время: 14.18с | Train Loss: 0.1941 | Val Loss: 0.3942 | Val Acc: 0.8352\n",
|
156 |
+
"Эпоха 26/100 | Время: 14.34с | Train Loss: 0.1811 | Val Loss: 0.3998 | Val Acc: 0.8349\n",
|
157 |
+
"Эпоха 27/100 | Время: 14.18с | Train Loss: 0.1797 | Val Loss: 0.4078 | Val Acc: 0.8318\n",
|
158 |
+
"Эпоха 28/100 | Время: 14.12с | Train Loss: 0.1667 | Val Loss: 0.4101 | Val Acc: 0.8339\n",
|
159 |
+
"Эпоха 29/100 | Время: 14.20с | Train Loss: 0.1610 | Val Loss: 0.4119 | Val Acc: 0.8335\n",
|
160 |
+
"Эпоха 30/100 | Время: 14.25с | Train Loss: 0.1507 | Val Loss: 0.4397 | Val Acc: 0.8294\n",
|
161 |
+
"Эпоха 31/100 | Время: 14.19с | Train Loss: 0.1400 | Val Loss: 0.4245 | Val Acc: 0.8330\n",
|
162 |
+
"Эпоха 32/100 | Время: 14.11с | Train Loss: 0.1361 | Val Loss: 0.4271 | Val Acc: 0.8338\n",
|
163 |
+
"Эпоха 33/100 | Время: 14.13с | Train Loss: 0.1254 | Val Loss: 0.4434 | Val Acc: 0.8311\n",
|
164 |
+
"Эпоха 34/100 | Время: 14.15с | Train Loss: 0.1190 | Val Loss: 0.4446 | Val Acc: 0.8306\n",
|
165 |
+
"Эпоха 35/100 | Время: 14.16с | Train Loss: 0.1125 | Val Loss: 0.4728 | Val Acc: 0.8271\n",
|
166 |
+
"Эпоха 36/100 | Время: 14.23с | Train Loss: 0.1068 | Val Loss: 0.4670 | Val Acc: 0.8297\n",
|
167 |
+
"Эпоха 37/100 | Время: 14.19с | Train Loss: 0.0976 | Val Loss: 0.5572 | Val Acc: 0.8121\n",
|
168 |
+
"Эпоха 38/100 | Время: 14.18с | Train Loss: 0.0923 | Val Loss: 0.4865 | Val Acc: 0.8260\n",
|
169 |
+
"Эпоха 39/100 | Время: 14.16с | Train Loss: 0.0865 | Val Loss: 0.5013 | Val Acc: 0.8250\n",
|
170 |
+
"Эпоха 40/100 | Время: 14.11с | Train Loss: 0.0822 | Val Loss: 0.5205 | Val Acc: 0.8233\n",
|
171 |
+
"Эпоха 41/100 | Время: 14.11с | Train Loss: 0.0731 | Val Loss: 0.5203 | Val Acc: 0.8255\n",
|
172 |
+
"Эпоха 42/100 | Время: 14.17с | Train Loss: 0.0693 | Val Loss: 0.5313 | Val Acc: 0.8276\n",
|
173 |
+
"Эпоха 43/100 | Время: 14.27с | Train Loss: 0.0645 | Val Loss: 0.5518 | Val Acc: 0.8270\n",
|
174 |
+
"Эпоха 44/100 | Время: 14.15с | Train Loss: 0.0577 | Val Loss: 0.5554 | Val Acc: 0.8270\n",
|
175 |
+
"Эпоха 45/100 | Время: 14.23с | Train Loss: 0.0561 | Val Loss: 0.5659 | Val Acc: 0.8258\n",
|
176 |
+
"Эпоха 46/100 | Время: 14.17с | Train Loss: 0.0526 | Val Loss: 0.5840 | Val Acc: 0.8219\n",
|
177 |
+
"Эпоха 47/100 | Время: 14.14с | Train Loss: 0.0457 | Val Loss: 0.6217 | Val Acc: 0.8224\n",
|
178 |
+
"Эпоха 48/100 | Время: 14.27с | Train Loss: 0.0420 | Val Loss: 0.6294 | Val Acc: 0.8237\n",
|
179 |
+
"Эпоха 49/100 | Время: 14.21с | Train Loss: 0.0411 | Val Loss: 0.6333 | Val Acc: 0.8214\n",
|
180 |
+
"Эпоха 50/100 | Время: 14.14с | Train Loss: 0.0345 | Val Loss: 0.6566 | Val Acc: 0.8266\n",
|
181 |
+
"Эпоха 51/100 | Время: 14.12с | Train Loss: 0.0373 | Val Loss: 0.6504 | Val Acc: 0.8243\n",
|
182 |
+
"Эпоха 52/100 | Время: 14.12с | Train Loss: 0.0319 | Val Loss: 0.6640 | Val Acc: 0.8272\n",
|
183 |
+
"Эпоха 53/100 | Время: 14.13с | Train Loss: 0.0286 | Val Loss: 0.6896 | Val Acc: 0.8249\n",
|
184 |
+
"Эпоха 54/100 | Время: 14.14с | Train Loss: 0.0274 | Val Loss: 0.7036 | Val Acc: 0.8213\n",
|
185 |
+
"Эпоха 55/100 | Время: 14.23с | Train Loss: 0.0268 | Val Loss: 0.8750 | Val Acc: 0.7955\n",
|
186 |
+
"Эпоха 56/100 | Время: 14.05с | Train Loss: 0.0274 | Val Loss: 0.7306 | Val Acc: 0.8194\n",
|
187 |
+
"Эпоха 57/100 | Время: 14.06с | Train Loss: 0.0224 | Val Loss: 0.7345 | Val Acc: 0.8196\n",
|
188 |
+
"Эпоха 58/100 | Время: 14.06с | Train Loss: 0.0234 | Val Loss: 0.7029 | Val Acc: 0.8238\n",
|
189 |
+
"Эпоха 59/100 | Время: 14.04с | Train Loss: 0.0218 | Val Loss: 0.7278 | Val Acc: 0.8253\n",
|
190 |
+
"Эпоха 60/100 | Время: 14.15с | Train Loss: 0.0193 | Val Loss: 0.7509 | Val Acc: 0.8217\n",
|
191 |
+
"Эпоха 61/100 | Время: 14.27с | Train Loss: 0.0169 | Val Loss: 0.7706 | Val Acc: 0.8229\n",
|
192 |
+
"Эпоха 62/100 | Время: 14.12с | Train Loss: 0.0177 | Val Loss: 0.7659 | Val Acc: 0.8229\n",
|
193 |
+
"Эпоха 63/100 | Время: 14.35с | Train Loss: 0.0159 | Val Loss: 0.7892 | Val Acc: 0.8178\n",
|
194 |
+
"Эпоха 64/100 | Время: 14.17с | Train Loss: 0.0153 | Val Loss: 0.7721 | Val Acc: 0.8262\n",
|
195 |
+
"Эпоха 65/100 | Время: 14.13с | Train Loss: 0.0161 | Val Loss: 0.7746 | Val Acc: 0.8218\n",
|
196 |
+
"Эпоха 66/100 | Время: 14.14с | Train Loss: 0.0151 | Val Loss: 0.7781 | Val Acc: 0.8227\n",
|
197 |
+
"Эпоха 67/100 | Время: 14.25с | Train Loss: 0.0131 | Val Loss: 0.8032 | Val Acc: 0.8198\n",
|
198 |
+
"Эпоха 68/100 | Время: 14.10с | Train Loss: 0.0156 | Val Loss: 0.7780 | Val Acc: 0.8274\n",
|
199 |
+
"Эпоха 69/100 | Время: 14.04с | Train Loss: 0.0147 | Val Loss: 0.7967 | Val Acc: 0.8237\n",
|
200 |
+
"Эпоха 70/100 | Время: 14.05с | Train Loss: 0.0152 | Val Loss: 0.7833 | Val Acc: 0.8240\n",
|
201 |
+
"Эпоха 71/100 | Время: 14.35с | Train Loss: 0.0136 | Val Loss: 0.8180 | Val Acc: 0.8212\n",
|
202 |
+
"Эпоха 72/100 | Время: 14.24с | Train Loss: 0.0120 | Val Loss: 0.8000 | Val Acc: 0.8235\n",
|
203 |
+
"Эпоха 73/100 | Время: 14.18с | Train Loss: 0.0120 | Val Loss: 0.7985 | Val Acc: 0.8226\n",
|
204 |
+
"Эпоха 74/100 | Время: 14.21с | Train Loss: 0.0106 | Val Loss: 0.7959 | Val Acc: 0.8259\n",
|
205 |
+
"Эпоха 75/100 | Время: 14.12с | Train Loss: 0.0112 | Val Loss: 0.7925 | Val Acc: 0.8238\n",
|
206 |
+
"Эпоха 76/100 | Время: 14.09с | Train Loss: 0.0133 | Val Loss: 0.8455 | Val Acc: 0.8138\n",
|
207 |
+
"Эпоха 77/100 | Время: 14.12с | Train Loss: 0.0099 | Val Loss: 0.8086 | Val Acc: 0.8243\n",
|
208 |
+
"Эпоха 78/100 | Время: 14.19с | Train Loss: 0.0086 | Val Loss: 0.8051 | Val Acc: 0.8271\n",
|
209 |
+
"Эпоха 79/100 | Время: 14.11с | Train Loss: 0.0091 | Val Loss: 0.8212 | Val Acc: 0.8242\n",
|
210 |
+
"Эпоха 80/100 | Время: 14.25с | Train Loss: 0.0105 | Val Loss: 0.8192 | Val Acc: 0.8244\n",
|
211 |
+
"Эпоха 81/100 | Время: 14.20с | Train Loss: 0.0111 | Val Loss: 0.7825 | Val Acc: 0.8250\n",
|
212 |
+
"Эпоха 82/100 | Время: 14.20с | Train Loss: 0.0105 | Val Loss: 0.7885 | Val Acc: 0.8259\n",
|
213 |
+
"Эпоха 83/100 | Время: 14.16с | Train Loss: 0.0091 | Val Loss: 0.7950 | Val Acc: 0.8280\n",
|
214 |
+
"Эпоха 84/100 | Время: 14.27с | Train Loss: 0.0092 | Val Loss: 0.8490 | Val Acc: 0.8217\n",
|
215 |
+
"Эпоха 85/100 | Время: 14.63с | Train Loss: 0.0068 | Val Loss: 0.8464 | Val Acc: 0.8239\n",
|
216 |
+
"Эпоха 86/100 | Время: 14.43с | Train Loss: 0.0084 | Val Loss: 0.8344 | Val Acc: 0.8250\n",
|
217 |
+
"Эпоха 87/100 | Время: 14.27с | Train Loss: 0.0080 | Val Loss: 0.8242 | Val Acc: 0.8266\n",
|
218 |
+
"Эпоха 88/100 | Время: 14.22с | Train Loss: 0.0102 | Val Loss: 0.8427 | Val Acc: 0.8230\n",
|
219 |
+
"Эпоха 89/100 | Время: 14.19с | Train Loss: 0.0080 | Val Loss: 0.8097 | Val Acc: 0.8241\n",
|
220 |
+
"Эпоха 90/100 | Время: 14.24с | Train Loss: 0.0079 | Val Loss: 0.8986 | Val Acc: 0.8161\n",
|
221 |
+
"Эпоха 91/100 | Время: 14.18с | Train Loss: 0.0083 | Val Loss: 0.9104 | Val Acc: 0.8162\n",
|
222 |
+
"Эпоха 92/100 | Время: 14.25с | Train Loss: 0.0073 | Val Loss: 0.8569 | Val Acc: 0.8258\n",
|
223 |
+
"Эпоха 93/100 | Время: 14.17с | Train Loss: 0.0078 | Val Loss: 0.9992 | Val Acc: 0.8039\n",
|
224 |
+
"Эпоха 94/100 | Время: 14.21с | Train Loss: 0.0066 | Val Loss: 0.8613 | Val Acc: 0.8224\n",
|
225 |
+
"Эпоха 95/100 | Время: 14.37с | Train Loss: 0.0067 | Val Loss: 0.8378 | Val Acc: 0.8284\n",
|
226 |
+
"Эпоха 96/100 | Время: 14.24с | Train Loss: 0.0059 | Val Loss: 0.8703 | Val Acc: 0.8203\n",
|
227 |
+
"Эпоха 97/100 | Время: 14.45с | Train Loss: 0.0089 | Val Loss: 0.8341 | Val Acc: 0.8256\n",
|
228 |
+
"Эпоха 98/100 | Время: 14.49с | Train Loss: 0.0131 | Val Loss: 0.8256 | Val Acc: 0.8217\n",
|
229 |
+
"Эпоха 99/100 | Время: 14.46с | Train Loss: 0.0056 | Val Loss: 0.8518 | Val Acc: 0.8202\n",
|
230 |
+
"Эпоха 100/100 | Время: 14.57с | Train Loss: 0.0065 | Val Loss: 0.8770 | Val Acc: 0.8216\n"
|
231 |
]
|
232 |
}
|
233 |
],
|
|
|
237 |
"criterion = nn.CrossEntropyLoss()\n",
|
238 |
"\n",
|
239 |
"best_val_loss = float('inf')\n",
|
240 |
+
"print(f\"--- Начало обучения модели ---\")\n",
|
241 |
"for epoch in range(NUM_EPOCHS):\n",
|
242 |
" model.train()\n",
|
243 |
" start_time = time.time()\n",
|
|
|
276 |
" print(f\" -> Модель сохранена, новая лучшая Val Loss: {best_val_loss:.4f}\")"
|
277 |
]
|
278 |
},
|
279 |
+
{
|
280 |
+
"cell_type": "markdown",
|
281 |
+
"metadata": {},
|
282 |
+
"source": [
|
283 |
+
"Снимем качество на тестовых данных из исходного датасета"
|
284 |
+
]
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"cell_type": "code",
|
288 |
+
"execution_count": 5,
|
289 |
+
"metadata": {},
|
290 |
+
"outputs": [
|
291 |
+
{
|
292 |
+
"name": "stdout",
|
293 |
+
"output_type": "stream",
|
294 |
+
"text": [
|
295 |
+
"Метрики на тестовой выборке (из обучаемого датасета) итоговой модели\n",
|
296 |
+
"{'loss': 0.8565363820110049, 'accuracy': 0.8276, 'precision': 0.851743686651778, 'recall': 0.79328, 'f1_score': 0.8214729517024273}\n"
|
297 |
+
]
|
298 |
+
}
|
299 |
+
],
|
300 |
+
"source": [
|
301 |
+
"def evaluate_on_test(model, test_loader, device, criterion):\n",
|
302 |
+
" model.eval()\n",
|
303 |
+
" total_test_loss = 0\n",
|
304 |
+
" all_preds = []\n",
|
305 |
+
" all_labels = []\n",
|
306 |
+
"\n",
|
307 |
+
" with torch.no_grad():\n",
|
308 |
+
" for batch_X, batch_y in test_loader:\n",
|
309 |
+
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
|
310 |
+
" outputs = model(batch_X)\n",
|
311 |
+
" loss = criterion(outputs, batch_y)\n",
|
312 |
+
" total_test_loss += loss.item()\n",
|
313 |
+
" \n",
|
314 |
+
" _, predicted = torch.max(outputs.data, 1)\n",
|
315 |
+
" all_preds.extend(predicted.cpu().numpy())\n",
|
316 |
+
" all_labels.extend(batch_y.cpu().numpy())\n",
|
317 |
+
" \n",
|
318 |
+
" avg_test_loss = total_test_loss / len(test_loader)\n",
|
319 |
+
" \n",
|
320 |
+
" accuracy = accuracy_score(all_labels, all_preds)\n",
|
321 |
+
" precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
|
322 |
+
" \n",
|
323 |
+
" return {'loss': avg_test_loss, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1_score': f1}\n",
|
324 |
+
"\n",
|
325 |
+
"\n",
|
326 |
+
"test_loader = create_dataloader(X_test, y_test, BATCH_SIZE)\n",
|
327 |
+
"test_metrics = evaluate_on_test(model, test_loader, DEVICE, criterion)\n",
|
328 |
+
"print(f\"Метрики на тестовой выборке (из обучаемого датасета) итоговой модели\\n{test_metrics}\")"
|
329 |
+
]
|
330 |
+
},
|
331 |
+
{
|
332 |
+
"cell_type": "markdown",
|
333 |
+
"metadata": {},
|
334 |
+
"source": [
|
335 |
+
"Снимем качество на тестовых данных нового датасета. Считаем данные"
|
336 |
+
]
|
337 |
+
},
|
338 |
+
{
|
339 |
+
"cell_type": "code",
|
340 |
+
"execution_count": 6,
|
341 |
+
"metadata": {},
|
342 |
+
"outputs": [],
|
343 |
+
"source": [
|
344 |
+
"text_processor = generator.get_text_processor()\n",
|
345 |
+
"config_polarity = DatasetConfig(\n",
|
346 |
+
" load_from_disk=True,\n",
|
347 |
+
" path_to_data=\"../datasets\",\n",
|
348 |
+
" train_size=25000, # взяли весь датасет\n",
|
349 |
+
" val_size=12500,\n",
|
350 |
+
" test_size=12500,\n",
|
351 |
+
" build_vocab=False\n",
|
352 |
+
")\n",
|
353 |
+
"generator_polarity = DatasetGenerator(DatasetName.POLARITY, config=config_polarity)\n",
|
354 |
+
"generator_polarity.vocab = generator.vocab\n",
|
355 |
+
"generator_polarity.id2word = generator.id2word\n",
|
356 |
+
"generator_polarity.text_processor = text_processor\n",
|
357 |
+
"(X_train, y_train), (X_val, y_val), (X_test, y_test) = generator_polarity.generate_dataset()\n"
|
358 |
+
]
|
359 |
+
},
|
360 |
{
|
361 |
"cell_type": "code",
|
362 |
"execution_count": 7,
|
363 |
"metadata": {},
|
364 |
+
"outputs": [],
|
365 |
+
"source": [
|
366 |
+
"test_loader = create_dataloader(X_test, y_test, BATCH_SIZE)"
|
367 |
+
]
|
368 |
+
},
|
369 |
+
{
|
370 |
+
"cell_type": "markdown",
|
371 |
+
"metadata": {},
|
372 |
+
"source": [
|
373 |
+
"Посмтрим на метрики"
|
374 |
+
]
|
375 |
+
},
|
376 |
+
{
|
377 |
+
"cell_type": "code",
|
378 |
+
"execution_count": 8,
|
379 |
+
"metadata": {},
|
380 |
+
"outputs": [
|
381 |
+
{
|
382 |
+
"name": "stdout",
|
383 |
+
"output_type": "stream",
|
384 |
+
"text": [
|
385 |
+
"Метрики на тестовой выборке (из неизвестного датасета) итоговой модели\n",
|
386 |
+
"{'loss': 0.6786724476485836, 'accuracy': 0.73816, 'precision': 0.7227414330218068, 'recall': 0.7762906309751434, 'f1_score': 0.7485595759391565}\n"
|
387 |
+
]
|
388 |
+
}
|
389 |
+
],
|
390 |
+
"source": [
|
391 |
+
"test_metrics = evaluate_on_test(model, test_loader, DEVICE, criterion)\n",
|
392 |
+
"print(f\"Метрики на тестовой выборке (из неизвестного датасета) итоговой модели\\n{test_metrics}\")"
|
393 |
+
]
|
394 |
+
},
|
395 |
+
{
|
396 |
+
"cell_type": "markdown",
|
397 |
+
"metadata": {},
|
398 |
+
"source": [
|
399 |
+
"В целом видно, что модель что-то, да выучила. Гипотезы по улучшению:\n",
|
400 |
+
" - Больше и разнообразнее данные для обучения\n",
|
401 |
+
" - Чем больше словарь - тем лучше\n",
|
402 |
+
" - Нужно чтобы тестовый датасет был больше похож на обучаемый"
|
403 |
+
]
|
404 |
+
},
|
405 |
+
{
|
406 |
+
"cell_type": "markdown",
|
407 |
+
"metadata": {},
|
408 |
+
"source": [
|
409 |
+
"Сохранение итоговой модели"
|
410 |
+
]
|
411 |
+
},
|
412 |
+
{
|
413 |
+
"cell_type": "code",
|
414 |
+
"execution_count": 9,
|
415 |
+
"metadata": {},
|
416 |
"outputs": [
|
417 |
{
|
418 |
"name": "stdout",
|
|
|
433 |
"}\n",
|
434 |
"with open(CONFIG_SAVE_PATH, 'w', encoding='utf-8') as f:\n",
|
435 |
" json.dump(config, f, ensure_ascii=False, indent=4)\n",
|
436 |
+
"print(f\"Конфигурация модели сохранена в: {CONFIG_SAVE_PATH}\")"
|
437 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
438 |
}
|
439 |
],
|
440 |
"metadata": {
|
441 |
"kernelspec": {
|
442 |
+
"display_name": "monkey-coding-dl-project-F4QJzkF_-py3.12",
|
443 |
"language": "python",
|
444 |
"name": "python3"
|
445 |
},
|
|
|
453 |
"name": "python",
|
454 |
"nbconvert_exporter": "python",
|
455 |
"pygments_lexer": "ipython3",
|
456 |
+
"version": "3.12.11"
|
457 |
}
|
458 |
},
|
459 |
"nbformat": 4,
|
pretrained/best_model.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0709ee349a04faf4ff0e19d0a91967953dd8aa0ca2fb0863f866fcc3d219debc
|
3 |
+
size 29525749
|
pretrained/config.json
CHANGED
@@ -2,10 +2,10 @@
|
|
2 |
"model_type": "Transformer",
|
3 |
"max_seq_len": 300,
|
4 |
"model_params": {
|
5 |
-
"vocab_size":
|
6 |
-
"embed_dim":
|
7 |
-
"num_heads":
|
8 |
-
"num_layers":
|
9 |
"num_classes": 2,
|
10 |
"max_seq_len": 300
|
11 |
}
|
|
|
2 |
"model_type": "Transformer",
|
3 |
"max_seq_len": 300,
|
4 |
"model_params": {
|
5 |
+
"vocab_size": 111829,
|
6 |
+
"embed_dim": 64,
|
7 |
+
"num_heads": 8,
|
8 |
+
"num_layers": 4,
|
9 |
"num_classes": 2,
|
10 |
"max_seq_len": 300
|
11 |
}
|
pretrained/vocab.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
src/data_utils/config.py
CHANGED
@@ -15,6 +15,7 @@ class DatasetConfig:
|
|
15 |
min_word_freq: Minimum word frequency to include in vocabulary
|
16 |
load_from_disk: Load dataset from local dir. If false download from huggin face
|
17 |
path_to_data: Path to local dataset data
|
|
|
18 |
max_seq_len: Maximum sequence length (will be padded/truncated to this)
|
19 |
lowercase: Whether to convert text to lowercase
|
20 |
remove_punct: Whether to remove punctuation
|
@@ -30,6 +31,7 @@ class DatasetConfig:
|
|
30 |
min_word_freq: int = 1
|
31 |
load_from_disk: bool = False
|
32 |
path_to_data: str = "./datasets"
|
|
|
33 |
|
34 |
max_seq_len: int = 300
|
35 |
lowercase: bool = True
|
|
|
15 |
min_word_freq: Minimum word frequency to include in vocabulary
|
16 |
load_from_disk: Load dataset from local dir. If false download from huggin face
|
17 |
path_to_data: Path to local dataset data
|
18 |
+
build_vocab: Is build vocabulary necessary
|
19 |
max_seq_len: Maximum sequence length (will be padded/truncated to this)
|
20 |
lowercase: Whether to convert text to lowercase
|
21 |
remove_punct: Whether to remove punctuation
|
|
|
31 |
min_word_freq: int = 1
|
32 |
load_from_disk: bool = False
|
33 |
path_to_data: str = "./datasets"
|
34 |
+
build_vocab: bool = True
|
35 |
|
36 |
max_seq_len: int = 300
|
37 |
lowercase: bool = True
|
src/data_utils/dataset_generator.py
CHANGED
@@ -128,7 +128,8 @@ class DatasetGenerator:
|
|
128 |
train_texts = train_df[self.dataset_params.content_col_name].tolist()
|
129 |
train_tokens = [self.text_processor.preprocess_text(text) for text in train_texts]
|
130 |
|
131 |
-
|
|
|
132 |
|
133 |
X_train = torch.stack([self.text_processor.text_to_tensor(text) for text in train_texts])
|
134 |
|
|
|
128 |
train_texts = train_df[self.dataset_params.content_col_name].tolist()
|
129 |
train_tokens = [self.text_processor.preprocess_text(text) for text in train_texts]
|
130 |
|
131 |
+
if self.config.build_vocab:
|
132 |
+
self.vocab, self.id2word = self.build_vocabulary(train_tokens)
|
133 |
|
134 |
X_train = torch.stack([self.text_processor.text_to_tensor(text) for text in train_texts])
|
135 |
|