litvinovmitch11 commited on
Commit
cd123bf
·
verified ·
1 Parent(s): 023f3c5

Synced repo using 'sync_with_huggingface' Github Action

Browse files
config.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ model_path: "./pretrained/best_model.pth"
2
+ vocab_path: "./pretrained/vocab.json"
3
+ config_path: "./pretrained/config.json"
4
+ max_seq_len: 300
5
+ server:
6
+ local: true
7
+ host: "0.0.0.0"
8
+ port: 7860
datasets/load_datasets.ipynb ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "data": {
10
+ "application/vnd.jupyter.widget-view+json": {
11
+ "model_id": "461b224aa295437b8ddd80ccb5b5e683",
12
+ "version_major": 2,
13
+ "version_minor": 0
14
+ },
15
+ "text/plain": [
16
+ "Saving the dataset (0/4 shards): 0%| | 0/3600000 [00:00<?, ? examples/s]"
17
+ ]
18
+ },
19
+ "metadata": {},
20
+ "output_type": "display_data"
21
+ },
22
+ {
23
+ "data": {
24
+ "application/vnd.jupyter.widget-view+json": {
25
+ "model_id": "1628189e590d446588f57357d7e7a035",
26
+ "version_major": 2,
27
+ "version_minor": 0
28
+ },
29
+ "text/plain": [
30
+ "Saving the dataset (0/1 shards): 0%| | 0/400000 [00:00<?, ? examples/s]"
31
+ ]
32
+ },
33
+ "metadata": {},
34
+ "output_type": "display_data"
35
+ },
36
+ {
37
+ "data": {
38
+ "application/vnd.jupyter.widget-view+json": {
39
+ "model_id": "a728f34f72f64abaa31627af611e7ad3",
40
+ "version_major": 2,
41
+ "version_minor": 0
42
+ },
43
+ "text/plain": [
44
+ "Saving the dataset (0/1 shards): 0%| | 0/25000 [00:00<?, ? examples/s]"
45
+ ]
46
+ },
47
+ "metadata": {},
48
+ "output_type": "display_data"
49
+ },
50
+ {
51
+ "data": {
52
+ "application/vnd.jupyter.widget-view+json": {
53
+ "model_id": "cdef5994359e4c2696bc3bb18b1086d3",
54
+ "version_major": 2,
55
+ "version_minor": 0
56
+ },
57
+ "text/plain": [
58
+ "Saving the dataset (0/1 shards): 0%| | 0/25000 [00:00<?, ? examples/s]"
59
+ ]
60
+ },
61
+ "metadata": {},
62
+ "output_type": "display_data"
63
+ },
64
+ {
65
+ "data": {
66
+ "application/vnd.jupyter.widget-view+json": {
67
+ "model_id": "aaf29ef86d3c4d4fb54effb14129b039",
68
+ "version_major": 2,
69
+ "version_minor": 0
70
+ },
71
+ "text/plain": [
72
+ "Saving the dataset (0/1 shards): 0%| | 0/50000 [00:00<?, ? examples/s]"
73
+ ]
74
+ },
75
+ "metadata": {},
76
+ "output_type": "display_data"
77
+ }
78
+ ],
79
+ "source": [
80
+ "from datasets import load_dataset\n",
81
+ "\n",
82
+ "dataset_polarity = load_dataset(\"fancyzhx/amazon_polarity\")\n",
83
+ "dataset_polarity.save_to_disk(\"polarity\")\n",
84
+ "\n",
85
+ "dataset_imdb = load_dataset(\"stanfordnlp/imdb\")\n",
86
+ "dataset_imdb.save_to_disk(\"imdb\")"
87
+ ]
88
+ }
89
+ ],
90
+ "metadata": {
91
+ "kernelspec": {
92
+ "display_name": "monkey-coding-dl-project-rj23F0vJ-py3.12",
93
+ "language": "python",
94
+ "name": "python3"
95
+ },
96
+ "language_info": {
97
+ "codemirror_mode": {
98
+ "name": "ipython",
99
+ "version": 3
100
+ },
101
+ "file_extension": ".py",
102
+ "mimetype": "text/x-python",
103
+ "name": "python",
104
+ "nbconvert_exporter": "python",
105
+ "pygments_lexer": "ipython3",
106
+ "version": "3.12.5"
107
+ }
108
+ },
109
+ "nbformat": 4,
110
+ "nbformat_minor": 2
111
+ }
main.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ for warn in [UserWarning, FutureWarning]: warnings.filterwarnings("ignore", category = warn)
3
+
4
+ from src.app.app import App
5
+ from src.app.config import AppConfig
6
+
7
+
8
+ def main():
9
+ config = AppConfig.from_yaml("config.yaml")
10
+
11
+ app = App(config)
12
+ app.launch()
13
+
14
+
15
+ if __name__ == "__main__":
16
+ main()
notebooks/datasets_stats.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/models.ipynb ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "9bfb61e1",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Сравниваем модели и сохраняем в `src/models/pretrained`"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "f0574ac3",
14
+ "metadata": {},
15
+ "source": [
16
+ "- Импорты\n",
17
+ "- Константы\n",
18
+ "- Считывание датасетов"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 1,
24
+ "id": "5a237c5c",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "import os\n",
29
+ "import time\n",
30
+ "import torch\n",
31
+ "import warnings\n",
32
+ "import numpy as np\n",
33
+ "import pandas as pd\n",
34
+ "import torch.nn as nn\n",
35
+ "import torch.optim as optim\n",
36
+ "from torch.utils.data import DataLoader, TensorDataset\n",
37
+ "from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n",
38
+ "for warn in [UserWarning, FutureWarning]: warnings.filterwarnings(\"ignore\", category = warn)\n",
39
+ "\n",
40
+ "from src.data_utils.config import DatasetConfig\n",
41
+ "from src.data_utils.dataset_params import DatasetName\n",
42
+ "from src.data_utils.dataset_generator import DatasetGenerator\n",
43
+ "from src.models.models import TransformerClassifier, CustomMambaClassifier, LSTMClassifier\n",
44
+ "\n",
45
+ "MAX_SEQ_LEN = 300\n",
46
+ "EMBEDDING_DIM = 128\n",
47
+ "BATCH_SIZE = 32\n",
48
+ "LEARNING_RATE = 1e-4\n",
49
+ "NUM_EPOCHS = 5 # для быстрого сравнения моделей\n",
50
+ "NUM_CLASSES = 2\n",
51
+ "\n",
52
+ "SAVE_DIR = \"../pretrained_comparison\"\n",
53
+ "os.makedirs(SAVE_DIR, exist_ok=True)\n",
54
+ "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
55
+ "\n",
56
+ "config = DatasetConfig(\n",
57
+ " load_from_disk=True,\n",
58
+ " path_to_data=\"../datasets\"\n",
59
+ ")\n",
60
+ "\n",
61
+ "generator = DatasetGenerator(DatasetName.IMDB, config=config)\n",
62
+ "(X_train, y_train), (X_val, y_val), (X_test, y_test) = generator.generate_dataset()\n",
63
+ "VOCAB_SIZE = len(generator.vocab)"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "markdown",
68
+ "id": "5b95192d",
69
+ "metadata": {},
70
+ "source": [
71
+ "Вспомогательные функции для трейна/валидации/теста"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": 2,
77
+ "id": "b2a4534c",
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "\n",
82
+ "def train_and_evaluate(model, train_loader, val_loader, optimizer, criterion, num_epochs, device, model_name, save_path):\n",
83
+ " best_val_f1 = 0.0\n",
84
+ " history = {'train_loss': [], 'val_loss': [], 'val_accuracy': [], 'val_f1': []}\n",
85
+ " \n",
86
+ " print(f\"--- Начало обучения модели: {model_name} на устройстве {device} ---\")\n",
87
+ "\n",
88
+ " for epoch in range(num_epochs):\n",
89
+ " model.train()\n",
90
+ " start_time = time.time()\n",
91
+ " total_train_loss = 0\n",
92
+ "\n",
93
+ " for batch_X, batch_y in train_loader:\n",
94
+ " batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
95
+ " optimizer.zero_grad()\n",
96
+ " outputs = model(batch_X)\n",
97
+ " loss = criterion(outputs, batch_y)\n",
98
+ " loss.backward()\n",
99
+ " optimizer.step()\n",
100
+ " total_train_loss += loss.item()\n",
101
+ " \n",
102
+ " avg_train_loss = total_train_loss / len(train_loader)\n",
103
+ " history['train_loss'].append(avg_train_loss)\n",
104
+ "\n",
105
+ " model.eval()\n",
106
+ " total_val_loss = 0\n",
107
+ " all_preds = []\n",
108
+ " all_labels = []\n",
109
+ "\n",
110
+ " with torch.no_grad():\n",
111
+ " for batch_X, batch_y in val_loader:\n",
112
+ " batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
113
+ " outputs = model(batch_X)\n",
114
+ " loss = criterion(outputs, batch_y)\n",
115
+ " total_val_loss += loss.item()\n",
116
+ " \n",
117
+ " _, predicted = torch.max(outputs.data, 1)\n",
118
+ " all_preds.extend(predicted.cpu().numpy())\n",
119
+ " all_labels.extend(batch_y.cpu().numpy())\n",
120
+ " \n",
121
+ " avg_val_loss = total_val_loss / len(val_loader)\n",
122
+ " \n",
123
+ " accuracy = accuracy_score(all_labels, all_preds)\n",
124
+ " precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
125
+ " \n",
126
+ " history['val_loss'].append(avg_val_loss)\n",
127
+ " history['val_accuracy'].append(accuracy)\n",
128
+ " history['val_f1'].append(f1)\n",
129
+ "\n",
130
+ " epoch_time = time.time() - start_time\n",
131
+ " print(f\"Эпоха {epoch+1}/{num_epochs} | Время: {epoch_time:.2f}с | Train Loss: {avg_train_loss:.4f} | \"\n",
132
+ " f\"Val Loss: {avg_val_loss:.4f} | Val Acc: {accuracy:.4f} | Val F1: {f1:.4f}\")\n",
133
+ "\n",
134
+ " if f1 > best_val_f1:\n",
135
+ " best_val_f1 = f1\n",
136
+ " torch.save(model.state_dict(), save_path)\n",
137
+ " print(f\" -> Модель сохранена, новый лучший Val F1: {best_val_f1:.4f}\")\n",
138
+ " \n",
139
+ " print(f\"--- Обучение модели {model_name} завершено ---\")\n",
140
+ " return history\n",
141
+ "\n",
142
+ "def evaluate_on_test(model, test_loader, device, criterion):\n",
143
+ " model.eval()\n",
144
+ " total_test_loss = 0\n",
145
+ " all_preds = []\n",
146
+ " all_labels = []\n",
147
+ "\n",
148
+ " with torch.no_grad():\n",
149
+ " for batch_X, batch_y in test_loader:\n",
150
+ " batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
151
+ " outputs = model(batch_X)\n",
152
+ " loss = criterion(outputs, batch_y)\n",
153
+ " total_test_loss += loss.item()\n",
154
+ " \n",
155
+ " _, predicted = torch.max(outputs.data, 1)\n",
156
+ " all_preds.extend(predicted.cpu().numpy())\n",
157
+ " all_labels.extend(batch_y.cpu().numpy())\n",
158
+ " \n",
159
+ " avg_test_loss = total_test_loss / len(test_loader)\n",
160
+ " \n",
161
+ " accuracy = accuracy_score(all_labels, all_preds)\n",
162
+ " precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
163
+ " \n",
164
+ " return {'loss': avg_test_loss, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1_score': f1}"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "markdown",
169
+ "id": "1be50523",
170
+ "metadata": {},
171
+ "source": [
172
+ "Создание даталоадера"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": 3,
178
+ "id": "cccc5bea",
179
+ "metadata": {},
180
+ "outputs": [],
181
+ "source": [
182
+ "def create_dataloader(X, y, batch_size, shuffle=True):\n",
183
+ " X_tensor = torch.as_tensor(X, dtype=torch.long)\n",
184
+ " y_tensor = torch.as_tensor(y, dtype=torch.long)\n",
185
+ " dataset = TensorDataset(X_tensor, y_tensor)\n",
186
+ " return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)\n",
187
+ "\n",
188
+ "train_loader = create_dataloader(X_train, y_train, BATCH_SIZE)\n",
189
+ "val_loader = create_dataloader(X_val, y_val, BATCH_SIZE, shuffle=False)\n",
190
+ "test_loader = create_dataloader(X_test, y_test, BATCH_SIZE, shuffle=False)"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "markdown",
195
+ "id": "4938b9f3",
196
+ "metadata": {},
197
+ "source": [
198
+ "Сравнения моделей\n",
199
+ "\n",
200
+ "Смотрим первые 5 эпох чтобы выбрать лучшую модель, с которой будем играться дальше"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": 4,
206
+ "id": "0244aafa",
207
+ "metadata": {},
208
+ "outputs": [
209
+ {
210
+ "name": "stdout",
211
+ "output_type": "stream",
212
+ "text": [
213
+ "--- Начало обучения модели: CustomMamba на устройстве cuda ---\n",
214
+ "Эпоха 1/5 | Время: 337.85с | Train Loss: 0.6768 | Val Loss: 0.6168 | Val Acc: 0.6592 | Val F1: 0.5937\n",
215
+ " -> Модель сохранена, новый лучший Val F1: 0.5937\n",
216
+ "Эпоха 2/5 | Время: 345.54с | Train Loss: 0.5266 | Val Loss: 0.4964 | Val Acc: 0.7580 | Val F1: 0.7552\n",
217
+ " -> Модель сохранена, новый лучший Val F1: 0.7552\n",
218
+ "Эпоха 3/5 | Время: 343.23с | Train Loss: 0.4329 | Val Loss: 0.4586 | Val Acc: 0.7812 | Val F1: 0.7830\n",
219
+ " -> Модель сохранена, новый лучший Val F1: 0.7830\n",
220
+ "Эпоха 4/5 | Время: 342.62с | Train Loss: 0.3730 | Val Loss: 0.4596 | Val Acc: 0.7928 | Val F1: 0.8056\n",
221
+ " -> Модель сохранена, новый лучший Val F1: 0.8056\n",
222
+ "Эпоха 5/5 | Время: 340.21с | Train Loss: 0.3127 | Val Loss: 0.4469 | Val Acc: 0.7996 | Val F1: 0.8124\n",
223
+ " -> Модель сохранена, новый лучший Val F1: 0.8124\n",
224
+ "--- Обучение модели CustomMamba завершено ---\n",
225
+ "--- Оценка лучшей модели CustomMamba на тестовых данных ---\n",
226
+ "Результаты для CustomMamba: {'loss': 0.44949763529239944, 'accuracy': 0.8062, 'precision': 0.778874269005848, 'recall': 0.8541082164328657, 'f1_score': 0.8147581724335691}\n",
227
+ "------------------------------------------------------------\n",
228
+ "--- Начало обучения модели: Lib_LSTM на устройстве cuda ---\n",
229
+ "Эпоха 1/5 | Время: 5.09с | Train Loss: 0.6930 | Val Loss: 0.6922 | Val Acc: 0.5170 | Val F1: 0.4221\n",
230
+ " -> Модель сохранена, новый лучший Val F1: 0.4221\n",
231
+ "Эпоха 2/5 | Время: 5.03с | Train Loss: 0.6911 | Val Loss: 0.6899 | Val Acc: 0.5324 | Val F1: 0.4880\n",
232
+ " -> Модель сохранена, новый лучший Val F1: 0.4880\n",
233
+ "Эпоха 3/5 | Время: 5.03с | Train Loss: 0.6864 | Val Loss: 0.6837 | Val Acc: 0.5530 | Val F1: 0.5605\n",
234
+ " -> Модель сохранена, новый лучший Val F1: 0.5605\n",
235
+ "Эпоха 4/5 | Время: 5.03с | Train Loss: 0.6740 | Val Loss: 0.6589 | Val Acc: 0.6096 | Val F1: 0.6208\n",
236
+ " -> Модель сохранена, новый лучший Val F1: 0.6208\n",
237
+ "Эпоха 5/5 | Время: 5.04с | Train Loss: 0.6489 | Val Loss: 0.6501 | Val Acc: 0.6498 | Val F1: 0.6460\n",
238
+ " -> Модель сохранена, новый лучший Val F1: 0.6460\n",
239
+ "--- Обучение модели Lib_LSTM завершено ---\n",
240
+ "--- Оценка лучшей модели Lib_LSTM на тестовых данных ---\n",
241
+ "Результаты для Lib_LSTM: {'loss': 0.6330309821541902, 'accuracy': 0.6644, 'precision': 0.6724356268467708, 'recall': 0.6384769539078157, 'f1_score': 0.655016447368421}\n",
242
+ "------------------------------------------------------------\n",
243
+ "--- Начало обучения модели: Lib_Transformer на устройстве cuda ---\n",
244
+ "Эпоха 1/5 | Время: 4.28с | Train Loss: 0.6712 | Val Loss: 0.6773 | Val Acc: 0.5292 | Val F1: 0.1729\n",
245
+ " -> Модель сохранена, новый лучший Val F1: 0.1729\n",
246
+ "Эпоха 2/5 | Время: 4.14с | Train Loss: 0.5753 | Val Loss: 0.5631 | Val Acc: 0.7308 | Val F1: 0.7701\n",
247
+ " -> Модель сохранена, новый лучший Val F1: 0.7701\n",
248
+ "Эпоха 3/5 | Время: 4.17с | Train Loss: 0.4836 | Val Loss: 0.5106 | Val Acc: 0.7622 | Val F1: 0.7830\n",
249
+ " -> Модель сохранена, новый лучший Val F1: 0.7830\n",
250
+ "Эпоха 4/5 | Время: 4.16с | Train Loss: 0.4399 | Val Loss: 0.4880 | Val Acc: 0.7814 | Val F1: 0.7763\n",
251
+ "Эпоха 5/5 | Время: 4.13с | Train Loss: 0.4014 | Val Loss: 0.4611 | Val Acc: 0.7946 | Val F1: 0.8078\n",
252
+ " -> Модель сохранена, новый лучший Val F1: 0.8078\n",
253
+ "--- Обучение модели Lib_Transformer завершено ---\n",
254
+ "--- Оценка лучшей модели Lib_Transformer на тестовых данных ---\n",
255
+ "Результаты для Lib_Transformer: {'loss': 0.4671077333438169, 'accuracy': 0.7938, 'precision': 0.7661818181818182, 'recall': 0.8444889779559118, 'f1_score': 0.8034318398474738}\n",
256
+ "------------------------------------------------------------\n",
257
+ "\n",
258
+ "\n",
259
+ "--- Итоговая таблица сравнения моделей на тестовых данных ---\n",
260
+ " loss accuracy precision recall f1_score\n",
261
+ "CustomMamba 0.449498 0.8062 0.778874 0.854108 0.814758\n",
262
+ "Lib_LSTM 0.633031 0.6644 0.672436 0.638477 0.655016\n",
263
+ "Lib_Transformer 0.467108 0.7938 0.766182 0.844489 0.803432\n"
264
+ ]
265
+ }
266
+ ],
267
+ "source": [
268
+ "model_configs = {\n",
269
+ " \"CustomMamba\": {\n",
270
+ " \"class\": CustomMambaClassifier,\n",
271
+ " \"params\": {'vocab_size': VOCAB_SIZE, 'd_model': EMBEDDING_DIM, 'd_state': 8, \n",
272
+ " 'd_conv': 4, 'num_layers': 2, 'num_classes': NUM_CLASSES},\n",
273
+ " },\n",
274
+ "\n",
275
+ " \"Lib_LSTM\": {\n",
276
+ " \"class\": LSTMClassifier,\n",
277
+ " \"params\": {'vocab_size': VOCAB_SIZE, 'embed_dim': EMBEDDING_DIM, 'hidden_dim': 128, \n",
278
+ " 'num_layers': 2, 'num_classes': NUM_CLASSES, 'dropout': 0.5},\n",
279
+ " },\n",
280
+ " \"Lib_Transformer\": {\n",
281
+ " \"class\": TransformerClassifier,\n",
282
+ " \"params\": {'vocab_size': VOCAB_SIZE, 'embed_dim': EMBEDDING_DIM, 'num_heads': 4, \n",
283
+ " 'num_layers': 2, 'num_classes': NUM_CLASSES, 'max_seq_len': MAX_SEQ_LEN},\n",
284
+ " },\n",
285
+ "}\n",
286
+ "\n",
287
+ "results = {}\n",
288
+ "for model_name, config in model_configs.items():\n",
289
+ "\n",
290
+ " model_path = os.path.join(SAVE_DIR, f\"best_model_{model_name.lower()}.pth\")\n",
291
+ " \n",
292
+ " model = config['class'](**config['params']).to(DEVICE)\n",
293
+ " optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
294
+ " criterion = nn.CrossEntropyLoss()\n",
295
+ " \n",
296
+ " train_and_evaluate(\n",
297
+ " model=model, train_loader=train_loader, val_loader=val_loader,\n",
298
+ " optimizer=optimizer, criterion=criterion, num_epochs=NUM_EPOCHS,\n",
299
+ " device=DEVICE, model_name=model_name, save_path=model_path\n",
300
+ " )\n",
301
+ " \n",
302
+ " print(f\"--- Оценка лучшей модели {model_name} на тестовых данных ---\")\n",
303
+ " if os.path.exists(model_path):\n",
304
+ " best_model = config['class'](**config['params']).to(DEVICE)\n",
305
+ " best_model.load_state_dict(torch.load(model_path))\n",
306
+ " test_metrics = evaluate_on_test(best_model, test_loader, DEVICE, criterion)\n",
307
+ " results[model_name] = test_metrics\n",
308
+ " print(f\"Результаты для {model_name}: {test_metrics}\")\n",
309
+ " else:\n",
310
+ " print(f\"Файл лучшей модели для {model_name} не найден. Пропускаем оценку.\")\n",
311
+ "\n",
312
+ " print(\"-\" * 60)\n",
313
+ " \n",
314
+ "if results:\n",
315
+ " results_df = pd.DataFrame(results).T\n",
316
+ " print(\"\\n\\n--- Итоговая таблица сравнения моделей на тестовых данных ---\")\n",
317
+ " print(results_df.to_string())\n",
318
+ "else:\n",
319
+ " print(\"Не удалось получить результаты ни для одной модели.\")\n"
320
+ ]
321
+ },
322
+ {
323
+ "cell_type": "markdown",
324
+ "id": "404db766",
325
+ "metadata": {},
326
+ "source": [
327
+ "По результатам видно, что LSTM и Transformer обучаются быстро, но Mamba обучается хорошо. Дальнейшие шаги следующие \n",
328
+ " - Пробуем сравнить Transformer и Mamba более детально, играем с гиперпараметрами\n",
329
+ " - LSTM проигрывает Transformer и по времени, и по качеству, поэтому в следующий этап сравнения не пойдет\n",
330
+ " \n",
331
+ "Цель следующего иследования: найти идеальный баланс между временем и качеством. Поставим больше эпох, меньший lr для обоих моделей, увеличим датасет (в текущем сетапе было 10'000 сэмплов на трейн и по 5'000 на валидацию/тест)"
332
+ ]
333
+ }
334
+ ],
335
+ "metadata": {
336
+ "kernelspec": {
337
+ "display_name": "monkey-coding-dl-project-rj23F0vJ-py3.12",
338
+ "language": "python",
339
+ "name": "python3"
340
+ },
341
+ "language_info": {
342
+ "codemirror_mode": {
343
+ "name": "ipython",
344
+ "version": 3
345
+ },
346
+ "file_extension": ".py",
347
+ "mimetype": "text/x-python",
348
+ "name": "python",
349
+ "nbconvert_exporter": "python",
350
+ "pygments_lexer": "ipython3",
351
+ "version": "3.12.5"
352
+ }
353
+ },
354
+ "nbformat": 4,
355
+ "nbformat_minor": 5
356
+ }
notebooks/train.ipynb ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import warnings\n",
10
+ "for warn in [UserWarning, FutureWarning]: warnings.filterwarnings(\"ignore\", category = warn)\n",
11
+ "\n",
12
+ "import os\n",
13
+ "import time\n",
14
+ "import json\n",
15
+ "import torch\n",
16
+ "import torch.nn as nn\n",
17
+ "import torch.optim as optim\n",
18
+ "\n",
19
+ "from torch.utils.data import DataLoader, TensorDataset\n",
20
+ "\n",
21
+ "# Импортируем классы моделей из нашего файла\n",
22
+ "from src.models.models import TransformerClassifier, MambaClassifier, LSTMClassifier\n"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 2,
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "MODEL_TO_TRAIN = 'Transformer' \n",
32
+ "\n",
33
+ "# Гиперпараметры данных и модели\n",
34
+ "MAX_SEQ_LEN = 300\n",
35
+ "EMBEDDING_DIM = 128\n",
36
+ "BATCH_SIZE = 32\n",
37
+ "LEARNING_RATE = 1e-4\n",
38
+ "NUM_EPOCHS = 5 # Увеличим для лучшего результата\n",
39
+ "\n",
40
+ "# Пути для сохранения артефактов\n",
41
+ "SAVE_DIR = \"../pretrained\"\n",
42
+ "os.makedirs(SAVE_DIR, exist_ok=True)\n",
43
+ "MODEL_SAVE_PATH = os.path.join(SAVE_DIR, \"best_model.pth\")\n",
44
+ "VOCAB_SAVE_PATH = os.path.join(SAVE_DIR, \"vocab.json\")\n",
45
+ "CONFIG_SAVE_PATH = os.path.join(SAVE_DIR, \"config.json\")\n",
46
+ "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 3,
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "from src.data_utils.dataset_generator import DatasetGenerator\n",
56
+ "from src.data_utils.dataset_params import DatasetName\n",
57
+ "\n",
58
+ "generator = DatasetGenerator(DatasetName.IMDB)\n",
59
+ "(X_train, y_train), (X_val, y_val), (X_test, y_test) = generator.generate_dataset()\n",
60
+ "X_train, y_train, X_val, y_val, X_test, y_test = X_train[:1000], y_train[:1000], X_val[:100], y_val[:100], X_test[:100], y_test[:100]"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": 4,
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "def create_dataloader(X, y, batch_size):\n",
70
+ " dataset = TensorDataset(torch.tensor(X, dtype=torch.long), torch.tensor(y, dtype=torch.long))\n",
71
+ " return DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
72
+ "train_loader = create_dataloader(X_train, y_train, BATCH_SIZE)\n",
73
+ "val_loader = create_dataloader(X_val, y_val, BATCH_SIZE)"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": 5,
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "model_params = {}\n",
83
+ "if MODEL_TO_TRAIN == 'Transformer':\n",
84
+ " model_params = {'vocab_size': len(generator.vocab), 'embed_dim': EMBEDDING_DIM, 'num_heads': 4, 'num_layers': 2, 'num_classes': 2, 'max_seq_len': MAX_SEQ_LEN}\n",
85
+ " model = TransformerClassifier(**model_params)\n",
86
+ "elif MODEL_TO_TRAIN == 'Mamba':\n",
87
+ " model_params = {'vocab_size': len(generator.vocab), 'embed_dim': EMBEDDING_DIM, 'mamba_d_state': 16, 'mamba_d_conv': 4, 'mamba_expand': 2, 'num_classes': 2}\n",
88
+ " model = MambaClassifier(**model_params)\n",
89
+ "elif MODEL_TO_TRAIN == 'LSTM':\n",
90
+ " model_params = {'vocab_size': len(generator.vocab), 'embed_dim': EMBEDDING_DIM, 'hidden_dim': 256, 'num_layers': 2, 'num_classes': 2, 'dropout': 0.5}\n",
91
+ " model = LSTMClassifier(**model_params)\n",
92
+ "else:\n",
93
+ " raise ValueError(\"Неизвестный тип модели. Выберите 'Transformer', 'Mamba' или 'LSTM'\")"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": 6,
99
+ "metadata": {},
100
+ "outputs": [
101
+ {
102
+ "name": "stdout",
103
+ "output_type": "stream",
104
+ "text": [
105
+ "--- Начало обучения модели: Transformer ---\n",
106
+ "Эпоха 1/5 | Время: 17.06с | Train Loss: 0.7023 | Val Loss: 0.7095 | Val Acc: 0.4000\n",
107
+ " -> Модель сохранена, новая лучшая Val Loss: 0.7095\n",
108
+ "Эпоха 2/5 | Время: 16.40с | Train Loss: 0.6682 | Val Loss: 0.6937 | Val Acc: 0.4800\n",
109
+ " -> Модель сохранена, новая лучшая Val Loss: 0.6937\n",
110
+ "Эпоха 3/5 | Время: 16.13с | Train Loss: 0.6471 | Val Loss: 0.7075 | Val Acc: 0.4100\n",
111
+ "Эпоха 4/5 | Время: 16.36с | Train Loss: 0.6283 | Val Loss: 0.6917 | Val Acc: 0.5300\n",
112
+ " -> Модель сохранена, новая лучшая Val Loss: 0.6917\n",
113
+ "Эпоха 5/5 | Время: 16.39с | Train Loss: 0.6050 | Val Loss: 0.6871 | Val Acc: 0.5300\n",
114
+ " -> Модель сохранена, новая лучшая Val Loss: 0.6871\n"
115
+ ]
116
+ }
117
+ ],
118
+ "source": [
119
+ "model.to(DEVICE)\n",
120
+ "optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
121
+ "criterion = nn.CrossEntropyLoss()\n",
122
+ "\n",
123
+ "best_val_loss = float('inf')\n",
124
+ "print(f\"--- Начало обучения модели: {MODEL_TO_TRAIN} ---\")\n",
125
+ "for epoch in range(NUM_EPOCHS):\n",
126
+ " model.train()\n",
127
+ " start_time = time.time()\n",
128
+ " total_train_loss = 0\n",
129
+ "\n",
130
+ " for batch_X, batch_y in train_loader:\n",
131
+ " batch_X, batch_y = batch_X.to(DEVICE), batch_y.to(DEVICE)\n",
132
+ " optimizer.zero_grad()\n",
133
+ " outputs = model(batch_X)\n",
134
+ " loss = criterion(outputs, batch_y)\n",
135
+ " loss.backward()\n",
136
+ " optimizer.step()\n",
137
+ " total_train_loss += loss.item()\n",
138
+ " avg_train_loss = total_train_loss / len(train_loader)\n",
139
+ " \n",
140
+ " model.eval()\n",
141
+ " total_val_loss, correct_val, total_val = 0, 0, 0\n",
142
+ " with torch.no_grad():\n",
143
+ " for batch_X, batch_y in val_loader:\n",
144
+ " batch_X, batch_y = batch_X.to(DEVICE), batch_y.to(DEVICE)\n",
145
+ " outputs = model(batch_X)\n",
146
+ " loss = criterion(outputs, batch_y)\n",
147
+ " total_val_loss += loss.item()\n",
148
+ " _, predicted = torch.max(outputs.data, 1)\n",
149
+ " total_val += batch_y.size(0)\n",
150
+ " correct_val += (predicted == batch_y).sum().item()\n",
151
+ " avg_val_loss = total_val_loss / len(val_loader)\n",
152
+ " val_accuracy = correct_val / total_val\n",
153
+ "\n",
154
+ " epoch_time = time.time() - start_time\n",
155
+ " print(f\"Эпоха {epoch+1}/{NUM_EPOCHS} | Время: {epoch_time:.2f}с | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.4f}\")\n",
156
+ " \n",
157
+ " if avg_val_loss < best_val_loss:\n",
158
+ " best_val_loss = avg_val_loss\n",
159
+ " torch.save(model.state_dict(), MODEL_SAVE_PATH)\n",
160
+ " print(f\" -> Модель сохранена, новая лучшая Val Loss: {best_val_loss:.4f}\")"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": 7,
166
+ "metadata": {},
167
+ "outputs": [
168
+ {
169
+ "name": "stdout",
170
+ "output_type": "stream",
171
+ "text": [
172
+ "Конфигурация модели сохранена в: ../pretrained/config.json\n"
173
+ ]
174
+ }
175
+ ],
176
+ "source": [
177
+ "with open(VOCAB_SAVE_PATH, 'w', encoding='utf-8') as f:\n",
178
+ " json.dump(generator.vocab, f, ensure_ascii=False, indent=4)\n",
179
+ "\n",
180
+ "config = {\n",
181
+ " \"model_type\": MODEL_TO_TRAIN,\n",
182
+ " \"max_seq_len\": MAX_SEQ_LEN,\n",
183
+ " \"model_params\": model_params,\n",
184
+ "}\n",
185
+ "with open(CONFIG_SAVE_PATH, 'w', encoding='utf-8') as f:\n",
186
+ " json.dump(config, f, ensure_ascii=False, indent=4)\n",
187
+ "print(f\"Конфигурация модели сохранена в: {CONFIG_SAVE_PATH}\")\n"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": null,
193
+ "metadata": {},
194
+ "outputs": [],
195
+ "source": []
196
+ }
197
+ ],
198
+ "metadata": {
199
+ "kernelspec": {
200
+ "display_name": "monkey-coding-dl-project-OWiM8ypK-py3.12",
201
+ "language": "python",
202
+ "name": "python3"
203
+ },
204
+ "language_info": {
205
+ "codemirror_mode": {
206
+ "name": "ipython",
207
+ "version": 3
208
+ },
209
+ "file_extension": ".py",
210
+ "mimetype": "text/x-python",
211
+ "name": "python",
212
+ "nbconvert_exporter": "python",
213
+ "pygments_lexer": "ipython3",
214
+ "version": "3.12.3"
215
+ }
216
+ },
217
+ "nbformat": 4,
218
+ "nbformat_minor": 2
219
+ }
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pretrained/best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7636cf4c7205b64df4b91b2a23620d443de468a91211e074760e64adb24751ba
3
+ size 37445685
pretrained/config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "Transformer",
3
+ "max_seq_len": 300,
4
+ "model_params": {
5
+ "vocab_size": 69715,
6
+ "embed_dim": 128,
7
+ "num_heads": 4,
8
+ "num_layers": 2,
9
+ "num_classes": 2,
10
+ "max_seq_len": 300
11
+ }
12
+ }
pretrained/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "monkey-coding-dl-project"
3
+ version = "0.1.0"
4
+ description = "HSE DL spring 2025 project"
5
+ authors = [
6
+ "Michael Litvinov",
7
+ "Kamil Gabidullin",
8
+ ]
9
+ readme = "README.md"
10
+ packages = [
11
+ { include = "src" },
12
+ ]
13
+
14
+ [tool.poetry.dependencies]
15
+ python = "^3.12"
16
+ datasets = "3.6.0"
17
+ matplotlib = "3.10.3"
18
+ nltk = "3.9.1"
19
+ numpy = "2.3.0"
20
+ pandas = "2.3.0"
21
+ scikit-learn = "1.7.0"
22
+ seaborn = "0.13.2"
23
+ torch = "2.7.1"
24
+ transformers = "4.52.4"
25
+ jupyter = "^1.1.1"
26
+ ipykernel = "^6.29.5"
27
+ gradio = "^5.33.2"
28
+
29
+
30
+ [build-system]
31
+ requires = ["poetry-core"]
32
+ build-backend = "poetry.core.masonry.api"
requirements.txt ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ argcomplete==3.1.4
2
+ attrs==23.2.0
3
+ Automat==22.10.0
4
+ Babel==2.10.3
5
+ bcc==0.29.1
6
+ bcrypt==3.2.2
7
+ blinker==1.7.0
8
+ boto3==1.34.46
9
+ botocore==1.34.46
10
+ build==1.0.3
11
+ CacheControl==0.14.0
12
+ certifi==2023.11.17
13
+ cffi==1.17.1
14
+ chardet==5.2.0
15
+ cleo==2.1.0
16
+ click==8.1.6
17
+ cloud-init==24.4.1
18
+ colorama==0.4.6
19
+ command-not-found==0.3
20
+ configobj==5.0.8
21
+ constantly==23.10.4
22
+ cpplint==2.0.0
23
+ crashtest==0.4.1
24
+ cryptography==41.0.7
25
+ dbus-python==1.3.2
26
+ distlib==0.3.9
27
+ distro==1.9.0
28
+ distro-info==1.7+build1
29
+ docopt==0.6.2
30
+ dulwich==0.21.6
31
+ fastimport==0.9.14
32
+ fastjsonschema==2.19.0
33
+ filelock==3.17.0
34
+ gyp==0.1
35
+ h11==0.16.0
36
+ httplib2==0.20.4
37
+ hyperlink==21.0.0
38
+ idna==3.6
39
+ importlib-metadata==4.12.0
40
+ incremental==22.10.0
41
+ iniconfig==2.0.0
42
+ installer==0.7.0
43
+ jaraco.classes==3.2.1
44
+ jeepney==0.8.0
45
+ Jinja2==3.1.2
46
+ jmespath==1.0.1
47
+ jsonpatch==1.32
48
+ jsonpointer==2.0
49
+ jsonschema==4.10.3
50
+ keyring==24.3.1
51
+ launchpadlib==1.11.0
52
+ lazr.restfulclient==0.14.6
53
+ lazr.uri==1.0.6
54
+ lockfile==0.12.2
55
+ markdown-it-py==3.0.0
56
+ MarkupSafe==2.1.5
57
+ mdurl==0.1.2
58
+ more-itertools==10.2.0
59
+ msgpack==1.0.3
60
+ netaddr==0.8.0
61
+ netifaces==0.11.0
62
+ numpy==2.2.2
63
+ oauthlib==3.2.2
64
+ packaging==24.0
65
+ pexpect==4.9.0
66
+ pipenv==2024.4.1
67
+ pipx==1.4.3
68
+ pkginfo==1.9.6
69
+ platformdirs==4.3.6
70
+ pluggy==1.5.0
71
+ poetry-core==1.9.0
72
+ psutil==5.9.8
73
+ ptyprocess==0.7.0
74
+ pyasn1==0.4.8
75
+ pyasn1-modules==0.2.8
76
+ pycparser==2.22
77
+ Pygments==2.17.2
78
+ PyGObject==3.48.2
79
+ PyHamcrest==2.1.0
80
+ PyJWT==2.7.0
81
+ pylev==1.4.0
82
+ pyOpenSSL==23.2.0
83
+ pyparsing==3.1.1
84
+ pyproject_hooks==1.0.0
85
+ pyrsistent==0.20.0
86
+ pyserial==3.5
87
+ pytest==8.3.4
88
+ python-apt==2.7.7+ubuntu4
89
+ python-dateutil==2.8.2
90
+ python-debian==0.1.49+ubuntu2
91
+ python-magic==0.4.27
92
+ pytz==2024.1
93
+ PyYAML==6.0.1
94
+ requests==2.31.0
95
+ requests-toolbelt==1.0.0
96
+ rich==13.7.1
97
+ s3transfer==0.10.1
98
+ SecretStorage==3.3.3
99
+ service-identity==24.1.0
100
+ setuptools==68.1.2
101
+ shellingham==1.5.4
102
+ six==1.16.0
103
+ sniffio==1.3.1
104
+ sos==4.8.2
105
+ ssh-import-id==5.11
106
+ systemd-python==235
107
+ toml==0.10.2
108
+ tomlkit==0.12.4
109
+ trove-classifiers==2024.1.31
110
+ Twisted==24.3.0
111
+ typing_extensions==4.14.0
112
+ ubuntu-drivers-common==0.0.0
113
+ ubuntu-pro-client==8001
114
+ ufw==0.36.2
115
+ unattended-upgrades==0.1
116
+ urllib3==2.0.7
117
+ userpath==1.9.1
118
+ virtualenv==20.29.2
119
+ wadllib==1.3.6
120
+ wheel==0.42.0
121
+ xkit==0.0.0
122
+ zipp==1.0.0
123
+ zope.interface==6.1
src/app/__init__.py ADDED
File without changes
src/app/app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import torch
4
+
5
+ from typing import Optional
6
+
7
+ from src.app.config import AppConfig
8
+ from src.data_utils.config import TextProcessorConfig
9
+ from src.data_utils.text_processor import TextProcessor
10
+
11
+
12
+ class App:
13
+ def __init__(self, config: AppConfig):
14
+ self.config = config
15
+ self.model: Optional[torch.nn.Module] = None
16
+ self.text_processor: Optional[TextProcessor] = None
17
+
18
+ self._load_model()
19
+ self._load_text_processor()
20
+
21
+
22
+ def _load_model(self):
23
+ """
24
+ Load model with params from config
25
+ """
26
+
27
+ with open(self.config.config_path, 'r') as f:
28
+ config = json.load(f)
29
+
30
+ model_type = config['model_type']
31
+ model_classes = {
32
+ 'Transformer': 'TransformerClassifier',
33
+ 'LSTM': 'LSTMClassifier',
34
+ 'Mamba': 'MambaClassifier'
35
+ }
36
+
37
+ if model_type not in model_classes:
38
+ raise ValueError(f"Unknown model type: {model_type}")
39
+
40
+ module = __import__(f'src.models.models', fromlist=[model_classes[model_type]])
41
+ model_class = getattr(module, model_classes[model_type])
42
+
43
+ self.model = model_class(**config['model_params'])
44
+ self.model.load_state_dict(torch.load(self.config.model_path))
45
+ self.model.eval()
46
+
47
+
48
+ def _load_text_processor(self):
49
+ with open(self.config.vocab_path, 'r') as f:
50
+ vocab = json.load(f)
51
+
52
+ processor_config = TextProcessorConfig(
53
+ max_seq_len=self.config.max_seq_len,
54
+ lowercase=True,
55
+ remove_punct=False
56
+ )
57
+
58
+ self.text_processor = TextProcessor(
59
+ vocab=vocab,
60
+ config=processor_config
61
+ )
62
+
63
+
64
+ def predict(self, text: str) -> dict:
65
+ """
66
+ Evaluating the tone of the text
67
+ """
68
+
69
+ if not text.strip():
70
+ return {"Negative": 0.5, "Positive": 0.5}
71
+
72
+ input_tensor = self.text_processor.text_to_tensor(text).unsqueeze(0)
73
+
74
+ with torch.no_grad():
75
+ output = self.model(input_tensor)
76
+ proba = torch.softmax(output, dim=1)[0].tolist()
77
+
78
+ return {"Negative": proba[0], "Positive": proba[1]}
79
+
80
+
81
+ def launch(self):
82
+ """
83
+ Launch interface
84
+ """
85
+
86
+ interface = gr.Interface(
87
+ fn=self.predict,
88
+ inputs=gr.Textbox(label="Enter your text"),
89
+ outputs=gr.Label(label="Result"),
90
+ title="Evaluating the tone of the text",
91
+ examples=["Very good! Increadble! So fantastic",
92
+ "Thw worst thing in the world!"]
93
+ )
94
+
95
+ if self.config.local:
96
+ interface.launch(
97
+ server_name=self.config.host,
98
+ server_port=self.config.port
99
+ )
100
+ else:
101
+ interface.launch(
102
+ share=True
103
+ )
src/app/config.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass
7
+ class AppConfig:
8
+ model_path: str
9
+ vocab_path: str
10
+ config_path: str
11
+ max_seq_len: int = 300
12
+ local: bool = True
13
+ host: str = "0.0.0.0"
14
+ port: int = 7860
15
+
16
+ @classmethod
17
+ def from_yaml(cls, config_path: str) -> 'AppConfig':
18
+ """
19
+ AppConfig from path string
20
+
21
+ Args:
22
+ config_path: path string
23
+
24
+ Returns:
25
+ AppConfig object
26
+ """
27
+
28
+ with open(config_path, 'r') as f:
29
+ config_data = yaml.safe_load(f)
30
+
31
+ return cls(
32
+ model_path=config_data['model_path'],
33
+ vocab_path=config_data['vocab_path'],
34
+ config_path=config_data['config_path'],
35
+ max_seq_len=int(config_data['max_seq_len']),
36
+ local=config_data.get('server', {}).get('local', True),
37
+ host=config_data.get('server', {}).get('host', "0.0.0.0"),
38
+ port=config_data.get('server', {}).get('port', 7860)
39
+ )
src/app/model_utils/factory.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from pathlib import Path
4
+ from typing import Dict, Any, Optional
5
+
6
+ from src.models.models import TransformerClassifier, MambaClassifier, LSTMClassifier
7
+
8
+
9
+ class ModelFactory:
10
+ """
11
+ Factory class for creating and loading models
12
+ """
13
+
14
+ @staticmethod
15
+ def create_model(
16
+ model_type: str,
17
+ model_params: Dict[str, Any],
18
+ state_dict_path: Optional[Path] = None
19
+ ) -> torch.nn.Module:
20
+ """
21
+ Create and load a model from configuration
22
+
23
+ Args:
24
+ model_type: Type of model ('Transformer', 'Mamba', 'LSTM')
25
+ model_params: Dictionary of model parameters
26
+ state_dict_path: Path to saved state dictionary
27
+
28
+ Returns:
29
+ Initialized PyTorch model
30
+
31
+ Raises:
32
+ ValueError: If model_type is unknown
33
+ """
34
+
35
+ model_classes = {
36
+ "Transformer": TransformerClassifier,
37
+ "Mamba": MambaClassifier,
38
+ "LSTM": LSTMClassifier
39
+ }
40
+
41
+ if model_type not in model_classes:
42
+ raise ValueError(f"Unknown model type: {model_type}")
43
+
44
+ model = model_classes[model_type](**model_params)
45
+
46
+ if state_dict_path:
47
+ state_dict = torch.load(state_dict_path, map_location="cpu")
48
+ model.load_state_dict(state_dict)
49
+
50
+ model.eval()
51
+ return model
src/app/model_utils/manager.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+
4
+ from pathlib import Path
5
+ from typing import Dict, Any
6
+
7
+ from src.app.model_utils.factory import ModelFactory
8
+
9
+
10
+ class ModelManager:
11
+ """
12
+ Manages model loading and inference operations
13
+
14
+ Args:
15
+ model_dir: Directory containing model artifacts
16
+ """
17
+
18
+ def __init__(self, model_dir: str = "../pretrained") -> None:
19
+ self.model_dir = Path(model_dir)
20
+ self.loaded_models: Dict[str, Any] = {}
21
+ self._load_model_artifacts()
22
+
23
+
24
+ def _load_model_artifacts(self) -> None:
25
+ """
26
+ Load model configuration and vocabulary
27
+ """
28
+
29
+ with open(self.model_dir / "config.json", "r") as f:
30
+ self.config = json.load(f)
31
+
32
+ with open(self.model_dir / "vocab.json", "r") as f:
33
+ self.vocab = json.load(f)
34
+
35
+ self.idx_to_label = {0: "Negative", 1: "Positive"}
36
+
37
+
38
+ def get_model(self) -> torch.nn.Module:
39
+ """
40
+ Get the loaded model (cached for performance)
41
+
42
+ Returns:
43
+ Loaded PyTorch model in evaluation mode
44
+ """
45
+
46
+ model_type = self.config["model_type"]
47
+
48
+ if model_type not in self.loaded_models:
49
+ model = ModelFactory.create_model(
50
+ model_type=model_type,
51
+ model_params=self.config["model_params"],
52
+ state_dict_path=self.model_dir / "best_model.pth"
53
+ )
54
+ self.loaded_models[model_type] = model
55
+
56
+ return self.loaded_models[model_type]
57
+
58
+
59
+ def get_vocab(self) -> Dict[str, int]:
60
+ """
61
+ Get vocabulary mapping
62
+ """
63
+
64
+ return self.vocab
65
+
66
+
67
+ def get_config(self) -> Dict[str, Any]:
68
+ """
69
+ Get model configuration
70
+ """
71
+
72
+ return self.config
src/data_utils/__init__.py ADDED
File without changes
src/data_utils/config.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class DatasetConfig:
6
+ """
7
+ Configuration class for dataset generation parameters
8
+
9
+ Attributes:
10
+ embedding_dim: Dimension for embedding layer output
11
+ train_size: Number of samples in training set
12
+ val_size: Number of samples in validation set
13
+ test_size: Number of samples in test set
14
+ random_state: Random seed for reproducibility
15
+ min_word_freq: Minimum word frequency to include in vocabulary
16
+ load_from_disk: Load dataset from local dir. If false download from huggin face
17
+ path_to_data: Path to local dataset data
18
+ max_seq_len: Maximum sequence length (will be padded/truncated to this)
19
+ lowercase: Whether to convert text to lowercase
20
+ remove_punct: Whether to remove punctuation
21
+ pad_token: Padding token
22
+ unk_token: Unknown token
23
+ """
24
+
25
+ embedding_dim: int = 64
26
+ train_size: int = 10000
27
+ val_size: int = 5000
28
+ test_size: int = 5000
29
+ random_state: int = 42
30
+ min_word_freq: int = 1
31
+ load_from_disk: bool = False
32
+ path_to_data: str = "./datasets"
33
+
34
+ max_seq_len: int = 300
35
+ lowercase: bool = True
36
+ remove_punct: bool = False
37
+ pad_token: str = "<PAD>"
38
+ unk_token: str = "<UNK>"
39
+
40
+
41
+ @dataclass
42
+ class TextProcessorConfig:
43
+ """
44
+ Configuration class for text processor parameters (params should be equal dataset config)
45
+
46
+ Attributes:
47
+ max_seq_len: Maximum sequence length (will be padded/truncated to this)
48
+ lowercase: Whether to convert text to lowercase
49
+ remove_punct: Whether to remove punctuation
50
+ pad_token: Padding token
51
+ unk_token: Unknown token
52
+ """
53
+
54
+ max_seq_len: int = 300
55
+ lowercase: bool = True
56
+ remove_punct: bool = False
57
+ pad_token: str = "<PAD>"
58
+ unk_token: str = "<UNK>"
src/data_utils/dataset_generator.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ from typing import Dict, Tuple, List
3
+
4
+ import pandas as pd
5
+ import torch
6
+
7
+ from datasets import load_dataset, load_from_disk
8
+ from sklearn.model_selection import train_test_split
9
+
10
+ import src.data_utils.dataset_params as dataset_params
11
+
12
+ from src.data_utils.config import DatasetConfig, TextProcessorConfig
13
+ from src.data_utils.text_processor import TextProcessor
14
+
15
+
16
+ class DatasetGenerator:
17
+ """
18
+ Main dataset generator class
19
+
20
+ Provides methods to load, build vocabulary, convert text datasets
21
+ into tensor format suitable for deep learning models.
22
+
23
+ Args:
24
+ dataset_name: Name of dataset from DatasetName enum
25
+ config: Configuration object with preprocessing parameters
26
+ device: Torch device to place tensors on (cpu/cuda)
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ dataset_name: dataset_params.DatasetName,
32
+ config: DatasetConfig = DatasetConfig(),
33
+ device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ ):
35
+ self.dataset_params = dataset_params.get_dataset_params_by_name(dataset_name=dataset_name)
36
+ self.config = config
37
+ self.device = device
38
+ self.text_processor = TextProcessor(
39
+ vocab=None,
40
+ config=TextProcessorConfig(
41
+ max_seq_len=self.config.max_seq_len,
42
+ lowercase=self.config.lowercase,
43
+ remove_punct=self.config.remove_punct,
44
+ pad_token=self.config.pad_token,
45
+ unk_token=self.config.unk_token,
46
+ )
47
+ )
48
+ self.vocab = None
49
+ self.id2word = None
50
+ self.embedding_layer = None
51
+
52
+
53
+ def load_raw_data(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
54
+ """
55
+ Load raw dataset from source
56
+
57
+ Returns:
58
+ Tuple of (train_df, val_df, test_df) DataFrames
59
+ """
60
+ if self.config.load_from_disk:
61
+ dataset = load_from_disk(f"{self.config.path_to_data}/{self.dataset_params.local_path}")
62
+ else:
63
+ dataset = load_dataset(self.dataset_params.hugging_face_name)
64
+ train_df = pd.DataFrame(dataset["train"])
65
+ test_df = pd.DataFrame(dataset["test"])
66
+ val_df, test_df = train_test_split(
67
+ test_df,
68
+ test_size=0.5,
69
+ random_state=self.config.random_state,
70
+ stratify=test_df[self.dataset_params.label_col_name]
71
+ )
72
+
73
+ # Sample configured sizes
74
+ train_df = train_df.sample(n=self.config.train_size, random_state=self.config.random_state)
75
+ val_df = val_df.sample(n=self.config.val_size, random_state=self.config.random_state)
76
+ test_df = test_df.sample(n=self.config.test_size, random_state=self.config.random_state)
77
+
78
+ return train_df, val_df, test_df
79
+
80
+
81
+ def build_vocabulary(self, tokenized_texts: List[List[str]]) -> Tuple[Dict[str, int], Dict[int, str]]:
82
+ """
83
+ Build vocabulary from tokenized texts
84
+
85
+ Args:
86
+ tokenized_texts: List of tokenized texts
87
+
88
+ Returns:
89
+ Tuple of (word_to_id, id_to_word) mappings
90
+ """
91
+
92
+ all_tokens = [token for tokens in tokenized_texts for token in tokens]
93
+ word_counts = Counter(all_tokens)
94
+
95
+ filtered_words = [word for word, count in word_counts.items()
96
+ if count >= self.config.min_word_freq]
97
+
98
+ word_to_id = {self.config.pad_token: 0, self.config.unk_token: 1}
99
+ id_to_word = {0: self.config.pad_token, 1: self.config.unk_token}
100
+
101
+ for idx, word in enumerate(filtered_words, start=2):
102
+ word_to_id[word] = idx
103
+ id_to_word[idx] = word
104
+
105
+ self.text_processor.vocab = word_to_id
106
+
107
+ return word_to_id, id_to_word
108
+
109
+
110
+ def generate_dataset(self) -> Tuple[
111
+ Tuple[torch.Tensor, torch.Tensor],
112
+ Tuple[torch.Tensor, torch.Tensor],
113
+ Tuple[torch.Tensor, torch.Tensor]
114
+ ]:
115
+ """
116
+ Main method to generate the full dataset
117
+
118
+ Returns:
119
+ Tuple containing:
120
+ - (train_features, train_labels)
121
+ - (val_features, val_labels)
122
+ - (test_features, test_labels)
123
+ - embedding_layer
124
+ """
125
+
126
+ train_df, val_df, test_df = self.load_raw_data()
127
+
128
+ train_texts = train_df[self.dataset_params.content_col_name].tolist()
129
+ train_tokens = [self.text_processor.preprocess_text(text) for text in train_texts]
130
+
131
+ self.vocab, self.id2word = self.build_vocabulary(train_tokens)
132
+
133
+ X_train = torch.stack([self.text_processor.text_to_tensor(text) for text in train_texts])
134
+
135
+ val_texts = val_df[self.dataset_params.content_col_name].tolist()
136
+ X_val = torch.stack([self.text_processor.text_to_tensor(text) for text in val_texts])
137
+
138
+ test_texts = test_df[self.dataset_params.content_col_name].tolist()
139
+ X_test = torch.stack([self.text_processor.text_to_tensor(text) for text in test_texts])
140
+
141
+ y_train = torch.tensor(train_df[self.dataset_params.label_col_name].values, dtype=torch.long)
142
+ y_val = torch.tensor(val_df[self.dataset_params.label_col_name].values, dtype=torch.long)
143
+ y_test = torch.tensor(test_df[self.dataset_params.label_col_name].values, dtype=torch.long)
144
+
145
+ return (X_train, y_train), (X_val, y_val), (X_test, y_test)
146
+
147
+
148
+ def get_vocabulary(self) -> Tuple[Dict[str, int], Dict[int, str]]:
149
+ """
150
+ Get vocabulary mappings
151
+
152
+ Returns:
153
+ Tuple of (word_to_id, id_to_word) dictionaries
154
+ """
155
+
156
+ return self.vocab, self.id2word
157
+
158
+
159
+ def get_config(self) -> DatasetConfig:
160
+ """
161
+ Get current configuration
162
+
163
+ Returns:
164
+ DatasetConfig object
165
+ """
166
+
167
+ return self.config
168
+
169
+
170
+ def get_text_processor(self) -> TextProcessor:
171
+ """
172
+ Get the text processor for inference usage
173
+
174
+ Returns:
175
+ TextProcessor object
176
+ """
177
+ return self.text_processor
src/data_utils/dataset_params.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+
3
+
4
+ class DatasetName(enum.Enum):
5
+ """
6
+ Supported dataset names enumeration
7
+ """
8
+
9
+ IMDB = "imdb"
10
+ POLARITY = "polarity"
11
+
12
+
13
+ class DatasetParams:
14
+ """
15
+ Abstarct class for dataset
16
+ """
17
+
18
+ hugging_face_name = ""
19
+ content_col_name = ""
20
+ label_col_name = ""
21
+ local_path = ""
22
+
23
+
24
+
25
+ def get_dataset_params_by_name(dataset_name: DatasetName) -> DatasetParams:
26
+ if dataset_name == DatasetName.IMDB:
27
+ return ImbdParams()
28
+ if dataset_name == DatasetName.POLARITY:
29
+ return PolarityParams()
30
+
31
+ raise ValueError(f"Unsupported dataset: {dataset_name}")
32
+
33
+
34
+ class ImbdParams(DatasetParams):
35
+ """
36
+ IMDB dataset params class
37
+ """
38
+
39
+ hugging_face_name = "stanfordnlp/imdb"
40
+ content_col_name = "text"
41
+ label_col_name = "label"
42
+ local_path = "imdb"
43
+
44
+
45
+ class PolarityParams(DatasetParams):
46
+ """
47
+ POLARITY dataset params class
48
+ """
49
+
50
+ hugging_face_name = "fancyzhx/amazon_polarity"
51
+ content_col_name = "content"
52
+ label_col_name = "label"
53
+ local_path = "polarity"
src/data_utils/text_processor.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+
3
+ import nltk
4
+ import torch
5
+
6
+ from nltk.tokenize import word_tokenize
7
+
8
+ from src.data_utils.config import TextProcessorConfig
9
+
10
+
11
+ class TextProcessor:
12
+ """
13
+ Main text preprocessor class
14
+
15
+ Args:
16
+ vocab: Vocabulary dictionary
17
+ config: Configuration object
18
+ """
19
+
20
+ def __init__(self, vocab: Dict[str, int], config: TextProcessorConfig):
21
+ self.vocab = vocab
22
+ self.config = config
23
+ self._ensure_nltk_downloaded()
24
+
25
+
26
+ def _ensure_nltk_downloaded(self):
27
+ try:
28
+ word_tokenize("test")
29
+ except LookupError:
30
+ nltk.download("punkt")
31
+
32
+
33
+ def preprocess_text(self, text: str) -> List[str]:
34
+ """
35
+ Tokenize and preprocess single text string
36
+
37
+ Args:
38
+ text: Your text
39
+
40
+ Returns:
41
+ List of preprocessed tokens
42
+ """
43
+
44
+ if self.config.lowercase:
45
+ text = text.lower()
46
+
47
+ tokens = word_tokenize(text)
48
+
49
+ if self.config.remove_punct:
50
+ tokens = [t for t in tokens if t.isalpha()]
51
+
52
+ return tokens
53
+
54
+
55
+ def text_to_tensor(self, text: str) -> torch.Tensor:
56
+ """
57
+ Convert raw text to tensor
58
+
59
+ Args:
60
+ text: Your text
61
+
62
+ Returns:
63
+ Tensor of your text
64
+ """
65
+
66
+ tokens = self.preprocess_text(text)
67
+ ids = [self.vocab.get(token, self.vocab[self.config.unk_token]) for token in tokens]
68
+
69
+ # Pad or truncate
70
+ if len(ids) < self.config.max_seq_len:
71
+ ids = ids + [self.vocab[self.config.pad_token]] * (self.config.max_seq_len - len(ids))
72
+ else:
73
+ ids = ids[:self.config.max_seq_len]
74
+
75
+ return torch.tensor(ids, dtype=torch.long)
src/models/__init__.py ADDED
File without changes
src/models/models.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class TransformerClassifier(nn.Module):
6
+ def __init__(self, vocab_size, embed_dim, num_heads, num_layers, num_classes, max_seq_len):
7
+ super(TransformerClassifier, self).__init__()
8
+ self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
9
+ self.pos_encoder = nn.Parameter(torch.zeros(1, max_seq_len, embed_dim))
10
+ encoder_layers = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True, dim_feedforward=embed_dim*4)
11
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
12
+ self.fc = nn.Linear(embed_dim, num_classes)
13
+
14
+ def forward(self, x):
15
+ padding_mask = (x == 0)
16
+ x = self.embedding(x) + self.pos_encoder
17
+ x = self.transformer_encoder(x, src_key_padding_mask=padding_mask)
18
+ x = x.mean(dim=1)
19
+ x = self.fc(x)
20
+ return x
21
+
22
+ class LSTMClassifier(nn.Module):
23
+ def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, num_classes, dropout):
24
+ super(LSTMClassifier, self).__init__()
25
+ self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
26
+ self.lstm = nn.LSTM(
27
+ input_size=embed_dim, hidden_size=hidden_dim, num_layers=num_layers,
28
+ batch_first=True, bidirectional=True, dropout=dropout if num_layers > 1 else 0
29
+ )
30
+ self.fc = nn.Linear(hidden_dim * 2, num_classes)
31
+ self.dropout = nn.Dropout(dropout)
32
+
33
+ def forward(self, x):
34
+ embedded = self.dropout(self.embedding(x))
35
+ _, (hidden, cell) = self.lstm(embedded)
36
+ hidden_cat = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
37
+ output = self.fc(self.dropout(hidden_cat))
38
+ return output
39
+
40
+
41
+ class SimpleMambaBlock(nn.Module):
42
+ """
43
+ Логика: Проекция -> 1D Свертка -> Активация -> Селективный SSM -> Выходная проекция
44
+ """
45
+ def __init__(self, d_model, d_state, d_conv, expand=2):
46
+ super().__init__()
47
+ self.d_model = d_model
48
+ self.d_state = d_state
49
+ self.d_conv = d_conv
50
+ self.expand = expand
51
+
52
+ d_inner = int(self.expand * self.d_model)
53
+
54
+ self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
55
+
56
+ self.conv1d = nn.Conv1d(
57
+ in_channels=d_inner, out_channels=d_inner,
58
+ kernel_size=d_conv, padding=d_conv - 1,
59
+ groups=d_inner, bias=True
60
+ )
61
+
62
+ self.x_proj = nn.Linear(d_inner, self.d_state + self.d_state + 1, bias=False)
63
+ self.dt_proj = nn.Linear(1, d_inner, bias=True)
64
+
65
+ A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(d_inner, 1)
66
+ self.A_log = nn.Parameter(torch.log(A))
67
+ self.D = nn.Parameter(torch.ones(d_inner))
68
+
69
+ self.out_proj = nn.Linear(d_inner, d_model, bias=False)
70
+
71
+ def forward(self, x):
72
+ B, L, D = x.shape
73
+
74
+ xz = self.in_proj(x)
75
+ x, z = xz.chunk(2, dim=-1)
76
+
77
+ x = x.transpose(1, 2)
78
+ x = self.conv1d(x)[:, :, :L]
79
+ x = x.transpose(1, 2)
80
+
81
+ x = F.silu(x)
82
+
83
+ y = self.ssm(x)
84
+
85
+ y = y * F.silu(z)
86
+ y = self.out_proj(y)
87
+
88
+ return y
89
+
90
+ def ssm(self, x):
91
+ batch_size, seq_len, d_inner = x.shape
92
+
93
+ A = -torch.exp(self.A_log.float())
94
+ D = self.D.float()
95
+
96
+ x_dbl = self.x_proj(x)
97
+ delta, B_param, C_param = torch.split(x_dbl, [1, self.d_state, self.d_state], dim=-1)
98
+
99
+ delta = F.softplus(self.dt_proj(delta))
100
+
101
+ h = torch.zeros(batch_size, d_inner, self.d_state, device=x.device)
102
+ ys = []
103
+
104
+ for i in range(seq_len):
105
+ delta_i = delta[:, i, :]
106
+ A_i = torch.exp(delta_i.unsqueeze(-1) * A)
107
+ B_i = delta_i.unsqueeze(-1) * B_param[:, i, :].unsqueeze(1)
108
+
109
+ h = A_i * h + B_i * x[:, i, :].unsqueeze(-1)
110
+
111
+ y_i = (h @ C_param[:, i, :].unsqueeze(-1)).squeeze(-1)
112
+ ys.append(y_i)
113
+
114
+ y = torch.stack(ys, dim=1)
115
+ y = y + x * D
116
+
117
+ return y
118
+
119
+
120
+ class CustomMambaClassifier(nn.Module):
121
+ def __init__(self, vocab_size, d_model, d_state, d_conv, num_layers, num_classes):
122
+ super().__init__()
123
+ self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
124
+
125
+ self.layers = nn.ModuleList(
126
+ [SimpleMambaBlock(d_model, d_state, d_conv) for _ in range(num_layers)]
127
+ )
128
+
129
+ self.fc = nn.Linear(d_model, num_classes)
130
+
131
+ def forward(self, x):
132
+ x = self.embedding(x)
133
+ for layer in self.layers:
134
+ x = layer(x)
135
+
136
+ pooled_output = x.mean(dim=1)
137
+ return self.fc(pooled_output)
src/models/predict.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from nltk.tokenize import word_tokenize
4
+ import argparse
5
+
6
+ from src.models.models import TransformerClassifier, MambaClassifier, LSTMClassifier
7
+
8
+ SAVE_DIR = "pretrained"
9
+ MODEL_PATH = f"{SAVE_DIR}/best_model.pth"
10
+ CONFIG_PATH = f"{SAVE_DIR}/config.json"
11
+ VOCAB_PATH = f"{SAVE_DIR}/vocab.json"
12
+
13
+ ID_TO_LABEL = {0: "Negative", 1: "Positive"}
14
+
15
+ def load_artifacts():
16
+ with open(CONFIG_PATH, 'r') as f:
17
+ config = json.load(f)
18
+ with open(VOCAB_PATH, 'r') as f:
19
+ vocab = json.load(f)
20
+
21
+ model_type = config['model_type']
22
+ model_params = config['model_params']
23
+
24
+ if model_type == 'Transformer':
25
+ model = TransformerClassifier(**model_params)
26
+ elif model_type == 'Mamba':
27
+ model = MambaClassifier(**model_params)
28
+ elif model_type == 'LSTM':
29
+ model = LSTMClassifier(**model_params)
30
+ else:
31
+ raise ValueError("Неизвестный тип модели в файле конфигурации.")
32
+
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
35
+ model.to(device)
36
+ model.eval()
37
+
38
+ return model, vocab, config, device
39
+
40
+ def preprocess_text(text, vocab, max_len):
41
+ tokens = word_tokenize(text.lower())
42
+ ids = [vocab.get(token, vocab['<UNK>']) for token in tokens]
43
+ if len(ids) < max_len:
44
+ ids.extend([vocab['<PAD>']] * (max_len - len(ids)))
45
+ else:
46
+ ids = ids[:max_len]
47
+ return torch.tensor(ids).unsqueeze(0)
48
+
49
+ def predict(text, model, vocab, config, device):
50
+ input_tensor = preprocess_text(text, vocab, config['max_seq_len'])
51
+ input_tensor = input_tensor.to(device)
52
+
53
+ with torch.no_grad():
54
+ outputs = model(input_tensor)
55
+ probabilities = torch.softmax(outputs, dim=1)
56
+ prediction_id = torch.argmax(probabilities, dim=1).item()
57
+
58
+ predicted_label = ID_TO_LABEL[prediction_id]
59
+ confidence = probabilities[0][prediction_id].item()
60
+
61
+ return predicted_label, confidence
62
+
63
+ if __name__ == "__main__":
64
+ parser = argparse.ArgumentParser(description="Предсказать тональность текста с помощью обученной модели.")
65
+ parser.add_argument("text", type=str, help="Текст для анализа (в кавычках).")
66
+ args = parser.parse_args()
67
+
68
+ print("Загрузка модели и артефактов...")
69
+ try:
70
+ loaded_model, loaded_vocab, loaded_config, device = load_artifacts()
71
+ print(f"Модель '{loaded_config['model_type']}' успешно загружена на устройство {device}.")
72
+ except FileNotFoundError:
73
+ print("\nОШИБКА: Файлы модели не найдены!")
74
+ print("Сначала запустите скрипт train.py для обучения и сохранения модели.")
75
+ exit()
76
+
77
+ label, conf = predict(args.text, loaded_model, loaded_vocab, loaded_config, device)
78
+
79
+ print("\n--- Результат предсказания ---")
80
+ print(f"Текст: '{args.text}'")
81
+ print(f"Тональность: {label}")
82
+ print(f"Уверенность: {conf:.2%}")
83
+