litvinovmitch11 commited on
Commit
2a591a9
·
verified ·
1 Parent(s): 9e6c6b7

Synced repo using 'sync_with_huggingface' Github Action

Browse files
notebooks/mamba_vs_transformerts.ipynb ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import time\n",
11
+ "import torch\n",
12
+ "import warnings\n",
13
+ "import numpy as np\n",
14
+ "import pandas as pd\n",
15
+ "import torch.nn as nn\n",
16
+ "import torch.optim as optim\n",
17
+ "from torch.utils.data import DataLoader, TensorDataset\n",
18
+ "from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n",
19
+ "for warn in [UserWarning, FutureWarning]: warnings.filterwarnings(\"ignore\", category = warn)\n",
20
+ "\n",
21
+ "from src.data_utils.config import DatasetConfig\n",
22
+ "from src.data_utils.dataset_params import DatasetName\n",
23
+ "from src.data_utils.dataset_generator import DatasetGenerator\n",
24
+ "from src.models.models import TransformerClassifier, CustomMambaClassifier, LSTMClassifier\n",
25
+ "\n",
26
+ "MAX_SEQ_LEN = 300\n",
27
+ "EMBEDDING_DIM = 128\n",
28
+ "BATCH_SIZE = 32\n",
29
+ "LEARNING_RATE = 7e-5 # уменьшили lr: 1e-4 -> 7e-5\n",
30
+ "NUM_EPOCHS = 20 # подняли количество эпох: 5 -> 20\n",
31
+ "NUM_CLASSES = 2\n",
32
+ "\n",
33
+ "SAVE_DIR = \"../best_models/\"\n",
34
+ "os.makedirs(SAVE_DIR, exist_ok=True)\n",
35
+ "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
36
+ "\n",
37
+ "config = DatasetConfig(\n",
38
+ " load_from_disk=True,\n",
39
+ " path_to_data=\"../datasets\",\n",
40
+ " train_size=25000, # увеличили количество сэмплов\n",
41
+ " val_size=12500,\n",
42
+ " test_size=12500\n",
43
+ ")"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": 2,
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "\n",
53
+ "generator = DatasetGenerator(DatasetName.IMDB, config=config)\n",
54
+ "(X_train, y_train), (X_val, y_val), (X_test, y_test) = generator.generate_dataset()\n",
55
+ "VOCAB_SIZE = len(generator.vocab)"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "markdown",
60
+ "metadata": {},
61
+ "source": [
62
+ "Также создадим генератор для тестовых данных второго датасета"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": 3,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "text_processor = generator.get_text_processor()\n",
72
+ "config_polarity = DatasetConfig(\n",
73
+ " load_from_disk=True,\n",
74
+ " path_to_data=\"../datasets\",\n",
75
+ " test_size=10000,\n",
76
+ " build_vocab=False\n",
77
+ ")\n",
78
+ "generator_polarity = DatasetGenerator(DatasetName.POLARITY, config=config_polarity)\n",
79
+ "generator_polarity.vocab = generator.vocab\n",
80
+ "generator_polarity.id2word = generator.id2word\n",
81
+ "generator_polarity.text_processor = text_processor\n",
82
+ "(_, _), (_, _), (X_test_polarity, y_test_polarity) = generator_polarity.generate_dataset()"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": 4,
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": [
91
+ "\n",
92
+ "def train_and_evaluate(model, train_loader, val_loader, optimizer, criterion, num_epochs, device, model_name, save_path):\n",
93
+ " best_val_f1 = 0.0\n",
94
+ " history = {'train_loss': [], 'val_loss': [], 'val_accuracy': [], 'val_f1': []}\n",
95
+ " \n",
96
+ " print(f\"--- Начало обучения модели: {model_name} на устройстве {device} ---\")\n",
97
+ "\n",
98
+ " for epoch in range(num_epochs):\n",
99
+ " model.train()\n",
100
+ " start_time = time.time()\n",
101
+ " total_train_loss = 0\n",
102
+ "\n",
103
+ " for batch_X, batch_y in train_loader:\n",
104
+ " batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
105
+ " optimizer.zero_grad()\n",
106
+ " outputs = model(batch_X)\n",
107
+ " loss = criterion(outputs, batch_y)\n",
108
+ " loss.backward()\n",
109
+ " optimizer.step()\n",
110
+ " total_train_loss += loss.item()\n",
111
+ " \n",
112
+ " avg_train_loss = total_train_loss / len(train_loader)\n",
113
+ " history['train_loss'].append(avg_train_loss)\n",
114
+ "\n",
115
+ " model.eval()\n",
116
+ " total_val_loss = 0\n",
117
+ " all_preds = []\n",
118
+ " all_labels = []\n",
119
+ "\n",
120
+ " with torch.no_grad():\n",
121
+ " for batch_X, batch_y in val_loader:\n",
122
+ " batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
123
+ " outputs = model(batch_X)\n",
124
+ " loss = criterion(outputs, batch_y)\n",
125
+ " total_val_loss += loss.item()\n",
126
+ " \n",
127
+ " _, predicted = torch.max(outputs.data, 1)\n",
128
+ " all_preds.extend(predicted.cpu().numpy())\n",
129
+ " all_labels.extend(batch_y.cpu().numpy())\n",
130
+ " \n",
131
+ " avg_val_loss = total_val_loss / len(val_loader)\n",
132
+ " \n",
133
+ " accuracy = accuracy_score(all_labels, all_preds)\n",
134
+ " precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
135
+ " \n",
136
+ " history['val_loss'].append(avg_val_loss)\n",
137
+ " history['val_accuracy'].append(accuracy)\n",
138
+ " history['val_f1'].append(f1)\n",
139
+ "\n",
140
+ " epoch_time = time.time() - start_time\n",
141
+ " print(f\"Эпоха {epoch+1}/{num_epochs} | Время: {epoch_time:.2f}с | Train Loss: {avg_train_loss:.4f} | \"\n",
142
+ " f\"Val Loss: {avg_val_loss:.4f} | Val Acc: {accuracy:.4f} | Val F1: {f1:.4f}\")\n",
143
+ "\n",
144
+ " if f1 > best_val_f1:\n",
145
+ " best_val_f1 = f1\n",
146
+ " torch.save(model.state_dict(), save_path)\n",
147
+ " print(f\" -> Модель сохранена, новый лучший Val F1: {best_val_f1:.4f}\")\n",
148
+ " \n",
149
+ " print(f\"--- Обучение модели {model_name} завершено ---\")\n",
150
+ " return history\n",
151
+ "\n",
152
+ "def evaluate_on_test(model, test_loader, device, criterion):\n",
153
+ " model.eval()\n",
154
+ " total_test_loss = 0\n",
155
+ " all_preds = []\n",
156
+ " all_labels = []\n",
157
+ "\n",
158
+ " with torch.no_grad():\n",
159
+ " for batch_X, batch_y in test_loader:\n",
160
+ " batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
161
+ " outputs = model(batch_X)\n",
162
+ " loss = criterion(outputs, batch_y)\n",
163
+ " total_test_loss += loss.item()\n",
164
+ " \n",
165
+ " _, predicted = torch.max(outputs.data, 1)\n",
166
+ " all_preds.extend(predicted.cpu().numpy())\n",
167
+ " all_labels.extend(batch_y.cpu().numpy())\n",
168
+ " \n",
169
+ " avg_test_loss = total_test_loss / len(test_loader)\n",
170
+ " \n",
171
+ " accuracy = accuracy_score(all_labels, all_preds)\n",
172
+ " precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
173
+ " \n",
174
+ " return {'loss': avg_test_loss, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1_score': f1}"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": 5,
180
+ "metadata": {},
181
+ "outputs": [],
182
+ "source": [
183
+ "\n",
184
+ "def create_dataloader(X, y, batch_size, shuffle=True):\n",
185
+ " X_tensor = torch.as_tensor(X, dtype=torch.long)\n",
186
+ " y_tensor = torch.as_tensor(y, dtype=torch.long)\n",
187
+ " dataset = TensorDataset(X_tensor, y_tensor)\n",
188
+ " return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)\n",
189
+ "\n",
190
+ "train_loader = create_dataloader(X_train, y_train, BATCH_SIZE)\n",
191
+ "val_loader = create_dataloader(X_val, y_val, BATCH_SIZE, shuffle=False)\n",
192
+ "test_loader = create_dataloader(X_test, y_test, BATCH_SIZE, shuffle=False)\n",
193
+ "test_loader_polarity = create_dataloader(X_test_polarity, y_test_polarity, BATCH_SIZE, shuffle=False)"
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "execution_count": 6,
199
+ "metadata": {},
200
+ "outputs": [
201
+ {
202
+ "name": "stdout",
203
+ "output_type": "stream",
204
+ "text": [
205
+ "--- Начало обучения модели: CustomMamba на устройстве cuda ---\n",
206
+ "Эпоха 1/20 | Время: 897.35с | Train Loss: 0.6261 | Val Loss: 0.5313 | Val Acc: 0.7388 | Val F1: 0.7503\n",
207
+ " -> Модель сохранена, новый лучший Val F1: 0.7503\n",
208
+ "Эпоха 2/20 | Время: 903.31с | Train Loss: 0.4748 | Val Loss: 0.4559 | Val Acc: 0.7889 | Val F1: 0.7901\n",
209
+ " -> Модель сохранена, новый лучший Val F1: 0.7901\n",
210
+ "Эпоха 3/20 | Время: 930.96с | Train Loss: 0.3955 | Val Loss: 0.4176 | Val Acc: 0.8090 | Val F1: 0.8048\n",
211
+ " -> Модель сохранена, новый лучший Val F1: 0.8048\n",
212
+ "Эпоха 4/20 | Время: 952.79с | Train Loss: 0.3429 | Val Loss: 0.3998 | Val Acc: 0.8230 | Val F1: 0.8303\n",
213
+ " -> Модель сохранена, новый лучший Val F1: 0.8303\n",
214
+ "Эпоха 5/20 | Время: 904.67с | Train Loss: 0.2984 | Val Loss: 0.4387 | Val Acc: 0.8165 | Val F1: 0.8337\n",
215
+ " -> Модель сохранена, новый лучший Val F1: 0.8337\n",
216
+ "Эпоха 6/20 | Время: 935.58с | Train Loss: 0.2609 | Val Loss: 0.4219 | Val Acc: 0.8255 | Val F1: 0.8386\n",
217
+ " -> Модель сохранена, новый лучший Val F1: 0.8386\n",
218
+ "Эпоха 7/20 | Время: 902.73с | Train Loss: 0.2266 | Val Loss: 0.4342 | Val Acc: 0.8334 | Val F1: 0.8420\n",
219
+ " -> Модель сохранена, новый лучший Val F1: 0.8420\n",
220
+ "Эпоха 8/20 | Время: 921.49с | Train Loss: 0.1953 | Val Loss: 0.4675 | Val Acc: 0.8299 | Val F1: 0.8411\n",
221
+ "Эпоха 9/20 | Время: 892.54с | Train Loss: 0.1590 | Val Loss: 0.5205 | Val Acc: 0.8359 | Val F1: 0.8394\n",
222
+ "Эпоха 10/20 | Время: 893.22с | Train Loss: 0.1263 | Val Loss: 0.6014 | Val Acc: 0.8303 | Val F1: 0.8398\n",
223
+ "Эпоха 11/20 | Время: 953.69с | Train Loss: 0.0945 | Val Loss: 0.8025 | Val Acc: 0.8183 | Val F1: 0.8352\n",
224
+ "Эпоха 12/20 | Время: 924.98с | Train Loss: 0.0644 | Val Loss: 0.8539 | Val Acc: 0.8290 | Val F1: 0.8369\n",
225
+ "Эпоха 13/20 | Время: 916.92с | Train Loss: 0.0428 | Val Loss: 1.0646 | Val Acc: 0.8286 | Val F1: 0.8266\n",
226
+ "Эпоха 14/20 | Время: 904.39с | Train Loss: 0.0266 | Val Loss: 1.5225 | Val Acc: 0.8149 | Val F1: 0.8305\n",
227
+ "Эпоха 15/20 | Время: 923.80с | Train Loss: 0.0199 | Val Loss: 1.6176 | Val Acc: 0.8242 | Val F1: 0.8351\n",
228
+ "Эпоха 16/20 | Время: 922.03с | Train Loss: 0.0134 | Val Loss: 1.8983 | Val Acc: 0.8258 | Val F1: 0.8354\n",
229
+ "Эпоха 17/20 | Время: 914.58с | Train Loss: 0.0069 | Val Loss: 1.7992 | Val Acc: 0.8260 | Val F1: 0.8252\n",
230
+ "Эпоха 18/20 | Время: 937.09с | Train Loss: 0.0134 | Val Loss: 2.2935 | Val Acc: 0.8120 | Val F1: 0.8287\n",
231
+ "Эпоха 19/20 | Время: 898.26с | Train Loss: 0.0061 | Val Loss: 2.3201 | Val Acc: 0.8294 | Val F1: 0.8338\n",
232
+ "Эпоха 20/20 | Время: 910.46с | Train Loss: 0.0065 | Val Loss: 1.7146 | Val Acc: 0.8282 | Val F1: 0.8311\n",
233
+ "--- Обучение модели CustomMamba завершено ---\n",
234
+ "--- Оценка лучшей модели CustomMamba на тестовых данных ---\n",
235
+ "Результаты для CustomMamba: {'loss': 0.4370202890042301, 'accuracy': 0.8328, 'precision': 0.8017116333043226, 'recall': 0.88432, 'f1_score': 0.8409920876445527}\n",
236
+ "------------------------------------------------------------\n",
237
+ "--- Начало обучения модели: Lib_Transformer на устройстве cuda ---\n",
238
+ "Эпоха 1/20 | Время: 21.90с | Train Loss: 0.5894 | Val Loss: 0.5326 | Val Acc: 0.7487 | Val F1: 0.7689\n",
239
+ " -> Модель сохранена, новый лучший Val F1: 0.7689\n",
240
+ "Эпоха 2/20 | Время: 21.56с | Train Loss: 0.4505 | Val Loss: 0.4653 | Val Acc: 0.7953 | Val F1: 0.7894\n",
241
+ " -> Модель сохранена, новый лучший Val F1: 0.7894\n",
242
+ "Эпоха 3/20 | Время: 21.64с | Train Loss: 0.3925 | Val Loss: 0.4553 | Val Acc: 0.8001 | Val F1: 0.7818\n",
243
+ "Эпоха 4/20 | Время: 22.27с | Train Loss: 0.3562 | Val Loss: 0.4642 | Val Acc: 0.7836 | Val F1: 0.7480\n",
244
+ "Эпоха 5/20 | Время: 21.73с | Train Loss: 0.3294 | Val Loss: 0.4035 | Val Acc: 0.8305 | Val F1: 0.8337\n",
245
+ " -> Модель сохранена, новый лучший Val F1: 0.8337\n",
246
+ "Эпоха 6/20 | Время: 21.58с | Train Loss: 0.3022 | Val Loss: 0.3936 | Val Acc: 0.8364 | Val F1: 0.8356\n",
247
+ " -> Модель сохранена, новый лучший Val F1: 0.8356\n",
248
+ "Эпоха 7/20 | Время: 21.67с | Train Loss: 0.2752 | Val Loss: 0.3853 | Val Acc: 0.8380 | Val F1: 0.8345\n",
249
+ "Эпоха 8/20 | Время: 21.82с | Train Loss: 0.2507 | Val Loss: 0.3882 | Val Acc: 0.8377 | Val F1: 0.8329\n",
250
+ "Эпоха 9/20 | Время: 21.81с | Train Loss: 0.2286 | Val Loss: 0.4488 | Val Acc: 0.8118 | Val F1: 0.8333\n",
251
+ "Эпоха 10/20 | Время: 21.65с | Train Loss: 0.2056 | Val Loss: 0.3876 | Val Acc: 0.8402 | Val F1: 0.8350\n",
252
+ "Эпоха 11/20 | Время: 21.60с | Train Loss: 0.1803 | Val Loss: 0.3949 | Val Acc: 0.8358 | Val F1: 0.8385\n",
253
+ " -> Модель сохранена, новый лучший Val F1: 0.8385\n",
254
+ "Эпоха 12/20 | Время: 21.57с | Train Loss: 0.1605 | Val Loss: 0.4024 | Val Acc: 0.8360 | Val F1: 0.8414\n",
255
+ " -> Модель сохранена, новый лучший Val F1: 0.8414\n",
256
+ "Эпоха 13/20 | Время: 21.65с | Train Loss: 0.1392 | Val Loss: 0.4087 | Val Acc: 0.8356 | Val F1: 0.8340\n",
257
+ "Эпоха 14/20 | Время: 21.57с | Train Loss: 0.1172 | Val Loss: 0.4315 | Val Acc: 0.8323 | Val F1: 0.8297\n",
258
+ "Эпоха 15/20 | Время: 21.56с | Train Loss: 0.1005 | Val Loss: 0.4626 | Val Acc: 0.8317 | Val F1: 0.8284\n",
259
+ "Эпоха 16/20 | Время: 21.65с | Train Loss: 0.0876 | Val Loss: 0.4680 | Val Acc: 0.8318 | Val F1: 0.8335\n",
260
+ "Эпоха 17/20 | Время: 21.73с | Train Loss: 0.0728 | Val Loss: 0.4823 | Val Acc: 0.8317 | Val F1: 0.8326\n",
261
+ "Эпоха 18/20 | Время: 21.55с | Train Loss: 0.0656 | Val Loss: 0.5540 | Val Acc: 0.8206 | Val F1: 0.8068\n",
262
+ "Эпоха 19/20 | Время: 21.56с | Train Loss: 0.0491 | Val Loss: 0.6002 | Val Acc: 0.8235 | Val F1: 0.8178\n",
263
+ "Эпоха 20/20 | Время: 21.60с | Train Loss: 0.0445 | Val Loss: 0.5776 | Val Acc: 0.8314 | Val F1: 0.8318\n",
264
+ "--- Обучение модели Lib_Transformer завершено ---\n",
265
+ "--- Оценка лучшей модели Lib_Transformer на тестовых данных ---\n",
266
+ "Результаты для Lib_Transformer: {'loss': 0.3889380347202806, 'accuracy': 0.84488, 'precision': 0.8214765100671141, 'recall': 0.88128, 'f1_score': 0.8503280586646083}\n",
267
+ "------------------------------------------------------------\n",
268
+ "\n",
269
+ "\n",
270
+ "--- Итоговая таблица сравнения моделей на тестовых данных ---\n",
271
+ " loss accuracy precision recall f1_score\n",
272
+ "CustomMamba 0.437020 0.83280 0.801712 0.884320 0.840992\n",
273
+ "CustomMamba_polarity 0.567920 0.73850 0.688808 0.869522 0.768686\n",
274
+ "Lib_Transformer 0.388938 0.84488 0.821477 0.881280 0.850328\n",
275
+ "Lib_Transformer_polarity 0.543388 0.73980 0.690897 0.867320 0.769122\n"
276
+ ]
277
+ }
278
+ ],
279
+ "source": [
280
+ "model_configs = {\n",
281
+ " \"CustomMamba\": {\n",
282
+ " \"class\": CustomMambaClassifier,\n",
283
+ " \"params\": {'vocab_size': VOCAB_SIZE, 'd_model': EMBEDDING_DIM, 'd_state': 8, \n",
284
+ " 'd_conv': 4, 'num_layers': 2, 'num_classes': NUM_CLASSES},\n",
285
+ " },\n",
286
+ "\n",
287
+ " \"Lib_Transformer\": {\n",
288
+ " \"class\": TransformerClassifier,\n",
289
+ " \"params\": {'vocab_size': VOCAB_SIZE, 'embed_dim': EMBEDDING_DIM, 'num_heads': 8, \n",
290
+ " 'num_layers': 4, 'num_classes': NUM_CLASSES, 'max_seq_len': MAX_SEQ_LEN},\n",
291
+ " # num_layers: 2 -> 4\n",
292
+ " # num_heads: 4 -> 8\n",
293
+ " },\n",
294
+ "}\n",
295
+ "\n",
296
+ "results = {}\n",
297
+ "for model_name, config in model_configs.items():\n",
298
+ "\n",
299
+ " model_path = os.path.join(SAVE_DIR, f\"best_model_{model_name.lower()}.pth\")\n",
300
+ " \n",
301
+ " model = config['class'](**config['params']).to(DEVICE)\n",
302
+ " optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
303
+ " criterion = nn.CrossEntropyLoss()\n",
304
+ " \n",
305
+ " train_and_evaluate(\n",
306
+ " model=model, train_loader=train_loader, val_loader=val_loader,\n",
307
+ " optimizer=optimizer, criterion=criterion, num_epochs=NUM_EPOCHS,\n",
308
+ " device=DEVICE, model_name=model_name, save_path=model_path\n",
309
+ " )\n",
310
+ " \n",
311
+ " print(f\"--- Оценка лучшей модели {model_name} на тестовых данных ---\")\n",
312
+ " if os.path.exists(model_path):\n",
313
+ " best_model = config['class'](**config['params']).to(DEVICE)\n",
314
+ " best_model.load_state_dict(torch.load(model_path))\n",
315
+ " test_metrics = evaluate_on_test(best_model, test_loader, DEVICE, criterion)\n",
316
+ " results[model_name] = test_metrics\n",
317
+ " results[model_name + \"_polarity\"] = evaluate_on_test(best_model, test_loader_polarity, DEVICE, criterion)\n",
318
+ " print(f\"Результаты для {model_name}: {test_metrics}\")\n",
319
+ " else:\n",
320
+ " print(f\"Файл лучшей модели для {model_name} не найден. Пропускаем оценку.\")\n",
321
+ "\n",
322
+ " print(\"-\" * 60)\n",
323
+ " \n",
324
+ "if results:\n",
325
+ " results_df = pd.DataFrame(results).T\n",
326
+ " print(\"\\n\\n--- Итоговая таблица сравнения моделей на тестовых данных ---\")\n",
327
+ " print(results_df.to_string())\n",
328
+ "else:\n",
329
+ " print(\"Не удалось получить результаты ни для одной модели.\")"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "markdown",
334
+ "metadata": {},
335
+ "source": [
336
+ "Видим, что transformer сильно выигрывает у мамбы как и по времени, так и по качеству. Дальше посмотрим как они справляются на данных из других датасетов"
337
+ ]
338
+ }
339
+ ],
340
+ "metadata": {
341
+ "kernelspec": {
342
+ "display_name": "monkey-coding-dl-project-F4QJzkF_-py3.12",
343
+ "language": "python",
344
+ "name": "python3"
345
+ },
346
+ "language_info": {
347
+ "codemirror_mode": {
348
+ "name": "ipython",
349
+ "version": 3
350
+ },
351
+ "file_extension": ".py",
352
+ "mimetype": "text/x-python",
353
+ "name": "python",
354
+ "nbconvert_exporter": "python",
355
+ "pygments_lexer": "ipython3",
356
+ "version": "3.12.11"
357
+ }
358
+ },
359
+ "nbformat": 4,
360
+ "nbformat_minor": 2
361
+ }
notebooks/models_comparations_second_dataset.ipynb ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "Восстановим константы, словарь и модели из прошлого нотубка"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 1,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import os\n",
17
+ "import pandas as pd\n",
18
+ "import torch\n",
19
+ "import torch.nn as nn\n",
20
+ "import torch.nn.functional as F\n",
21
+ "from torch.utils.data import DataLoader, TensorDataset\n",
22
+ "\n",
23
+ "from src.models.models import TransformerClassifier, LSTMClassifier, CustomMambaClassifier, SimpleMambaBlock\n",
24
+ "from src.data_utils.config import DatasetConfig\n",
25
+ "from src.data_utils.dataset_params import DatasetName\n",
26
+ "from src.data_utils.dataset_generator import DatasetGenerator\n",
27
+ "\n",
28
+ "MAX_SEQ_LEN = 300\n",
29
+ "EMBEDDING_DIM = 128\n",
30
+ "BATCH_SIZE = 32 \n",
31
+ "NUM_CLASSES = 2\n",
32
+ "SAVE_DIR = \"../pretrained_comparison\" \n",
33
+ "DATA_DIR = \"../datasets\" \n",
34
+ "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
35
+ "\n",
36
+ "config = DatasetConfig(load_from_disk=True, path_to_data=DATA_DIR)\n",
37
+ "generator = DatasetGenerator(DatasetName.IMDB, config=config)\n",
38
+ "\n",
39
+ "_, _, _ = generator.generate_dataset() \n",
40
+ "vocab = generator.vocab\n",
41
+ "VOCAB_SIZE = len(vocab)\n",
42
+ "text_processor = generator.get_text_processor()"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "markdown",
47
+ "metadata": {},
48
+ "source": [
49
+ "Возьмем всопомгательную функцию из пролшло нотубка"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": 2,
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n",
59
+ "\n",
60
+ "def evaluate_on_test(model, test_loader, device, criterion):\n",
61
+ " model.eval()\n",
62
+ " total_test_loss = 0\n",
63
+ " all_preds = []\n",
64
+ " all_labels = []\n",
65
+ "\n",
66
+ " with torch.no_grad():\n",
67
+ " for batch_X, batch_y in test_loader:\n",
68
+ " batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
69
+ " outputs = model(batch_X)\n",
70
+ " loss = criterion(outputs, batch_y)\n",
71
+ " total_test_loss += loss.item()\n",
72
+ " \n",
73
+ " _, predicted = torch.max(outputs.data, 1)\n",
74
+ " all_preds.extend(predicted.cpu().numpy())\n",
75
+ " all_labels.extend(batch_y.cpu().numpy())\n",
76
+ " \n",
77
+ " avg_test_loss = total_test_loss / len(test_loader)\n",
78
+ " \n",
79
+ " accuracy = accuracy_score(all_labels, all_preds)\n",
80
+ " precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
81
+ " \n",
82
+ " return {'loss': avg_test_loss, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1_score': f1}"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "markdown",
87
+ "metadata": {},
88
+ "source": [
89
+ "Создадим генератор датасета и передадим в него уже готовый текстовый процессор, заберем датасет из другого распределения"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": 3,
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "\n",
99
+ "def create_dataloader(X, y, batch_size, shuffle=True):\n",
100
+ " X_tensor = torch.as_tensor(X, dtype=torch.long)\n",
101
+ " y_tensor = torch.as_tensor(y, dtype=torch.long)\n",
102
+ " dataset = TensorDataset(X_tensor, y_tensor)\n",
103
+ " return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)\n",
104
+ "\n",
105
+ "text_processor = generator.get_text_processor()\n",
106
+ "config_polarity = DatasetConfig(\n",
107
+ " load_from_disk=True,\n",
108
+ " path_to_data=\"../datasets\",\n",
109
+ " train_size=25000, # взяли весь датасет\n",
110
+ " val_size=12500,\n",
111
+ " test_size=12500,\n",
112
+ " build_vocab=False\n",
113
+ ")\n",
114
+ "generator_polarity = DatasetGenerator(DatasetName.POLARITY, config=config_polarity)\n",
115
+ "generator_polarity.vocab = generator.vocab\n",
116
+ "generator_polarity.id2word = generator.id2word\n",
117
+ "generator_polarity.text_processor = text_processor\n",
118
+ "(X_train, y_train), (X_val, y_val), (X_test, y_test) = generator_polarity.generate_dataset()\n",
119
+ "\n",
120
+ "\n",
121
+ "test_loader = create_dataloader(X_test, y_test, BATCH_SIZE, shuffle=False)"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "markdown",
126
+ "metadata": {},
127
+ "source": [
128
+ "Восстановим конфигурации конфигов моделей из прошлого нотубка"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": 4,
134
+ "metadata": {},
135
+ "outputs": [],
136
+ "source": [
137
+ "model_configs = {\n",
138
+ " \"CustomMamba\": {\n",
139
+ " \"class\": CustomMambaClassifier,\n",
140
+ " \"params\": {'vocab_size': VOCAB_SIZE, 'd_model': EMBEDDING_DIM, 'd_state': 8, \n",
141
+ " 'd_conv': 4, 'num_layers': 2, 'num_classes': NUM_CLASSES},\n",
142
+ " },\n",
143
+ " \"Lib_LSTM\": {\n",
144
+ " \"class\": LSTMClassifier,\n",
145
+ " \"params\": {'vocab_size': VOCAB_SIZE, 'embed_dim': EMBEDDING_DIM, 'hidden_dim': 128, \n",
146
+ " 'num_layers': 2, 'num_classes': NUM_CLASSES, 'dropout': 0.5},\n",
147
+ " },\n",
148
+ " \"Lib_Transformer\": {\n",
149
+ " \"class\": TransformerClassifier,\n",
150
+ " \"params\": {'vocab_size': VOCAB_SIZE, 'embed_dim': EMBEDDING_DIM, 'num_heads': 4, \n",
151
+ " 'num_layers': 2, 'num_classes': NUM_CLASSES, 'max_seq_len': MAX_SEQ_LEN},\n",
152
+ " },\n",
153
+ "}"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "markdown",
158
+ "metadata": {},
159
+ "source": [
160
+ "Теперь посмотрим на результаты"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": 5,
166
+ "metadata": {},
167
+ "outputs": [
168
+ {
169
+ "name": "stderr",
170
+ "output_type": "stream",
171
+ "text": [
172
+ "/home/gab1k/.cache/pypoetry/virtualenvs/monkey-coding-dl-project-F4QJzkF_-py3.12/lib/python3.12/site-packages/torch/nn/modules/transformer.py:505: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. We recommend specifying layout=torch.jagged when constructing a nested tensor, as this layout receives active development, has better operator coverage, and works with torch.compile. (Triggered internally at /pytorch/aten/src/ATen/NestedTensorImpl.cpp:178.)\n",
173
+ " output = torch._nested_tensor_from_mask(\n"
174
+ ]
175
+ },
176
+ {
177
+ "name": "stdout",
178
+ "output_type": "stream",
179
+ "text": [
180
+ "\n",
181
+ "\n",
182
+ "--- Итоговая таблица сравнения моделей на тестовых данных ---\n",
183
+ " loss accuracy precision recall f1_score\n",
184
+ "CustomMamba 0.583675 0.70344 0.653410 0.871734 0.746945\n",
185
+ "Lib_LSTM 0.675894 0.59520 0.574803 0.744423 0.648709\n",
186
+ "Lib_Transformer 0.618924 0.66432 0.612190 0.904238 0.730091\n"
187
+ ]
188
+ }
189
+ ],
190
+ "source": [
191
+ "results = {}\n",
192
+ "for model_name, config in model_configs.items(): \n",
193
+ " model_path = os.path.join(SAVE_DIR, f\"best_model_{model_name.lower()}.pth\") \n",
194
+ " model = config['class'](**config['params']).to(DEVICE)\n",
195
+ "\n",
196
+ " model.load_state_dict(torch.load(model_path, map_location=DEVICE))\n",
197
+ " criterion = nn.CrossEntropyLoss()\n",
198
+ " test_metrics = evaluate_on_test(model, test_loader, DEVICE, criterion)\n",
199
+ " results[model_name] = test_metrics\n",
200
+ " \n",
201
+ "results_df = pd.DataFrame(results).T\n",
202
+ "print(\"\\n\\n--- Итоговая таблица сравнения моделей на тестовых данных ---\")\n",
203
+ "print(results_df.to_string())\n"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "markdown",
208
+ "metadata": {},
209
+ "source": [
210
+ "Снимали тут на \"игрушечных данных\". На даже на них видно, что:\n",
211
+ " - accuracy выше всего на Mamba\n",
212
+ " - Трансформер справился тоже неплохо\n",
213
+ " - LSTM опять проиграл\n",
214
+ "\n",
215
+ "В следующем нотбуке обучим Mamba и Transformer на всем датасете и снимем качество на втором. Та модель, которая будет лучше, \"поедет в продакшн\" "
216
+ ]
217
+ }
218
+ ],
219
+ "metadata": {
220
+ "kernelspec": {
221
+ "display_name": "monkey-coding-dl-project-F4QJzkF_-py3.12",
222
+ "language": "python",
223
+ "name": "python3"
224
+ },
225
+ "language_info": {
226
+ "codemirror_mode": {
227
+ "name": "ipython",
228
+ "version": 3
229
+ },
230
+ "file_extension": ".py",
231
+ "mimetype": "text/x-python",
232
+ "name": "python",
233
+ "nbconvert_exporter": "python",
234
+ "pygments_lexer": "ipython3",
235
+ "version": "3.12.11"
236
+ }
237
+ },
238
+ "nbformat": 4,
239
+ "nbformat_minor": 2
240
+ }
notebooks/train.ipynb CHANGED
@@ -1,68 +1,73 @@
1
  {
2
  "cells": [
 
 
 
 
 
 
 
3
  {
4
  "cell_type": "code",
5
  "execution_count": 1,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
9
- "import warnings\n",
10
- "for warn in [UserWarning, FutureWarning]: warnings.filterwarnings(\"ignore\", category = warn)\n",
11
- "\n",
12
  "import os\n",
13
  "import time\n",
14
  "import json\n",
15
  "import torch\n",
 
 
 
16
  "import torch.nn as nn\n",
17
  "import torch.optim as optim\n",
18
- "\n",
19
  "from torch.utils.data import DataLoader, TensorDataset\n",
 
 
20
  "\n",
21
- "# Импортируем классы моделей из нашего файла\n",
22
- "from src.models.models import TransformerClassifier, MambaClassifier, LSTMClassifier\n"
23
- ]
24
- },
25
- {
26
- "cell_type": "code",
27
- "execution_count": 2,
28
- "metadata": {},
29
- "outputs": [],
30
- "source": [
31
- "MODEL_TO_TRAIN = 'Transformer' \n",
32
  "\n",
33
- "# Гиперпараметры данных и модели\n",
34
  "MAX_SEQ_LEN = 300\n",
35
- "EMBEDDING_DIM = 128\n",
36
- "BATCH_SIZE = 32\n",
37
- "LEARNING_RATE = 1e-4\n",
38
- "NUM_EPOCHS = 5 # Увеличим для лучшего результата\n",
 
39
  "\n",
40
- "# Пути для сохранения артефактов\n",
41
  "SAVE_DIR = \"../pretrained\"\n",
42
  "os.makedirs(SAVE_DIR, exist_ok=True)\n",
43
  "MODEL_SAVE_PATH = os.path.join(SAVE_DIR, \"best_model.pth\")\n",
44
  "VOCAB_SAVE_PATH = os.path.join(SAVE_DIR, \"vocab.json\")\n",
45
  "CONFIG_SAVE_PATH = os.path.join(SAVE_DIR, \"config.json\")\n",
46
- "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n"
 
 
 
 
 
 
 
 
 
 
 
 
47
  ]
48
  },
49
  {
50
- "cell_type": "code",
51
- "execution_count": 3,
52
  "metadata": {},
53
- "outputs": [],
54
  "source": [
55
- "from src.data_utils.dataset_generator import DatasetGenerator\n",
56
- "from src.data_utils.dataset_params import DatasetName\n",
57
- "\n",
58
- "generator = DatasetGenerator(DatasetName.IMDB)\n",
59
- "(X_train, y_train), (X_val, y_val), (X_test, y_test) = generator.generate_dataset()\n",
60
- "X_train, y_train, X_val, y_val, X_test, y_test = X_train[:1000], y_train[:1000], X_val[:100], y_val[:100], X_test[:100], y_test[:100]"
61
  ]
62
  },
63
  {
64
  "cell_type": "code",
65
- "execution_count": 4,
66
  "metadata": {},
67
  "outputs": [],
68
  "source": [
@@ -73,45 +78,156 @@
73
  "val_loader = create_dataloader(X_val, y_val, BATCH_SIZE)"
74
  ]
75
  },
 
 
 
 
 
 
 
76
  {
77
  "cell_type": "code",
78
- "execution_count": 5,
79
  "metadata": {},
80
  "outputs": [],
81
  "source": [
82
- "model_params = {}\n",
83
- "if MODEL_TO_TRAIN == 'Transformer':\n",
84
- " model_params = {'vocab_size': len(generator.vocab), 'embed_dim': EMBEDDING_DIM, 'num_heads': 4, 'num_layers': 2, 'num_classes': 2, 'max_seq_len': MAX_SEQ_LEN}\n",
85
- " model = TransformerClassifier(**model_params)\n",
86
- "elif MODEL_TO_TRAIN == 'Mamba':\n",
87
- " model_params = {'vocab_size': len(generator.vocab), 'embed_dim': EMBEDDING_DIM, 'mamba_d_state': 16, 'mamba_d_conv': 4, 'mamba_expand': 2, 'num_classes': 2}\n",
88
- " model = MambaClassifier(**model_params)\n",
89
- "elif MODEL_TO_TRAIN == 'LSTM':\n",
90
- " model_params = {'vocab_size': len(generator.vocab), 'embed_dim': EMBEDDING_DIM, 'hidden_dim': 256, 'num_layers': 2, 'num_classes': 2, 'dropout': 0.5}\n",
91
- " model = LSTMClassifier(**model_params)\n",
92
- "else:\n",
93
- " raise ValueError(\"Неизвестный тип модели. Выберите 'Transformer', 'Mamba' или 'LSTM'\")"
94
  ]
95
  },
96
  {
97
  "cell_type": "code",
98
- "execution_count": 6,
99
  "metadata": {},
100
  "outputs": [
101
  {
102
  "name": "stdout",
103
  "output_type": "stream",
104
  "text": [
105
- "--- Начало обучения модели: Transformer ---\n",
106
- "Эпоха 1/5 | Время: 17.06с | Train Loss: 0.7023 | Val Loss: 0.7095 | Val Acc: 0.4000\n",
107
- " -> Модель сохранена, новая лучшая Val Loss: 0.7095\n",
108
- "Эпоха 2/5 | Время: 16.40с | Train Loss: 0.6682 | Val Loss: 0.6937 | Val Acc: 0.4800\n",
109
- " -> Модель сохранена, новая лучшая Val Loss: 0.6937\n",
110
- "Эпоха 3/5 | Время: 16.13с | Train Loss: 0.6471 | Val Loss: 0.7075 | Val Acc: 0.4100\n",
111
- "Эпоха 4/5 | Время: 16.36с | Train Loss: 0.6283 | Val Loss: 0.6917 | Val Acc: 0.5300\n",
112
- " -> Модель сохранена, новая лучшая Val Loss: 0.6917\n",
113
- "Эпоха 5/5 | Время: 16.39с | Train Loss: 0.6050 | Val Loss: 0.6871 | Val Acc: 0.5300\n",
114
- " -> Модель сохранена, новая лучшая Val Loss: 0.6871\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  ]
116
  }
117
  ],
@@ -121,7 +237,7 @@
121
  "criterion = nn.CrossEntropyLoss()\n",
122
  "\n",
123
  "best_val_loss = float('inf')\n",
124
- "print(f\"--- Начало обучения модели: {MODEL_TO_TRAIN} ---\")\n",
125
  "for epoch in range(NUM_EPOCHS):\n",
126
  " model.train()\n",
127
  " start_time = time.time()\n",
@@ -160,10 +276,143 @@
160
  " print(f\" -> Модель сохранена, новая лучшая Val Loss: {best_val_loss:.4f}\")"
161
  ]
162
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  {
164
  "cell_type": "code",
165
  "execution_count": 7,
166
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  "outputs": [
168
  {
169
  "name": "stdout",
@@ -184,20 +433,13 @@
184
  "}\n",
185
  "with open(CONFIG_SAVE_PATH, 'w', encoding='utf-8') as f:\n",
186
  " json.dump(config, f, ensure_ascii=False, indent=4)\n",
187
- "print(f\"Конфигурация модели сохранена в: {CONFIG_SAVE_PATH}\")\n"
188
  ]
189
- },
190
- {
191
- "cell_type": "code",
192
- "execution_count": null,
193
- "metadata": {},
194
- "outputs": [],
195
- "source": []
196
  }
197
  ],
198
  "metadata": {
199
  "kernelspec": {
200
- "display_name": "monkey-coding-dl-project-OWiM8ypK-py3.12",
201
  "language": "python",
202
  "name": "python3"
203
  },
@@ -211,7 +453,7 @@
211
  "name": "python",
212
  "nbconvert_exporter": "python",
213
  "pygments_lexer": "ipython3",
214
- "version": "3.12.3"
215
  }
216
  },
217
  "nbformat": 4,
 
1
  {
2
  "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "Инициализация глобальных переменных, достаем датасет"
8
+ ]
9
+ },
10
  {
11
  "cell_type": "code",
12
  "execution_count": 1,
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
 
 
 
16
  "import os\n",
17
  "import time\n",
18
  "import json\n",
19
  "import torch\n",
20
+ "import warnings\n",
21
+ "import numpy as np\n",
22
+ "import pandas as pd\n",
23
  "import torch.nn as nn\n",
24
  "import torch.optim as optim\n",
 
25
  "from torch.utils.data import DataLoader, TensorDataset\n",
26
+ "from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n",
27
+ "for warn in [UserWarning, FutureWarning]: warnings.filterwarnings(\"ignore\", category = warn)\n",
28
  "\n",
29
+ "from src.data_utils.config import DatasetConfig\n",
30
+ "from src.data_utils.dataset_params import DatasetName\n",
31
+ "from src.data_utils.dataset_generator import DatasetGenerator\n",
32
+ "from src.models.models import TransformerClassifier\n",
 
 
 
 
 
 
 
33
  "\n",
 
34
  "MAX_SEQ_LEN = 300\n",
35
+ "EMBEDDING_DIM = 64 # уменьшили: 128 -> 64, чтобы влезло в гит\n",
36
+ "BATCH_SIZE = 64 # подняли batch_size: 32 -> 64\n",
37
+ "LEARNING_RATE = 7e-5\n",
38
+ "NUM_EPOCHS = 100 # подняли количество эпох: 20 -> 100\n",
39
+ "NUM_CLASSES = 2\n",
40
  "\n",
 
41
  "SAVE_DIR = \"../pretrained\"\n",
42
  "os.makedirs(SAVE_DIR, exist_ok=True)\n",
43
  "MODEL_SAVE_PATH = os.path.join(SAVE_DIR, \"best_model.pth\")\n",
44
  "VOCAB_SAVE_PATH = os.path.join(SAVE_DIR, \"vocab.json\")\n",
45
  "CONFIG_SAVE_PATH = os.path.join(SAVE_DIR, \"config.json\")\n",
46
+ "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
47
+ "MODEL_TO_TRAIN = 'Transformer' \n",
48
+ "\n",
49
+ "config = DatasetConfig(\n",
50
+ " load_from_disk=True,\n",
51
+ " path_to_data=\"../datasets\",\n",
52
+ " train_size=25000, # взяли весь датасет\n",
53
+ " val_size=12500,\n",
54
+ " test_size=12500\n",
55
+ ")\n",
56
+ "generator = DatasetGenerator(DatasetName.IMDB, config=config)\n",
57
+ "(X_train, y_train), (X_val, y_val), (X_test, y_test) = generator.generate_dataset()\n",
58
+ "VOCAB_SIZE = len(generator.vocab)"
59
  ]
60
  },
61
  {
62
+ "cell_type": "markdown",
 
63
  "metadata": {},
 
64
  "source": [
65
+ "Создаем даталоадеры"
 
 
 
 
 
66
  ]
67
  },
68
  {
69
  "cell_type": "code",
70
+ "execution_count": 2,
71
  "metadata": {},
72
  "outputs": [],
73
  "source": [
 
78
  "val_loader = create_dataloader(X_val, y_val, BATCH_SIZE)"
79
  ]
80
  },
81
+ {
82
+ "cell_type": "markdown",
83
+ "metadata": {},
84
+ "source": [
85
+ "Инициализация модели"
86
+ ]
87
+ },
88
  {
89
  "cell_type": "code",
90
+ "execution_count": 3,
91
  "metadata": {},
92
  "outputs": [],
93
  "source": [
94
+ "model_params = {'vocab_size': len(generator.vocab), 'embed_dim': EMBEDDING_DIM, 'num_heads': 8, 'num_layers': 4, 'num_classes': 2, 'max_seq_len': MAX_SEQ_LEN}\n",
95
+ "model = TransformerClassifier(**model_params)"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "markdown",
100
+ "metadata": {},
101
+ "source": [
102
+ "Обучение"
 
 
 
103
  ]
104
  },
105
  {
106
  "cell_type": "code",
107
+ "execution_count": 4,
108
  "metadata": {},
109
  "outputs": [
110
  {
111
  "name": "stdout",
112
  "output_type": "stream",
113
  "text": [
114
+ "--- Начало обучения модели ---\n",
115
+ "Эпоха 1/100 | Время: 14.62с | Train Loss: 0.6563 | Val Loss: 0.6460 | Val Acc: 0.6544\n",
116
+ " -> Модель сохранена, новая лучшая Val Loss: 0.6460\n",
117
+ "Эпоха 2/100 | Время: 14.10с | Train Loss: 0.5749 | Val Loss: 0.5673 | Val Acc: 0.7217\n",
118
+ " -> Модель сохранена, новая лучшая Val Loss: 0.5673\n",
119
+ "Эпоха 3/100 | Время: 14.13с | Train Loss: 0.5058 | Val Loss: 0.5285 | Val Acc: 0.7533\n",
120
+ " -> Модель сохранена, новая лучшая Val Loss: 0.5285\n",
121
+ "Эпоха 4/100 | Время: 14.09с | Train Loss: 0.4664 | Val Loss: 0.4980 | Val Acc: 0.7724\n",
122
+ " -> Модель сохранена, новая лучшая Val Loss: 0.4980\n",
123
+ "Эпоха 5/100 | Время: 14.22с | Train Loss: 0.4382 | Val Loss: 0.4785 | Val Acc: 0.7851\n",
124
+ " -> Модель сохранена, новая лучшая Val Loss: 0.4785\n",
125
+ "Эпоха 6/100 | Время: 14.29с | Train Loss: 0.4166 | Val Loss: 0.4775 | Val Acc: 0.7814\n",
126
+ " -> Модель сохранена, новая лучшая Val Loss: 0.4775\n",
127
+ "Эпоха 7/100 | Время: 14.14с | Train Loss: 0.3974 | Val Loss: 0.4636 | Val Acc: 0.7893\n",
128
+ " -> Модель сохранена, новая лучшая Val Loss: 0.4636\n",
129
+ "Эпоха 8/100 | Время: 14.11с | Train Loss: 0.3778 | Val Loss: 0.4689 | Val Acc: 0.7874\n",
130
+ "Эпоха 9/100 | Время: 14.30с | Train Loss: 0.3595 | Val Loss: 0.4491 | Val Acc: 0.7973\n",
131
+ " -> Модель сохранена, новая лучшая Val Loss: 0.4491\n",
132
+ "Эпоха 10/100 | Время: 14.23с | Train Loss: 0.3438 | Val Loss: 0.4236 | Val Acc: 0.8148\n",
133
+ " -> Модель сохранена, новая лучшая Val Loss: 0.4236\n",
134
+ "Эпоха 11/100 | Время: 14.45с | Train Loss: 0.3301 | Val Loss: 0.4173 | Val Acc: 0.8174\n",
135
+ " -> Модель сохранена, новая лучшая Val Loss: 0.4173\n",
136
+ "Эпоха 12/100 | Время: 14.21с | Train Loss: 0.3202 | Val Loss: 0.4140 | Val Acc: 0.8206\n",
137
+ " -> Модель сохранена, новая лучшая Val Loss: 0.4140\n",
138
+ "Эпоха 13/100 | Время: 14.10с | Train Loss: 0.3076 | Val Loss: 0.4079 | Val Acc: 0.8243\n",
139
+ " -> Модель сохранена, новая лучшая Val Loss: 0.4079\n",
140
+ "Эпоха 14/100 | Время: 14.07с | Train Loss: 0.2959 | Val Loss: 0.4091 | Val Acc: 0.8220\n",
141
+ "Эпоха 15/100 | Время: 14.06с | Train Loss: 0.2875 | Val Loss: 0.4074 | Val Acc: 0.8256\n",
142
+ " -> Модель сохранена, новая лучшая Val Loss: 0.4074\n",
143
+ "Эпоха 16/100 | Время: 14.25с | Train Loss: 0.2758 | Val Loss: 0.4021 | Val Acc: 0.8285\n",
144
+ " -> Модель сохранена, новая лучшая Val Loss: 0.4021\n",
145
+ "Эпоха 17/100 | Время: 14.17с | Train Loss: 0.2658 | Val Loss: 0.3933 | Val Acc: 0.8314\n",
146
+ " -> Модель сохранена, новая лучшая Val Loss: 0.3933\n",
147
+ "Эпоха 18/100 | Время: 14.21с | Train Loss: 0.2558 | Val Loss: 0.4100 | Val Acc: 0.8232\n",
148
+ "Эпоха 19/100 | Время: 14.14с | Train Loss: 0.2518 | Val Loss: 0.3940 | Val Acc: 0.8324\n",
149
+ "Эпоха 20/100 | Время: 14.14с | Train Loss: 0.2365 | Val Loss: 0.3934 | Val Acc: 0.8304\n",
150
+ "Эпоха 21/100 | Время: 14.08с | Train Loss: 0.2283 | Val Loss: 0.3913 | Val Acc: 0.8336\n",
151
+ " -> Модель сохранена, новая лучшая Val Loss: 0.3913\n",
152
+ "Эпоха 22/100 | Время: 14.06с | Train Loss: 0.2215 | Val Loss: 0.4161 | Val Acc: 0.8250\n",
153
+ "Эпоха 23/100 | Время: 14.25с | Train Loss: 0.2100 | Val Loss: 0.3956 | Val Acc: 0.8334\n",
154
+ "Эпоха 24/100 | Время: 14.25с | Train Loss: 0.2018 | Val Loss: 0.3957 | Val Acc: 0.8333\n",
155
+ "Эпоха 25/100 | Время: 14.18с | Train Loss: 0.1941 | Val Loss: 0.3942 | Val Acc: 0.8352\n",
156
+ "Эпоха 26/100 | Время: 14.34с | Train Loss: 0.1811 | Val Loss: 0.3998 | Val Acc: 0.8349\n",
157
+ "Эпоха 27/100 | Время: 14.18с | Train Loss: 0.1797 | Val Loss: 0.4078 | Val Acc: 0.8318\n",
158
+ "Эпоха 28/100 | Время: 14.12с | Train Loss: 0.1667 | Val Loss: 0.4101 | Val Acc: 0.8339\n",
159
+ "Эпоха 29/100 | Время: 14.20с | Train Loss: 0.1610 | Val Loss: 0.4119 | Val Acc: 0.8335\n",
160
+ "Эпоха 30/100 | Время: 14.25с | Train Loss: 0.1507 | Val Loss: 0.4397 | Val Acc: 0.8294\n",
161
+ "Эпоха 31/100 | Время: 14.19с | Train Loss: 0.1400 | Val Loss: 0.4245 | Val Acc: 0.8330\n",
162
+ "Эпоха 32/100 | Время: 14.11с | Train Loss: 0.1361 | Val Loss: 0.4271 | Val Acc: 0.8338\n",
163
+ "Эпоха 33/100 | Время: 14.13с | Train Loss: 0.1254 | Val Loss: 0.4434 | Val Acc: 0.8311\n",
164
+ "Эпоха 34/100 | Время: 14.15с | Train Loss: 0.1190 | Val Loss: 0.4446 | Val Acc: 0.8306\n",
165
+ "Эпоха 35/100 | Время: 14.16с | Train Loss: 0.1125 | Val Loss: 0.4728 | Val Acc: 0.8271\n",
166
+ "Эпоха 36/100 | Время: 14.23с | Train Loss: 0.1068 | Val Loss: 0.4670 | Val Acc: 0.8297\n",
167
+ "Эпоха 37/100 | Время: 14.19с | Train Loss: 0.0976 | Val Loss: 0.5572 | Val Acc: 0.8121\n",
168
+ "Эпоха 38/100 | Время: 14.18с | Train Loss: 0.0923 | Val Loss: 0.4865 | Val Acc: 0.8260\n",
169
+ "Эпоха 39/100 | Время: 14.16с | Train Loss: 0.0865 | Val Loss: 0.5013 | Val Acc: 0.8250\n",
170
+ "Эпоха 40/100 | Время: 14.11с | Train Loss: 0.0822 | Val Loss: 0.5205 | Val Acc: 0.8233\n",
171
+ "Эпоха 41/100 | Время: 14.11с | Train Loss: 0.0731 | Val Loss: 0.5203 | Val Acc: 0.8255\n",
172
+ "Эпоха 42/100 | Время: 14.17с | Train Loss: 0.0693 | Val Loss: 0.5313 | Val Acc: 0.8276\n",
173
+ "Эпоха 43/100 | Время: 14.27с | Train Loss: 0.0645 | Val Loss: 0.5518 | Val Acc: 0.8270\n",
174
+ "Эпоха 44/100 | Время: 14.15с | Train Loss: 0.0577 | Val Loss: 0.5554 | Val Acc: 0.8270\n",
175
+ "Эпоха 45/100 | Время: 14.23с | Train Loss: 0.0561 | Val Loss: 0.5659 | Val Acc: 0.8258\n",
176
+ "Эпоха 46/100 | Время: 14.17с | Train Loss: 0.0526 | Val Loss: 0.5840 | Val Acc: 0.8219\n",
177
+ "Эпоха 47/100 | Время: 14.14с | Train Loss: 0.0457 | Val Loss: 0.6217 | Val Acc: 0.8224\n",
178
+ "Эпоха 48/100 | Время: 14.27с | Train Loss: 0.0420 | Val Loss: 0.6294 | Val Acc: 0.8237\n",
179
+ "Эпоха 49/100 | Время: 14.21с | Train Loss: 0.0411 | Val Loss: 0.6333 | Val Acc: 0.8214\n",
180
+ "Эпоха 50/100 | Время: 14.14с | Train Loss: 0.0345 | Val Loss: 0.6566 | Val Acc: 0.8266\n",
181
+ "Эпоха 51/100 | Время: 14.12с | Train Loss: 0.0373 | Val Loss: 0.6504 | Val Acc: 0.8243\n",
182
+ "Эпоха 52/100 | Время: 14.12с | Train Loss: 0.0319 | Val Loss: 0.6640 | Val Acc: 0.8272\n",
183
+ "Эпоха 53/100 | Время: 14.13с | Train Loss: 0.0286 | Val Loss: 0.6896 | Val Acc: 0.8249\n",
184
+ "Эпоха 54/100 | Время: 14.14с | Train Loss: 0.0274 | Val Loss: 0.7036 | Val Acc: 0.8213\n",
185
+ "Эпоха 55/100 | Время: 14.23с | Train Loss: 0.0268 | Val Loss: 0.8750 | Val Acc: 0.7955\n",
186
+ "Эпоха 56/100 | Время: 14.05с | Train Loss: 0.0274 | Val Loss: 0.7306 | Val Acc: 0.8194\n",
187
+ "Эпоха 57/100 | Время: 14.06с | Train Loss: 0.0224 | Val Loss: 0.7345 | Val Acc: 0.8196\n",
188
+ "Эпоха 58/100 | Время: 14.06с | Train Loss: 0.0234 | Val Loss: 0.7029 | Val Acc: 0.8238\n",
189
+ "Эпоха 59/100 | Время: 14.04с | Train Loss: 0.0218 | Val Loss: 0.7278 | Val Acc: 0.8253\n",
190
+ "Эпоха 60/100 | Время: 14.15с | Train Loss: 0.0193 | Val Loss: 0.7509 | Val Acc: 0.8217\n",
191
+ "Эпоха 61/100 | Время: 14.27с | Train Loss: 0.0169 | Val Loss: 0.7706 | Val Acc: 0.8229\n",
192
+ "Эпоха 62/100 | Время: 14.12с | Train Loss: 0.0177 | Val Loss: 0.7659 | Val Acc: 0.8229\n",
193
+ "Эпоха 63/100 | Время: 14.35с | Train Loss: 0.0159 | Val Loss: 0.7892 | Val Acc: 0.8178\n",
194
+ "Эпоха 64/100 | Время: 14.17с | Train Loss: 0.0153 | Val Loss: 0.7721 | Val Acc: 0.8262\n",
195
+ "Эпоха 65/100 | Время: 14.13с | Train Loss: 0.0161 | Val Loss: 0.7746 | Val Acc: 0.8218\n",
196
+ "Эпоха 66/100 | Время: 14.14с | Train Loss: 0.0151 | Val Loss: 0.7781 | Val Acc: 0.8227\n",
197
+ "Эпоха 67/100 | Время: 14.25с | Train Loss: 0.0131 | Val Loss: 0.8032 | Val Acc: 0.8198\n",
198
+ "Эпоха 68/100 | Время: 14.10с | Train Loss: 0.0156 | Val Loss: 0.7780 | Val Acc: 0.8274\n",
199
+ "Эпоха 69/100 | Время: 14.04с | Train Loss: 0.0147 | Val Loss: 0.7967 | Val Acc: 0.8237\n",
200
+ "Эпоха 70/100 | Время: 14.05с | Train Loss: 0.0152 | Val Loss: 0.7833 | Val Acc: 0.8240\n",
201
+ "Эпоха 71/100 | Время: 14.35с | Train Loss: 0.0136 | Val Loss: 0.8180 | Val Acc: 0.8212\n",
202
+ "Эпоха 72/100 | Время: 14.24с | Train Loss: 0.0120 | Val Loss: 0.8000 | Val Acc: 0.8235\n",
203
+ "Эпоха 73/100 | Время: 14.18с | Train Loss: 0.0120 | Val Loss: 0.7985 | Val Acc: 0.8226\n",
204
+ "Эпоха 74/100 | Время: 14.21с | Train Loss: 0.0106 | Val Loss: 0.7959 | Val Acc: 0.8259\n",
205
+ "Эпоха 75/100 | Время: 14.12с | Train Loss: 0.0112 | Val Loss: 0.7925 | Val Acc: 0.8238\n",
206
+ "Эпоха 76/100 | Время: 14.09с | Train Loss: 0.0133 | Val Loss: 0.8455 | Val Acc: 0.8138\n",
207
+ "Эпоха 77/100 | Время: 14.12с | Train Loss: 0.0099 | Val Loss: 0.8086 | Val Acc: 0.8243\n",
208
+ "Эпоха 78/100 | Время: 14.19с | Train Loss: 0.0086 | Val Loss: 0.8051 | Val Acc: 0.8271\n",
209
+ "Эпоха 79/100 | Время: 14.11с | Train Loss: 0.0091 | Val Loss: 0.8212 | Val Acc: 0.8242\n",
210
+ "Эпоха 80/100 | Время: 14.25с | Train Loss: 0.0105 | Val Loss: 0.8192 | Val Acc: 0.8244\n",
211
+ "Эпоха 81/100 | Время: 14.20с | Train Loss: 0.0111 | Val Loss: 0.7825 | Val Acc: 0.8250\n",
212
+ "Эпоха 82/100 | Время: 14.20с | Train Loss: 0.0105 | Val Loss: 0.7885 | Val Acc: 0.8259\n",
213
+ "Эпоха 83/100 | Время: 14.16с | Train Loss: 0.0091 | Val Loss: 0.7950 | Val Acc: 0.8280\n",
214
+ "Эпоха 84/100 | Время: 14.27с | Train Loss: 0.0092 | Val Loss: 0.8490 | Val Acc: 0.8217\n",
215
+ "Эпоха 85/100 | Время: 14.63с | Train Loss: 0.0068 | Val Loss: 0.8464 | Val Acc: 0.8239\n",
216
+ "Эпоха 86/100 | Время: 14.43с | Train Loss: 0.0084 | Val Loss: 0.8344 | Val Acc: 0.8250\n",
217
+ "Эпоха 87/100 | Время: 14.27с | Train Loss: 0.0080 | Val Loss: 0.8242 | Val Acc: 0.8266\n",
218
+ "Эпоха 88/100 | Время: 14.22с | Train Loss: 0.0102 | Val Loss: 0.8427 | Val Acc: 0.8230\n",
219
+ "Эпоха 89/100 | Время: 14.19с | Train Loss: 0.0080 | Val Loss: 0.8097 | Val Acc: 0.8241\n",
220
+ "Эпоха 90/100 | Время: 14.24с | Train Loss: 0.0079 | Val Loss: 0.8986 | Val Acc: 0.8161\n",
221
+ "Эпоха 91/100 | Время: 14.18с | Train Loss: 0.0083 | Val Loss: 0.9104 | Val Acc: 0.8162\n",
222
+ "Эпоха 92/100 | Время: 14.25с | Train Loss: 0.0073 | Val Loss: 0.8569 | Val Acc: 0.8258\n",
223
+ "Эпоха 93/100 | Время: 14.17с | Train Loss: 0.0078 | Val Loss: 0.9992 | Val Acc: 0.8039\n",
224
+ "Эпоха 94/100 | Время: 14.21с | Train Loss: 0.0066 | Val Loss: 0.8613 | Val Acc: 0.8224\n",
225
+ "Эпоха 95/100 | Время: 14.37с | Train Loss: 0.0067 | Val Loss: 0.8378 | Val Acc: 0.8284\n",
226
+ "Эпоха 96/100 | Время: 14.24с | Train Loss: 0.0059 | Val Loss: 0.8703 | Val Acc: 0.8203\n",
227
+ "Эпоха 97/100 | Время: 14.45с | Train Loss: 0.0089 | Val Loss: 0.8341 | Val Acc: 0.8256\n",
228
+ "Эпоха 98/100 | Время: 14.49с | Train Loss: 0.0131 | Val Loss: 0.8256 | Val Acc: 0.8217\n",
229
+ "Эпоха 99/100 | Время: 14.46с | Train Loss: 0.0056 | Val Loss: 0.8518 | Val Acc: 0.8202\n",
230
+ "Эпоха 100/100 | Время: 14.57с | Train Loss: 0.0065 | Val Loss: 0.8770 | Val Acc: 0.8216\n"
231
  ]
232
  }
233
  ],
 
237
  "criterion = nn.CrossEntropyLoss()\n",
238
  "\n",
239
  "best_val_loss = float('inf')\n",
240
+ "print(f\"--- Начало обучения модели ---\")\n",
241
  "for epoch in range(NUM_EPOCHS):\n",
242
  " model.train()\n",
243
  " start_time = time.time()\n",
 
276
  " print(f\" -> Модель сохранена, новая лучшая Val Loss: {best_val_loss:.4f}\")"
277
  ]
278
  },
279
+ {
280
+ "cell_type": "markdown",
281
+ "metadata": {},
282
+ "source": [
283
+ "Снимем качество на тестовых данных из исходного датасета"
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "code",
288
+ "execution_count": 5,
289
+ "metadata": {},
290
+ "outputs": [
291
+ {
292
+ "name": "stdout",
293
+ "output_type": "stream",
294
+ "text": [
295
+ "Метрики на тестовой выборке (из обучаемого датасета) итоговой модели\n",
296
+ "{'loss': 0.8565363820110049, 'accuracy': 0.8276, 'precision': 0.851743686651778, 'recall': 0.79328, 'f1_score': 0.8214729517024273}\n"
297
+ ]
298
+ }
299
+ ],
300
+ "source": [
301
+ "def evaluate_on_test(model, test_loader, device, criterion):\n",
302
+ " model.eval()\n",
303
+ " total_test_loss = 0\n",
304
+ " all_preds = []\n",
305
+ " all_labels = []\n",
306
+ "\n",
307
+ " with torch.no_grad():\n",
308
+ " for batch_X, batch_y in test_loader:\n",
309
+ " batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
310
+ " outputs = model(batch_X)\n",
311
+ " loss = criterion(outputs, batch_y)\n",
312
+ " total_test_loss += loss.item()\n",
313
+ " \n",
314
+ " _, predicted = torch.max(outputs.data, 1)\n",
315
+ " all_preds.extend(predicted.cpu().numpy())\n",
316
+ " all_labels.extend(batch_y.cpu().numpy())\n",
317
+ " \n",
318
+ " avg_test_loss = total_test_loss / len(test_loader)\n",
319
+ " \n",
320
+ " accuracy = accuracy_score(all_labels, all_preds)\n",
321
+ " precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
322
+ " \n",
323
+ " return {'loss': avg_test_loss, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1_score': f1}\n",
324
+ "\n",
325
+ "\n",
326
+ "test_loader = create_dataloader(X_test, y_test, BATCH_SIZE)\n",
327
+ "test_metrics = evaluate_on_test(model, test_loader, DEVICE, criterion)\n",
328
+ "print(f\"Метрики на тестовой выборке (из обучаемого датасета) итоговой модели\\n{test_metrics}\")"
329
+ ]
330
+ },
331
+ {
332
+ "cell_type": "markdown",
333
+ "metadata": {},
334
+ "source": [
335
+ "Снимем качество на тестовых данных нового датасета. Считаем данные"
336
+ ]
337
+ },
338
+ {
339
+ "cell_type": "code",
340
+ "execution_count": 6,
341
+ "metadata": {},
342
+ "outputs": [],
343
+ "source": [
344
+ "text_processor = generator.get_text_processor()\n",
345
+ "config_polarity = DatasetConfig(\n",
346
+ " load_from_disk=True,\n",
347
+ " path_to_data=\"../datasets\",\n",
348
+ " train_size=25000, # взяли весь датасет\n",
349
+ " val_size=12500,\n",
350
+ " test_size=12500,\n",
351
+ " build_vocab=False\n",
352
+ ")\n",
353
+ "generator_polarity = DatasetGenerator(DatasetName.POLARITY, config=config_polarity)\n",
354
+ "generator_polarity.vocab = generator.vocab\n",
355
+ "generator_polarity.id2word = generator.id2word\n",
356
+ "generator_polarity.text_processor = text_processor\n",
357
+ "(X_train, y_train), (X_val, y_val), (X_test, y_test) = generator_polarity.generate_dataset()\n"
358
+ ]
359
+ },
360
  {
361
  "cell_type": "code",
362
  "execution_count": 7,
363
  "metadata": {},
364
+ "outputs": [],
365
+ "source": [
366
+ "test_loader = create_dataloader(X_test, y_test, BATCH_SIZE)"
367
+ ]
368
+ },
369
+ {
370
+ "cell_type": "markdown",
371
+ "metadata": {},
372
+ "source": [
373
+ "Посмтрим на метрики"
374
+ ]
375
+ },
376
+ {
377
+ "cell_type": "code",
378
+ "execution_count": 8,
379
+ "metadata": {},
380
+ "outputs": [
381
+ {
382
+ "name": "stdout",
383
+ "output_type": "stream",
384
+ "text": [
385
+ "Метрики на тестовой выборке (из неизвестного датасета) итоговой модели\n",
386
+ "{'loss': 0.6786724476485836, 'accuracy': 0.73816, 'precision': 0.7227414330218068, 'recall': 0.7762906309751434, 'f1_score': 0.7485595759391565}\n"
387
+ ]
388
+ }
389
+ ],
390
+ "source": [
391
+ "test_metrics = evaluate_on_test(model, test_loader, DEVICE, criterion)\n",
392
+ "print(f\"Метрики на тестовой выборке (из неизвестного датасета) итоговой модели\\n{test_metrics}\")"
393
+ ]
394
+ },
395
+ {
396
+ "cell_type": "markdown",
397
+ "metadata": {},
398
+ "source": [
399
+ "В целом видно, что модель что-то, да выучила. Гипотезы по улучшению:\n",
400
+ " - Больше и разнообразнее данные для обучения\n",
401
+ " - Чем больше словарь - тем лучше\n",
402
+ " - Нужно чтобы тестовый датасет был больше похож на обучаемый"
403
+ ]
404
+ },
405
+ {
406
+ "cell_type": "markdown",
407
+ "metadata": {},
408
+ "source": [
409
+ "Сохранение итоговой модели"
410
+ ]
411
+ },
412
+ {
413
+ "cell_type": "code",
414
+ "execution_count": 9,
415
+ "metadata": {},
416
  "outputs": [
417
  {
418
  "name": "stdout",
 
433
  "}\n",
434
  "with open(CONFIG_SAVE_PATH, 'w', encoding='utf-8') as f:\n",
435
  " json.dump(config, f, ensure_ascii=False, indent=4)\n",
436
+ "print(f\"Конфигурация модели сохранена в: {CONFIG_SAVE_PATH}\")"
437
  ]
 
 
 
 
 
 
 
438
  }
439
  ],
440
  "metadata": {
441
  "kernelspec": {
442
+ "display_name": "monkey-coding-dl-project-F4QJzkF_-py3.12",
443
  "language": "python",
444
  "name": "python3"
445
  },
 
453
  "name": "python",
454
  "nbconvert_exporter": "python",
455
  "pygments_lexer": "ipython3",
456
+ "version": "3.12.11"
457
  }
458
  },
459
  "nbformat": 4,
pretrained/best_model.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7636cf4c7205b64df4b91b2a23620d443de468a91211e074760e64adb24751ba
3
- size 37445685
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0709ee349a04faf4ff0e19d0a91967953dd8aa0ca2fb0863f866fcc3d219debc
3
+ size 29525749
pretrained/config.json CHANGED
@@ -2,10 +2,10 @@
2
  "model_type": "Transformer",
3
  "max_seq_len": 300,
4
  "model_params": {
5
- "vocab_size": 69715,
6
- "embed_dim": 128,
7
- "num_heads": 4,
8
- "num_layers": 2,
9
  "num_classes": 2,
10
  "max_seq_len": 300
11
  }
 
2
  "model_type": "Transformer",
3
  "max_seq_len": 300,
4
  "model_params": {
5
+ "vocab_size": 111829,
6
+ "embed_dim": 64,
7
+ "num_heads": 8,
8
+ "num_layers": 4,
9
  "num_classes": 2,
10
  "max_seq_len": 300
11
  }
pretrained/vocab.json CHANGED
The diff for this file is too large to render. See raw diff
 
src/data_utils/config.py CHANGED
@@ -15,6 +15,7 @@ class DatasetConfig:
15
  min_word_freq: Minimum word frequency to include in vocabulary
16
  load_from_disk: Load dataset from local dir. If false download from huggin face
17
  path_to_data: Path to local dataset data
 
18
  max_seq_len: Maximum sequence length (will be padded/truncated to this)
19
  lowercase: Whether to convert text to lowercase
20
  remove_punct: Whether to remove punctuation
@@ -30,6 +31,7 @@ class DatasetConfig:
30
  min_word_freq: int = 1
31
  load_from_disk: bool = False
32
  path_to_data: str = "./datasets"
 
33
 
34
  max_seq_len: int = 300
35
  lowercase: bool = True
 
15
  min_word_freq: Minimum word frequency to include in vocabulary
16
  load_from_disk: Load dataset from local dir. If false download from huggin face
17
  path_to_data: Path to local dataset data
18
+ build_vocab: Is build vocabulary necessary
19
  max_seq_len: Maximum sequence length (will be padded/truncated to this)
20
  lowercase: Whether to convert text to lowercase
21
  remove_punct: Whether to remove punctuation
 
31
  min_word_freq: int = 1
32
  load_from_disk: bool = False
33
  path_to_data: str = "./datasets"
34
+ build_vocab: bool = True
35
 
36
  max_seq_len: int = 300
37
  lowercase: bool = True
src/data_utils/dataset_generator.py CHANGED
@@ -128,7 +128,8 @@ class DatasetGenerator:
128
  train_texts = train_df[self.dataset_params.content_col_name].tolist()
129
  train_tokens = [self.text_processor.preprocess_text(text) for text in train_texts]
130
 
131
- self.vocab, self.id2word = self.build_vocabulary(train_tokens)
 
132
 
133
  X_train = torch.stack([self.text_processor.text_to_tensor(text) for text in train_texts])
134
 
 
128
  train_texts = train_df[self.dataset_params.content_col_name].tolist()
129
  train_tokens = [self.text_processor.preprocess_text(text) for text in train_texts]
130
 
131
+ if self.config.build_vocab:
132
+ self.vocab, self.id2word = self.build_vocabulary(train_tokens)
133
 
134
  X_train = torch.stack([self.text_processor.text_to_tensor(text) for text in train_texts])
135