Synced repo using 'sync_with_huggingface' Github Action
Browse files- config.yaml +8 -0
- datasets/load_datasets.ipynb +111 -0
- main.py +16 -0
- notebooks/datasets_stats.ipynb +0 -0
- notebooks/models.ipynb +356 -0
- notebooks/train.ipynb +219 -0
- poetry.lock +0 -0
- pretrained/best_model.pth +3 -0
- pretrained/config.json +12 -0
- pretrained/vocab.json +0 -0
- pyproject.toml +32 -0
- requirements.txt +123 -0
- src/app/__init__.py +0 -0
- src/app/app.py +103 -0
- src/app/config.py +39 -0
- src/app/model_utils/factory.py +51 -0
- src/app/model_utils/manager.py +72 -0
- src/data_utils/__init__.py +0 -0
- src/data_utils/config.py +58 -0
- src/data_utils/dataset_generator.py +177 -0
- src/data_utils/dataset_params.py +53 -0
- src/data_utils/text_processor.py +75 -0
- src/models/__init__.py +0 -0
- src/models/models.py +137 -0
- src/models/predict.py +83 -0
config.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_path: "./pretrained/best_model.pth"
|
2 |
+
vocab_path: "./pretrained/vocab.json"
|
3 |
+
config_path: "./pretrained/config.json"
|
4 |
+
max_seq_len: 300
|
5 |
+
server:
|
6 |
+
local: true
|
7 |
+
host: "0.0.0.0"
|
8 |
+
port: 7860
|
datasets/load_datasets.ipynb
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"data": {
|
10 |
+
"application/vnd.jupyter.widget-view+json": {
|
11 |
+
"model_id": "461b224aa295437b8ddd80ccb5b5e683",
|
12 |
+
"version_major": 2,
|
13 |
+
"version_minor": 0
|
14 |
+
},
|
15 |
+
"text/plain": [
|
16 |
+
"Saving the dataset (0/4 shards): 0%| | 0/3600000 [00:00<?, ? examples/s]"
|
17 |
+
]
|
18 |
+
},
|
19 |
+
"metadata": {},
|
20 |
+
"output_type": "display_data"
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"data": {
|
24 |
+
"application/vnd.jupyter.widget-view+json": {
|
25 |
+
"model_id": "1628189e590d446588f57357d7e7a035",
|
26 |
+
"version_major": 2,
|
27 |
+
"version_minor": 0
|
28 |
+
},
|
29 |
+
"text/plain": [
|
30 |
+
"Saving the dataset (0/1 shards): 0%| | 0/400000 [00:00<?, ? examples/s]"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
"metadata": {},
|
34 |
+
"output_type": "display_data"
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"data": {
|
38 |
+
"application/vnd.jupyter.widget-view+json": {
|
39 |
+
"model_id": "a728f34f72f64abaa31627af611e7ad3",
|
40 |
+
"version_major": 2,
|
41 |
+
"version_minor": 0
|
42 |
+
},
|
43 |
+
"text/plain": [
|
44 |
+
"Saving the dataset (0/1 shards): 0%| | 0/25000 [00:00<?, ? examples/s]"
|
45 |
+
]
|
46 |
+
},
|
47 |
+
"metadata": {},
|
48 |
+
"output_type": "display_data"
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"data": {
|
52 |
+
"application/vnd.jupyter.widget-view+json": {
|
53 |
+
"model_id": "cdef5994359e4c2696bc3bb18b1086d3",
|
54 |
+
"version_major": 2,
|
55 |
+
"version_minor": 0
|
56 |
+
},
|
57 |
+
"text/plain": [
|
58 |
+
"Saving the dataset (0/1 shards): 0%| | 0/25000 [00:00<?, ? examples/s]"
|
59 |
+
]
|
60 |
+
},
|
61 |
+
"metadata": {},
|
62 |
+
"output_type": "display_data"
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"data": {
|
66 |
+
"application/vnd.jupyter.widget-view+json": {
|
67 |
+
"model_id": "aaf29ef86d3c4d4fb54effb14129b039",
|
68 |
+
"version_major": 2,
|
69 |
+
"version_minor": 0
|
70 |
+
},
|
71 |
+
"text/plain": [
|
72 |
+
"Saving the dataset (0/1 shards): 0%| | 0/50000 [00:00<?, ? examples/s]"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
"metadata": {},
|
76 |
+
"output_type": "display_data"
|
77 |
+
}
|
78 |
+
],
|
79 |
+
"source": [
|
80 |
+
"from datasets import load_dataset\n",
|
81 |
+
"\n",
|
82 |
+
"dataset_polarity = load_dataset(\"fancyzhx/amazon_polarity\")\n",
|
83 |
+
"dataset_polarity.save_to_disk(\"polarity\")\n",
|
84 |
+
"\n",
|
85 |
+
"dataset_imdb = load_dataset(\"stanfordnlp/imdb\")\n",
|
86 |
+
"dataset_imdb.save_to_disk(\"imdb\")"
|
87 |
+
]
|
88 |
+
}
|
89 |
+
],
|
90 |
+
"metadata": {
|
91 |
+
"kernelspec": {
|
92 |
+
"display_name": "monkey-coding-dl-project-rj23F0vJ-py3.12",
|
93 |
+
"language": "python",
|
94 |
+
"name": "python3"
|
95 |
+
},
|
96 |
+
"language_info": {
|
97 |
+
"codemirror_mode": {
|
98 |
+
"name": "ipython",
|
99 |
+
"version": 3
|
100 |
+
},
|
101 |
+
"file_extension": ".py",
|
102 |
+
"mimetype": "text/x-python",
|
103 |
+
"name": "python",
|
104 |
+
"nbconvert_exporter": "python",
|
105 |
+
"pygments_lexer": "ipython3",
|
106 |
+
"version": "3.12.5"
|
107 |
+
}
|
108 |
+
},
|
109 |
+
"nbformat": 4,
|
110 |
+
"nbformat_minor": 2
|
111 |
+
}
|
main.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
for warn in [UserWarning, FutureWarning]: warnings.filterwarnings("ignore", category = warn)
|
3 |
+
|
4 |
+
from src.app.app import App
|
5 |
+
from src.app.config import AppConfig
|
6 |
+
|
7 |
+
|
8 |
+
def main():
|
9 |
+
config = AppConfig.from_yaml("config.yaml")
|
10 |
+
|
11 |
+
app = App(config)
|
12 |
+
app.launch()
|
13 |
+
|
14 |
+
|
15 |
+
if __name__ == "__main__":
|
16 |
+
main()
|
notebooks/datasets_stats.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/models.ipynb
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "9bfb61e1",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Сравниваем модели и сохраняем в `src/models/pretrained`"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "markdown",
|
13 |
+
"id": "f0574ac3",
|
14 |
+
"metadata": {},
|
15 |
+
"source": [
|
16 |
+
"- Импорты\n",
|
17 |
+
"- Константы\n",
|
18 |
+
"- Считывание датасетов"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 1,
|
24 |
+
"id": "5a237c5c",
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [],
|
27 |
+
"source": [
|
28 |
+
"import os\n",
|
29 |
+
"import time\n",
|
30 |
+
"import torch\n",
|
31 |
+
"import warnings\n",
|
32 |
+
"import numpy as np\n",
|
33 |
+
"import pandas as pd\n",
|
34 |
+
"import torch.nn as nn\n",
|
35 |
+
"import torch.optim as optim\n",
|
36 |
+
"from torch.utils.data import DataLoader, TensorDataset\n",
|
37 |
+
"from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n",
|
38 |
+
"for warn in [UserWarning, FutureWarning]: warnings.filterwarnings(\"ignore\", category = warn)\n",
|
39 |
+
"\n",
|
40 |
+
"from src.data_utils.config import DatasetConfig\n",
|
41 |
+
"from src.data_utils.dataset_params import DatasetName\n",
|
42 |
+
"from src.data_utils.dataset_generator import DatasetGenerator\n",
|
43 |
+
"from src.models.models import TransformerClassifier, CustomMambaClassifier, LSTMClassifier\n",
|
44 |
+
"\n",
|
45 |
+
"MAX_SEQ_LEN = 300\n",
|
46 |
+
"EMBEDDING_DIM = 128\n",
|
47 |
+
"BATCH_SIZE = 32\n",
|
48 |
+
"LEARNING_RATE = 1e-4\n",
|
49 |
+
"NUM_EPOCHS = 5 # для быстрого сравнения моделей\n",
|
50 |
+
"NUM_CLASSES = 2\n",
|
51 |
+
"\n",
|
52 |
+
"SAVE_DIR = \"../pretrained_comparison\"\n",
|
53 |
+
"os.makedirs(SAVE_DIR, exist_ok=True)\n",
|
54 |
+
"DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
55 |
+
"\n",
|
56 |
+
"config = DatasetConfig(\n",
|
57 |
+
" load_from_disk=True,\n",
|
58 |
+
" path_to_data=\"../datasets\"\n",
|
59 |
+
")\n",
|
60 |
+
"\n",
|
61 |
+
"generator = DatasetGenerator(DatasetName.IMDB, config=config)\n",
|
62 |
+
"(X_train, y_train), (X_val, y_val), (X_test, y_test) = generator.generate_dataset()\n",
|
63 |
+
"VOCAB_SIZE = len(generator.vocab)"
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"cell_type": "markdown",
|
68 |
+
"id": "5b95192d",
|
69 |
+
"metadata": {},
|
70 |
+
"source": [
|
71 |
+
"Вспомогательные функции для трейна/валидации/теста"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"cell_type": "code",
|
76 |
+
"execution_count": 2,
|
77 |
+
"id": "b2a4534c",
|
78 |
+
"metadata": {},
|
79 |
+
"outputs": [],
|
80 |
+
"source": [
|
81 |
+
"\n",
|
82 |
+
"def train_and_evaluate(model, train_loader, val_loader, optimizer, criterion, num_epochs, device, model_name, save_path):\n",
|
83 |
+
" best_val_f1 = 0.0\n",
|
84 |
+
" history = {'train_loss': [], 'val_loss': [], 'val_accuracy': [], 'val_f1': []}\n",
|
85 |
+
" \n",
|
86 |
+
" print(f\"--- Начало обучения модели: {model_name} на устройстве {device} ---\")\n",
|
87 |
+
"\n",
|
88 |
+
" for epoch in range(num_epochs):\n",
|
89 |
+
" model.train()\n",
|
90 |
+
" start_time = time.time()\n",
|
91 |
+
" total_train_loss = 0\n",
|
92 |
+
"\n",
|
93 |
+
" for batch_X, batch_y in train_loader:\n",
|
94 |
+
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
|
95 |
+
" optimizer.zero_grad()\n",
|
96 |
+
" outputs = model(batch_X)\n",
|
97 |
+
" loss = criterion(outputs, batch_y)\n",
|
98 |
+
" loss.backward()\n",
|
99 |
+
" optimizer.step()\n",
|
100 |
+
" total_train_loss += loss.item()\n",
|
101 |
+
" \n",
|
102 |
+
" avg_train_loss = total_train_loss / len(train_loader)\n",
|
103 |
+
" history['train_loss'].append(avg_train_loss)\n",
|
104 |
+
"\n",
|
105 |
+
" model.eval()\n",
|
106 |
+
" total_val_loss = 0\n",
|
107 |
+
" all_preds = []\n",
|
108 |
+
" all_labels = []\n",
|
109 |
+
"\n",
|
110 |
+
" with torch.no_grad():\n",
|
111 |
+
" for batch_X, batch_y in val_loader:\n",
|
112 |
+
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
|
113 |
+
" outputs = model(batch_X)\n",
|
114 |
+
" loss = criterion(outputs, batch_y)\n",
|
115 |
+
" total_val_loss += loss.item()\n",
|
116 |
+
" \n",
|
117 |
+
" _, predicted = torch.max(outputs.data, 1)\n",
|
118 |
+
" all_preds.extend(predicted.cpu().numpy())\n",
|
119 |
+
" all_labels.extend(batch_y.cpu().numpy())\n",
|
120 |
+
" \n",
|
121 |
+
" avg_val_loss = total_val_loss / len(val_loader)\n",
|
122 |
+
" \n",
|
123 |
+
" accuracy = accuracy_score(all_labels, all_preds)\n",
|
124 |
+
" precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
|
125 |
+
" \n",
|
126 |
+
" history['val_loss'].append(avg_val_loss)\n",
|
127 |
+
" history['val_accuracy'].append(accuracy)\n",
|
128 |
+
" history['val_f1'].append(f1)\n",
|
129 |
+
"\n",
|
130 |
+
" epoch_time = time.time() - start_time\n",
|
131 |
+
" print(f\"Эпоха {epoch+1}/{num_epochs} | Время: {epoch_time:.2f}с | Train Loss: {avg_train_loss:.4f} | \"\n",
|
132 |
+
" f\"Val Loss: {avg_val_loss:.4f} | Val Acc: {accuracy:.4f} | Val F1: {f1:.4f}\")\n",
|
133 |
+
"\n",
|
134 |
+
" if f1 > best_val_f1:\n",
|
135 |
+
" best_val_f1 = f1\n",
|
136 |
+
" torch.save(model.state_dict(), save_path)\n",
|
137 |
+
" print(f\" -> Модель сохранена, новый лучший Val F1: {best_val_f1:.4f}\")\n",
|
138 |
+
" \n",
|
139 |
+
" print(f\"--- Обучение модели {model_name} завершено ---\")\n",
|
140 |
+
" return history\n",
|
141 |
+
"\n",
|
142 |
+
"def evaluate_on_test(model, test_loader, device, criterion):\n",
|
143 |
+
" model.eval()\n",
|
144 |
+
" total_test_loss = 0\n",
|
145 |
+
" all_preds = []\n",
|
146 |
+
" all_labels = []\n",
|
147 |
+
"\n",
|
148 |
+
" with torch.no_grad():\n",
|
149 |
+
" for batch_X, batch_y in test_loader:\n",
|
150 |
+
" batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n",
|
151 |
+
" outputs = model(batch_X)\n",
|
152 |
+
" loss = criterion(outputs, batch_y)\n",
|
153 |
+
" total_test_loss += loss.item()\n",
|
154 |
+
" \n",
|
155 |
+
" _, predicted = torch.max(outputs.data, 1)\n",
|
156 |
+
" all_preds.extend(predicted.cpu().numpy())\n",
|
157 |
+
" all_labels.extend(batch_y.cpu().numpy())\n",
|
158 |
+
" \n",
|
159 |
+
" avg_test_loss = total_test_loss / len(test_loader)\n",
|
160 |
+
" \n",
|
161 |
+
" accuracy = accuracy_score(all_labels, all_preds)\n",
|
162 |
+
" precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
|
163 |
+
" \n",
|
164 |
+
" return {'loss': avg_test_loss, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1_score': f1}"
|
165 |
+
]
|
166 |
+
},
|
167 |
+
{
|
168 |
+
"cell_type": "markdown",
|
169 |
+
"id": "1be50523",
|
170 |
+
"metadata": {},
|
171 |
+
"source": [
|
172 |
+
"Создание даталоадера"
|
173 |
+
]
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"cell_type": "code",
|
177 |
+
"execution_count": 3,
|
178 |
+
"id": "cccc5bea",
|
179 |
+
"metadata": {},
|
180 |
+
"outputs": [],
|
181 |
+
"source": [
|
182 |
+
"def create_dataloader(X, y, batch_size, shuffle=True):\n",
|
183 |
+
" X_tensor = torch.as_tensor(X, dtype=torch.long)\n",
|
184 |
+
" y_tensor = torch.as_tensor(y, dtype=torch.long)\n",
|
185 |
+
" dataset = TensorDataset(X_tensor, y_tensor)\n",
|
186 |
+
" return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)\n",
|
187 |
+
"\n",
|
188 |
+
"train_loader = create_dataloader(X_train, y_train, BATCH_SIZE)\n",
|
189 |
+
"val_loader = create_dataloader(X_val, y_val, BATCH_SIZE, shuffle=False)\n",
|
190 |
+
"test_loader = create_dataloader(X_test, y_test, BATCH_SIZE, shuffle=False)"
|
191 |
+
]
|
192 |
+
},
|
193 |
+
{
|
194 |
+
"cell_type": "markdown",
|
195 |
+
"id": "4938b9f3",
|
196 |
+
"metadata": {},
|
197 |
+
"source": [
|
198 |
+
"Сравнения моделей\n",
|
199 |
+
"\n",
|
200 |
+
"Смотрим первые 5 эпох чтобы выбрать лучшую модель, с которой будем играться дальше"
|
201 |
+
]
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"cell_type": "code",
|
205 |
+
"execution_count": 4,
|
206 |
+
"id": "0244aafa",
|
207 |
+
"metadata": {},
|
208 |
+
"outputs": [
|
209 |
+
{
|
210 |
+
"name": "stdout",
|
211 |
+
"output_type": "stream",
|
212 |
+
"text": [
|
213 |
+
"--- Начало обучения модели: CustomMamba на устройстве cuda ---\n",
|
214 |
+
"Эпоха 1/5 | Время: 337.85с | Train Loss: 0.6768 | Val Loss: 0.6168 | Val Acc: 0.6592 | Val F1: 0.5937\n",
|
215 |
+
" -> Модель сохранена, новый лучший Val F1: 0.5937\n",
|
216 |
+
"Эпоха 2/5 | Время: 345.54с | Train Loss: 0.5266 | Val Loss: 0.4964 | Val Acc: 0.7580 | Val F1: 0.7552\n",
|
217 |
+
" -> Модель сохранена, новый лучший Val F1: 0.7552\n",
|
218 |
+
"Эпоха 3/5 | Время: 343.23с | Train Loss: 0.4329 | Val Loss: 0.4586 | Val Acc: 0.7812 | Val F1: 0.7830\n",
|
219 |
+
" -> Модель сохранена, новый лучший Val F1: 0.7830\n",
|
220 |
+
"Эпоха 4/5 | Время: 342.62с | Train Loss: 0.3730 | Val Loss: 0.4596 | Val Acc: 0.7928 | Val F1: 0.8056\n",
|
221 |
+
" -> Модель сохранена, новый лучший Val F1: 0.8056\n",
|
222 |
+
"Эпоха 5/5 | Время: 340.21с | Train Loss: 0.3127 | Val Loss: 0.4469 | Val Acc: 0.7996 | Val F1: 0.8124\n",
|
223 |
+
" -> Модель сохранена, новый лучший Val F1: 0.8124\n",
|
224 |
+
"--- Обучение модели CustomMamba завершено ---\n",
|
225 |
+
"--- Оценка лучшей модели CustomMamba на тестовых данных ---\n",
|
226 |
+
"Результаты для CustomMamba: {'loss': 0.44949763529239944, 'accuracy': 0.8062, 'precision': 0.778874269005848, 'recall': 0.8541082164328657, 'f1_score': 0.8147581724335691}\n",
|
227 |
+
"------------------------------------------------------------\n",
|
228 |
+
"--- Начало обучения модели: Lib_LSTM на устройстве cuda ---\n",
|
229 |
+
"Эпоха 1/5 | Время: 5.09с | Train Loss: 0.6930 | Val Loss: 0.6922 | Val Acc: 0.5170 | Val F1: 0.4221\n",
|
230 |
+
" -> Модель сохранена, новый лучший Val F1: 0.4221\n",
|
231 |
+
"Эпоха 2/5 | Время: 5.03с | Train Loss: 0.6911 | Val Loss: 0.6899 | Val Acc: 0.5324 | Val F1: 0.4880\n",
|
232 |
+
" -> Модель сохранена, новый лучший Val F1: 0.4880\n",
|
233 |
+
"Эпоха 3/5 | Время: 5.03с | Train Loss: 0.6864 | Val Loss: 0.6837 | Val Acc: 0.5530 | Val F1: 0.5605\n",
|
234 |
+
" -> Модель сохранена, новый лучший Val F1: 0.5605\n",
|
235 |
+
"Эпоха 4/5 | Время: 5.03с | Train Loss: 0.6740 | Val Loss: 0.6589 | Val Acc: 0.6096 | Val F1: 0.6208\n",
|
236 |
+
" -> Модель сохранена, новый лучший Val F1: 0.6208\n",
|
237 |
+
"Эпоха 5/5 | Время: 5.04с | Train Loss: 0.6489 | Val Loss: 0.6501 | Val Acc: 0.6498 | Val F1: 0.6460\n",
|
238 |
+
" -> Модель сохранена, новый лучший Val F1: 0.6460\n",
|
239 |
+
"--- Обучение модели Lib_LSTM завершено ---\n",
|
240 |
+
"--- Оценка лучшей модели Lib_LSTM на тестовых данных ---\n",
|
241 |
+
"Результаты для Lib_LSTM: {'loss': 0.6330309821541902, 'accuracy': 0.6644, 'precision': 0.6724356268467708, 'recall': 0.6384769539078157, 'f1_score': 0.655016447368421}\n",
|
242 |
+
"------------------------------------------------------------\n",
|
243 |
+
"--- Начало обучения модели: Lib_Transformer на устройстве cuda ---\n",
|
244 |
+
"Эпоха 1/5 | Время: 4.28с | Train Loss: 0.6712 | Val Loss: 0.6773 | Val Acc: 0.5292 | Val F1: 0.1729\n",
|
245 |
+
" -> Модель сохранена, новый лучший Val F1: 0.1729\n",
|
246 |
+
"Эпоха 2/5 | Время: 4.14с | Train Loss: 0.5753 | Val Loss: 0.5631 | Val Acc: 0.7308 | Val F1: 0.7701\n",
|
247 |
+
" -> Модель сохранена, новый лучший Val F1: 0.7701\n",
|
248 |
+
"Эпоха 3/5 | Время: 4.17с | Train Loss: 0.4836 | Val Loss: 0.5106 | Val Acc: 0.7622 | Val F1: 0.7830\n",
|
249 |
+
" -> Модель сохранена, новый лучший Val F1: 0.7830\n",
|
250 |
+
"Эпоха 4/5 | Время: 4.16с | Train Loss: 0.4399 | Val Loss: 0.4880 | Val Acc: 0.7814 | Val F1: 0.7763\n",
|
251 |
+
"Эпоха 5/5 | Время: 4.13с | Train Loss: 0.4014 | Val Loss: 0.4611 | Val Acc: 0.7946 | Val F1: 0.8078\n",
|
252 |
+
" -> Модель сохранена, новый лучший Val F1: 0.8078\n",
|
253 |
+
"--- Обучение модели Lib_Transformer завершено ---\n",
|
254 |
+
"--- Оценка лучшей модели Lib_Transformer на тестовых данных ---\n",
|
255 |
+
"Результаты для Lib_Transformer: {'loss': 0.4671077333438169, 'accuracy': 0.7938, 'precision': 0.7661818181818182, 'recall': 0.8444889779559118, 'f1_score': 0.8034318398474738}\n",
|
256 |
+
"------------------------------------------------------------\n",
|
257 |
+
"\n",
|
258 |
+
"\n",
|
259 |
+
"--- Итоговая таблица сравнения моделей на тестовых данных ---\n",
|
260 |
+
" loss accuracy precision recall f1_score\n",
|
261 |
+
"CustomMamba 0.449498 0.8062 0.778874 0.854108 0.814758\n",
|
262 |
+
"Lib_LSTM 0.633031 0.6644 0.672436 0.638477 0.655016\n",
|
263 |
+
"Lib_Transformer 0.467108 0.7938 0.766182 0.844489 0.803432\n"
|
264 |
+
]
|
265 |
+
}
|
266 |
+
],
|
267 |
+
"source": [
|
268 |
+
"model_configs = {\n",
|
269 |
+
" \"CustomMamba\": {\n",
|
270 |
+
" \"class\": CustomMambaClassifier,\n",
|
271 |
+
" \"params\": {'vocab_size': VOCAB_SIZE, 'd_model': EMBEDDING_DIM, 'd_state': 8, \n",
|
272 |
+
" 'd_conv': 4, 'num_layers': 2, 'num_classes': NUM_CLASSES},\n",
|
273 |
+
" },\n",
|
274 |
+
"\n",
|
275 |
+
" \"Lib_LSTM\": {\n",
|
276 |
+
" \"class\": LSTMClassifier,\n",
|
277 |
+
" \"params\": {'vocab_size': VOCAB_SIZE, 'embed_dim': EMBEDDING_DIM, 'hidden_dim': 128, \n",
|
278 |
+
" 'num_layers': 2, 'num_classes': NUM_CLASSES, 'dropout': 0.5},\n",
|
279 |
+
" },\n",
|
280 |
+
" \"Lib_Transformer\": {\n",
|
281 |
+
" \"class\": TransformerClassifier,\n",
|
282 |
+
" \"params\": {'vocab_size': VOCAB_SIZE, 'embed_dim': EMBEDDING_DIM, 'num_heads': 4, \n",
|
283 |
+
" 'num_layers': 2, 'num_classes': NUM_CLASSES, 'max_seq_len': MAX_SEQ_LEN},\n",
|
284 |
+
" },\n",
|
285 |
+
"}\n",
|
286 |
+
"\n",
|
287 |
+
"results = {}\n",
|
288 |
+
"for model_name, config in model_configs.items():\n",
|
289 |
+
"\n",
|
290 |
+
" model_path = os.path.join(SAVE_DIR, f\"best_model_{model_name.lower()}.pth\")\n",
|
291 |
+
" \n",
|
292 |
+
" model = config['class'](**config['params']).to(DEVICE)\n",
|
293 |
+
" optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
|
294 |
+
" criterion = nn.CrossEntropyLoss()\n",
|
295 |
+
" \n",
|
296 |
+
" train_and_evaluate(\n",
|
297 |
+
" model=model, train_loader=train_loader, val_loader=val_loader,\n",
|
298 |
+
" optimizer=optimizer, criterion=criterion, num_epochs=NUM_EPOCHS,\n",
|
299 |
+
" device=DEVICE, model_name=model_name, save_path=model_path\n",
|
300 |
+
" )\n",
|
301 |
+
" \n",
|
302 |
+
" print(f\"--- Оценка лучшей модели {model_name} на тестовых данных ---\")\n",
|
303 |
+
" if os.path.exists(model_path):\n",
|
304 |
+
" best_model = config['class'](**config['params']).to(DEVICE)\n",
|
305 |
+
" best_model.load_state_dict(torch.load(model_path))\n",
|
306 |
+
" test_metrics = evaluate_on_test(best_model, test_loader, DEVICE, criterion)\n",
|
307 |
+
" results[model_name] = test_metrics\n",
|
308 |
+
" print(f\"Результаты для {model_name}: {test_metrics}\")\n",
|
309 |
+
" else:\n",
|
310 |
+
" print(f\"Файл лучшей модели для {model_name} не найден. Пропускаем оценку.\")\n",
|
311 |
+
"\n",
|
312 |
+
" print(\"-\" * 60)\n",
|
313 |
+
" \n",
|
314 |
+
"if results:\n",
|
315 |
+
" results_df = pd.DataFrame(results).T\n",
|
316 |
+
" print(\"\\n\\n--- Итоговая таблица сравнения моделей на тестовых данных ---\")\n",
|
317 |
+
" print(results_df.to_string())\n",
|
318 |
+
"else:\n",
|
319 |
+
" print(\"Не удалось получить результаты ни для одной модели.\")\n"
|
320 |
+
]
|
321 |
+
},
|
322 |
+
{
|
323 |
+
"cell_type": "markdown",
|
324 |
+
"id": "404db766",
|
325 |
+
"metadata": {},
|
326 |
+
"source": [
|
327 |
+
"По результатам видно, что LSTM и Transformer обучаются быстро, но Mamba обучается хорошо. Дальнейшие шаги следующие \n",
|
328 |
+
" - Пробуем сравнить Transformer и Mamba более детально, играем с гиперпараметрами\n",
|
329 |
+
" - LSTM проигрывает Transformer и по времени, и по качеству, поэтому в следующий этап сравнения не пойдет\n",
|
330 |
+
" \n",
|
331 |
+
"Цель следующего иследования: найти идеальный баланс между временем и качеством. Поставим больше эпох, меньший lr для обоих моделей, увеличим датасет (в текущем сетапе было 10'000 сэмплов на трейн и по 5'000 на валидацию/тест)"
|
332 |
+
]
|
333 |
+
}
|
334 |
+
],
|
335 |
+
"metadata": {
|
336 |
+
"kernelspec": {
|
337 |
+
"display_name": "monkey-coding-dl-project-rj23F0vJ-py3.12",
|
338 |
+
"language": "python",
|
339 |
+
"name": "python3"
|
340 |
+
},
|
341 |
+
"language_info": {
|
342 |
+
"codemirror_mode": {
|
343 |
+
"name": "ipython",
|
344 |
+
"version": 3
|
345 |
+
},
|
346 |
+
"file_extension": ".py",
|
347 |
+
"mimetype": "text/x-python",
|
348 |
+
"name": "python",
|
349 |
+
"nbconvert_exporter": "python",
|
350 |
+
"pygments_lexer": "ipython3",
|
351 |
+
"version": "3.12.5"
|
352 |
+
}
|
353 |
+
},
|
354 |
+
"nbformat": 4,
|
355 |
+
"nbformat_minor": 5
|
356 |
+
}
|
notebooks/train.ipynb
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import warnings\n",
|
10 |
+
"for warn in [UserWarning, FutureWarning]: warnings.filterwarnings(\"ignore\", category = warn)\n",
|
11 |
+
"\n",
|
12 |
+
"import os\n",
|
13 |
+
"import time\n",
|
14 |
+
"import json\n",
|
15 |
+
"import torch\n",
|
16 |
+
"import torch.nn as nn\n",
|
17 |
+
"import torch.optim as optim\n",
|
18 |
+
"\n",
|
19 |
+
"from torch.utils.data import DataLoader, TensorDataset\n",
|
20 |
+
"\n",
|
21 |
+
"# Импортируем классы моделей из нашего файла\n",
|
22 |
+
"from src.models.models import TransformerClassifier, MambaClassifier, LSTMClassifier\n"
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "code",
|
27 |
+
"execution_count": 2,
|
28 |
+
"metadata": {},
|
29 |
+
"outputs": [],
|
30 |
+
"source": [
|
31 |
+
"MODEL_TO_TRAIN = 'Transformer' \n",
|
32 |
+
"\n",
|
33 |
+
"# Гиперпараметры данных и модели\n",
|
34 |
+
"MAX_SEQ_LEN = 300\n",
|
35 |
+
"EMBEDDING_DIM = 128\n",
|
36 |
+
"BATCH_SIZE = 32\n",
|
37 |
+
"LEARNING_RATE = 1e-4\n",
|
38 |
+
"NUM_EPOCHS = 5 # Увеличим для лучшего результата\n",
|
39 |
+
"\n",
|
40 |
+
"# Пути для сохранения артефактов\n",
|
41 |
+
"SAVE_DIR = \"../pretrained\"\n",
|
42 |
+
"os.makedirs(SAVE_DIR, exist_ok=True)\n",
|
43 |
+
"MODEL_SAVE_PATH = os.path.join(SAVE_DIR, \"best_model.pth\")\n",
|
44 |
+
"VOCAB_SAVE_PATH = os.path.join(SAVE_DIR, \"vocab.json\")\n",
|
45 |
+
"CONFIG_SAVE_PATH = os.path.join(SAVE_DIR, \"config.json\")\n",
|
46 |
+
"DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": 3,
|
52 |
+
"metadata": {},
|
53 |
+
"outputs": [],
|
54 |
+
"source": [
|
55 |
+
"from src.data_utils.dataset_generator import DatasetGenerator\n",
|
56 |
+
"from src.data_utils.dataset_params import DatasetName\n",
|
57 |
+
"\n",
|
58 |
+
"generator = DatasetGenerator(DatasetName.IMDB)\n",
|
59 |
+
"(X_train, y_train), (X_val, y_val), (X_test, y_test) = generator.generate_dataset()\n",
|
60 |
+
"X_train, y_train, X_val, y_val, X_test, y_test = X_train[:1000], y_train[:1000], X_val[:100], y_val[:100], X_test[:100], y_test[:100]"
|
61 |
+
]
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"cell_type": "code",
|
65 |
+
"execution_count": 4,
|
66 |
+
"metadata": {},
|
67 |
+
"outputs": [],
|
68 |
+
"source": [
|
69 |
+
"def create_dataloader(X, y, batch_size):\n",
|
70 |
+
" dataset = TensorDataset(torch.tensor(X, dtype=torch.long), torch.tensor(y, dtype=torch.long))\n",
|
71 |
+
" return DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
|
72 |
+
"train_loader = create_dataloader(X_train, y_train, BATCH_SIZE)\n",
|
73 |
+
"val_loader = create_dataloader(X_val, y_val, BATCH_SIZE)"
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"cell_type": "code",
|
78 |
+
"execution_count": 5,
|
79 |
+
"metadata": {},
|
80 |
+
"outputs": [],
|
81 |
+
"source": [
|
82 |
+
"model_params = {}\n",
|
83 |
+
"if MODEL_TO_TRAIN == 'Transformer':\n",
|
84 |
+
" model_params = {'vocab_size': len(generator.vocab), 'embed_dim': EMBEDDING_DIM, 'num_heads': 4, 'num_layers': 2, 'num_classes': 2, 'max_seq_len': MAX_SEQ_LEN}\n",
|
85 |
+
" model = TransformerClassifier(**model_params)\n",
|
86 |
+
"elif MODEL_TO_TRAIN == 'Mamba':\n",
|
87 |
+
" model_params = {'vocab_size': len(generator.vocab), 'embed_dim': EMBEDDING_DIM, 'mamba_d_state': 16, 'mamba_d_conv': 4, 'mamba_expand': 2, 'num_classes': 2}\n",
|
88 |
+
" model = MambaClassifier(**model_params)\n",
|
89 |
+
"elif MODEL_TO_TRAIN == 'LSTM':\n",
|
90 |
+
" model_params = {'vocab_size': len(generator.vocab), 'embed_dim': EMBEDDING_DIM, 'hidden_dim': 256, 'num_layers': 2, 'num_classes': 2, 'dropout': 0.5}\n",
|
91 |
+
" model = LSTMClassifier(**model_params)\n",
|
92 |
+
"else:\n",
|
93 |
+
" raise ValueError(\"Неизвестный тип модели. Выберите 'Transformer', 'Mamba' или 'LSTM'\")"
|
94 |
+
]
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"cell_type": "code",
|
98 |
+
"execution_count": 6,
|
99 |
+
"metadata": {},
|
100 |
+
"outputs": [
|
101 |
+
{
|
102 |
+
"name": "stdout",
|
103 |
+
"output_type": "stream",
|
104 |
+
"text": [
|
105 |
+
"--- Начало обучения модели: Transformer ---\n",
|
106 |
+
"Эпоха 1/5 | Время: 17.06с | Train Loss: 0.7023 | Val Loss: 0.7095 | Val Acc: 0.4000\n",
|
107 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.7095\n",
|
108 |
+
"Эпоха 2/5 | Время: 16.40с | Train Loss: 0.6682 | Val Loss: 0.6937 | Val Acc: 0.4800\n",
|
109 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.6937\n",
|
110 |
+
"Эпоха 3/5 | Время: 16.13с | Train Loss: 0.6471 | Val Loss: 0.7075 | Val Acc: 0.4100\n",
|
111 |
+
"Эпоха 4/5 | Время: 16.36с | Train Loss: 0.6283 | Val Loss: 0.6917 | Val Acc: 0.5300\n",
|
112 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.6917\n",
|
113 |
+
"Эпоха 5/5 | Время: 16.39с | Train Loss: 0.6050 | Val Loss: 0.6871 | Val Acc: 0.5300\n",
|
114 |
+
" -> Модель сохранена, новая лучшая Val Loss: 0.6871\n"
|
115 |
+
]
|
116 |
+
}
|
117 |
+
],
|
118 |
+
"source": [
|
119 |
+
"model.to(DEVICE)\n",
|
120 |
+
"optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
|
121 |
+
"criterion = nn.CrossEntropyLoss()\n",
|
122 |
+
"\n",
|
123 |
+
"best_val_loss = float('inf')\n",
|
124 |
+
"print(f\"--- Начало обучения модели: {MODEL_TO_TRAIN} ---\")\n",
|
125 |
+
"for epoch in range(NUM_EPOCHS):\n",
|
126 |
+
" model.train()\n",
|
127 |
+
" start_time = time.time()\n",
|
128 |
+
" total_train_loss = 0\n",
|
129 |
+
"\n",
|
130 |
+
" for batch_X, batch_y in train_loader:\n",
|
131 |
+
" batch_X, batch_y = batch_X.to(DEVICE), batch_y.to(DEVICE)\n",
|
132 |
+
" optimizer.zero_grad()\n",
|
133 |
+
" outputs = model(batch_X)\n",
|
134 |
+
" loss = criterion(outputs, batch_y)\n",
|
135 |
+
" loss.backward()\n",
|
136 |
+
" optimizer.step()\n",
|
137 |
+
" total_train_loss += loss.item()\n",
|
138 |
+
" avg_train_loss = total_train_loss / len(train_loader)\n",
|
139 |
+
" \n",
|
140 |
+
" model.eval()\n",
|
141 |
+
" total_val_loss, correct_val, total_val = 0, 0, 0\n",
|
142 |
+
" with torch.no_grad():\n",
|
143 |
+
" for batch_X, batch_y in val_loader:\n",
|
144 |
+
" batch_X, batch_y = batch_X.to(DEVICE), batch_y.to(DEVICE)\n",
|
145 |
+
" outputs = model(batch_X)\n",
|
146 |
+
" loss = criterion(outputs, batch_y)\n",
|
147 |
+
" total_val_loss += loss.item()\n",
|
148 |
+
" _, predicted = torch.max(outputs.data, 1)\n",
|
149 |
+
" total_val += batch_y.size(0)\n",
|
150 |
+
" correct_val += (predicted == batch_y).sum().item()\n",
|
151 |
+
" avg_val_loss = total_val_loss / len(val_loader)\n",
|
152 |
+
" val_accuracy = correct_val / total_val\n",
|
153 |
+
"\n",
|
154 |
+
" epoch_time = time.time() - start_time\n",
|
155 |
+
" print(f\"Эпоха {epoch+1}/{NUM_EPOCHS} | Время: {epoch_time:.2f}с | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.4f}\")\n",
|
156 |
+
" \n",
|
157 |
+
" if avg_val_loss < best_val_loss:\n",
|
158 |
+
" best_val_loss = avg_val_loss\n",
|
159 |
+
" torch.save(model.state_dict(), MODEL_SAVE_PATH)\n",
|
160 |
+
" print(f\" -> Модель сохранена, новая лучшая Val Loss: {best_val_loss:.4f}\")"
|
161 |
+
]
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"cell_type": "code",
|
165 |
+
"execution_count": 7,
|
166 |
+
"metadata": {},
|
167 |
+
"outputs": [
|
168 |
+
{
|
169 |
+
"name": "stdout",
|
170 |
+
"output_type": "stream",
|
171 |
+
"text": [
|
172 |
+
"Конфигурация модели сохранена в: ../pretrained/config.json\n"
|
173 |
+
]
|
174 |
+
}
|
175 |
+
],
|
176 |
+
"source": [
|
177 |
+
"with open(VOCAB_SAVE_PATH, 'w', encoding='utf-8') as f:\n",
|
178 |
+
" json.dump(generator.vocab, f, ensure_ascii=False, indent=4)\n",
|
179 |
+
"\n",
|
180 |
+
"config = {\n",
|
181 |
+
" \"model_type\": MODEL_TO_TRAIN,\n",
|
182 |
+
" \"max_seq_len\": MAX_SEQ_LEN,\n",
|
183 |
+
" \"model_params\": model_params,\n",
|
184 |
+
"}\n",
|
185 |
+
"with open(CONFIG_SAVE_PATH, 'w', encoding='utf-8') as f:\n",
|
186 |
+
" json.dump(config, f, ensure_ascii=False, indent=4)\n",
|
187 |
+
"print(f\"Конфигурация модели сохранена в: {CONFIG_SAVE_PATH}\")\n"
|
188 |
+
]
|
189 |
+
},
|
190 |
+
{
|
191 |
+
"cell_type": "code",
|
192 |
+
"execution_count": null,
|
193 |
+
"metadata": {},
|
194 |
+
"outputs": [],
|
195 |
+
"source": []
|
196 |
+
}
|
197 |
+
],
|
198 |
+
"metadata": {
|
199 |
+
"kernelspec": {
|
200 |
+
"display_name": "monkey-coding-dl-project-OWiM8ypK-py3.12",
|
201 |
+
"language": "python",
|
202 |
+
"name": "python3"
|
203 |
+
},
|
204 |
+
"language_info": {
|
205 |
+
"codemirror_mode": {
|
206 |
+
"name": "ipython",
|
207 |
+
"version": 3
|
208 |
+
},
|
209 |
+
"file_extension": ".py",
|
210 |
+
"mimetype": "text/x-python",
|
211 |
+
"name": "python",
|
212 |
+
"nbconvert_exporter": "python",
|
213 |
+
"pygments_lexer": "ipython3",
|
214 |
+
"version": "3.12.3"
|
215 |
+
}
|
216 |
+
},
|
217 |
+
"nbformat": 4,
|
218 |
+
"nbformat_minor": 2
|
219 |
+
}
|
poetry.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pretrained/best_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7636cf4c7205b64df4b91b2a23620d443de468a91211e074760e64adb24751ba
|
3 |
+
size 37445685
|
pretrained/config.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "Transformer",
|
3 |
+
"max_seq_len": 300,
|
4 |
+
"model_params": {
|
5 |
+
"vocab_size": 69715,
|
6 |
+
"embed_dim": 128,
|
7 |
+
"num_heads": 4,
|
8 |
+
"num_layers": 2,
|
9 |
+
"num_classes": 2,
|
10 |
+
"max_seq_len": 300
|
11 |
+
}
|
12 |
+
}
|
pretrained/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "monkey-coding-dl-project"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "HSE DL spring 2025 project"
|
5 |
+
authors = [
|
6 |
+
"Michael Litvinov",
|
7 |
+
"Kamil Gabidullin",
|
8 |
+
]
|
9 |
+
readme = "README.md"
|
10 |
+
packages = [
|
11 |
+
{ include = "src" },
|
12 |
+
]
|
13 |
+
|
14 |
+
[tool.poetry.dependencies]
|
15 |
+
python = "^3.12"
|
16 |
+
datasets = "3.6.0"
|
17 |
+
matplotlib = "3.10.3"
|
18 |
+
nltk = "3.9.1"
|
19 |
+
numpy = "2.3.0"
|
20 |
+
pandas = "2.3.0"
|
21 |
+
scikit-learn = "1.7.0"
|
22 |
+
seaborn = "0.13.2"
|
23 |
+
torch = "2.7.1"
|
24 |
+
transformers = "4.52.4"
|
25 |
+
jupyter = "^1.1.1"
|
26 |
+
ipykernel = "^6.29.5"
|
27 |
+
gradio = "^5.33.2"
|
28 |
+
|
29 |
+
|
30 |
+
[build-system]
|
31 |
+
requires = ["poetry-core"]
|
32 |
+
build-backend = "poetry.core.masonry.api"
|
requirements.txt
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
argcomplete==3.1.4
|
2 |
+
attrs==23.2.0
|
3 |
+
Automat==22.10.0
|
4 |
+
Babel==2.10.3
|
5 |
+
bcc==0.29.1
|
6 |
+
bcrypt==3.2.2
|
7 |
+
blinker==1.7.0
|
8 |
+
boto3==1.34.46
|
9 |
+
botocore==1.34.46
|
10 |
+
build==1.0.3
|
11 |
+
CacheControl==0.14.0
|
12 |
+
certifi==2023.11.17
|
13 |
+
cffi==1.17.1
|
14 |
+
chardet==5.2.0
|
15 |
+
cleo==2.1.0
|
16 |
+
click==8.1.6
|
17 |
+
cloud-init==24.4.1
|
18 |
+
colorama==0.4.6
|
19 |
+
command-not-found==0.3
|
20 |
+
configobj==5.0.8
|
21 |
+
constantly==23.10.4
|
22 |
+
cpplint==2.0.0
|
23 |
+
crashtest==0.4.1
|
24 |
+
cryptography==41.0.7
|
25 |
+
dbus-python==1.3.2
|
26 |
+
distlib==0.3.9
|
27 |
+
distro==1.9.0
|
28 |
+
distro-info==1.7+build1
|
29 |
+
docopt==0.6.2
|
30 |
+
dulwich==0.21.6
|
31 |
+
fastimport==0.9.14
|
32 |
+
fastjsonschema==2.19.0
|
33 |
+
filelock==3.17.0
|
34 |
+
gyp==0.1
|
35 |
+
h11==0.16.0
|
36 |
+
httplib2==0.20.4
|
37 |
+
hyperlink==21.0.0
|
38 |
+
idna==3.6
|
39 |
+
importlib-metadata==4.12.0
|
40 |
+
incremental==22.10.0
|
41 |
+
iniconfig==2.0.0
|
42 |
+
installer==0.7.0
|
43 |
+
jaraco.classes==3.2.1
|
44 |
+
jeepney==0.8.0
|
45 |
+
Jinja2==3.1.2
|
46 |
+
jmespath==1.0.1
|
47 |
+
jsonpatch==1.32
|
48 |
+
jsonpointer==2.0
|
49 |
+
jsonschema==4.10.3
|
50 |
+
keyring==24.3.1
|
51 |
+
launchpadlib==1.11.0
|
52 |
+
lazr.restfulclient==0.14.6
|
53 |
+
lazr.uri==1.0.6
|
54 |
+
lockfile==0.12.2
|
55 |
+
markdown-it-py==3.0.0
|
56 |
+
MarkupSafe==2.1.5
|
57 |
+
mdurl==0.1.2
|
58 |
+
more-itertools==10.2.0
|
59 |
+
msgpack==1.0.3
|
60 |
+
netaddr==0.8.0
|
61 |
+
netifaces==0.11.0
|
62 |
+
numpy==2.2.2
|
63 |
+
oauthlib==3.2.2
|
64 |
+
packaging==24.0
|
65 |
+
pexpect==4.9.0
|
66 |
+
pipenv==2024.4.1
|
67 |
+
pipx==1.4.3
|
68 |
+
pkginfo==1.9.6
|
69 |
+
platformdirs==4.3.6
|
70 |
+
pluggy==1.5.0
|
71 |
+
poetry-core==1.9.0
|
72 |
+
psutil==5.9.8
|
73 |
+
ptyprocess==0.7.0
|
74 |
+
pyasn1==0.4.8
|
75 |
+
pyasn1-modules==0.2.8
|
76 |
+
pycparser==2.22
|
77 |
+
Pygments==2.17.2
|
78 |
+
PyGObject==3.48.2
|
79 |
+
PyHamcrest==2.1.0
|
80 |
+
PyJWT==2.7.0
|
81 |
+
pylev==1.4.0
|
82 |
+
pyOpenSSL==23.2.0
|
83 |
+
pyparsing==3.1.1
|
84 |
+
pyproject_hooks==1.0.0
|
85 |
+
pyrsistent==0.20.0
|
86 |
+
pyserial==3.5
|
87 |
+
pytest==8.3.4
|
88 |
+
python-apt==2.7.7+ubuntu4
|
89 |
+
python-dateutil==2.8.2
|
90 |
+
python-debian==0.1.49+ubuntu2
|
91 |
+
python-magic==0.4.27
|
92 |
+
pytz==2024.1
|
93 |
+
PyYAML==6.0.1
|
94 |
+
requests==2.31.0
|
95 |
+
requests-toolbelt==1.0.0
|
96 |
+
rich==13.7.1
|
97 |
+
s3transfer==0.10.1
|
98 |
+
SecretStorage==3.3.3
|
99 |
+
service-identity==24.1.0
|
100 |
+
setuptools==68.1.2
|
101 |
+
shellingham==1.5.4
|
102 |
+
six==1.16.0
|
103 |
+
sniffio==1.3.1
|
104 |
+
sos==4.8.2
|
105 |
+
ssh-import-id==5.11
|
106 |
+
systemd-python==235
|
107 |
+
toml==0.10.2
|
108 |
+
tomlkit==0.12.4
|
109 |
+
trove-classifiers==2024.1.31
|
110 |
+
Twisted==24.3.0
|
111 |
+
typing_extensions==4.14.0
|
112 |
+
ubuntu-drivers-common==0.0.0
|
113 |
+
ubuntu-pro-client==8001
|
114 |
+
ufw==0.36.2
|
115 |
+
unattended-upgrades==0.1
|
116 |
+
urllib3==2.0.7
|
117 |
+
userpath==1.9.1
|
118 |
+
virtualenv==20.29.2
|
119 |
+
wadllib==1.3.6
|
120 |
+
wheel==0.42.0
|
121 |
+
xkit==0.0.0
|
122 |
+
zipp==1.0.0
|
123 |
+
zope.interface==6.1
|
src/app/__init__.py
ADDED
File without changes
|
src/app/app.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
from src.app.config import AppConfig
|
8 |
+
from src.data_utils.config import TextProcessorConfig
|
9 |
+
from src.data_utils.text_processor import TextProcessor
|
10 |
+
|
11 |
+
|
12 |
+
class App:
|
13 |
+
def __init__(self, config: AppConfig):
|
14 |
+
self.config = config
|
15 |
+
self.model: Optional[torch.nn.Module] = None
|
16 |
+
self.text_processor: Optional[TextProcessor] = None
|
17 |
+
|
18 |
+
self._load_model()
|
19 |
+
self._load_text_processor()
|
20 |
+
|
21 |
+
|
22 |
+
def _load_model(self):
|
23 |
+
"""
|
24 |
+
Load model with params from config
|
25 |
+
"""
|
26 |
+
|
27 |
+
with open(self.config.config_path, 'r') as f:
|
28 |
+
config = json.load(f)
|
29 |
+
|
30 |
+
model_type = config['model_type']
|
31 |
+
model_classes = {
|
32 |
+
'Transformer': 'TransformerClassifier',
|
33 |
+
'LSTM': 'LSTMClassifier',
|
34 |
+
'Mamba': 'MambaClassifier'
|
35 |
+
}
|
36 |
+
|
37 |
+
if model_type not in model_classes:
|
38 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
39 |
+
|
40 |
+
module = __import__(f'src.models.models', fromlist=[model_classes[model_type]])
|
41 |
+
model_class = getattr(module, model_classes[model_type])
|
42 |
+
|
43 |
+
self.model = model_class(**config['model_params'])
|
44 |
+
self.model.load_state_dict(torch.load(self.config.model_path))
|
45 |
+
self.model.eval()
|
46 |
+
|
47 |
+
|
48 |
+
def _load_text_processor(self):
|
49 |
+
with open(self.config.vocab_path, 'r') as f:
|
50 |
+
vocab = json.load(f)
|
51 |
+
|
52 |
+
processor_config = TextProcessorConfig(
|
53 |
+
max_seq_len=self.config.max_seq_len,
|
54 |
+
lowercase=True,
|
55 |
+
remove_punct=False
|
56 |
+
)
|
57 |
+
|
58 |
+
self.text_processor = TextProcessor(
|
59 |
+
vocab=vocab,
|
60 |
+
config=processor_config
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
def predict(self, text: str) -> dict:
|
65 |
+
"""
|
66 |
+
Evaluating the tone of the text
|
67 |
+
"""
|
68 |
+
|
69 |
+
if not text.strip():
|
70 |
+
return {"Negative": 0.5, "Positive": 0.5}
|
71 |
+
|
72 |
+
input_tensor = self.text_processor.text_to_tensor(text).unsqueeze(0)
|
73 |
+
|
74 |
+
with torch.no_grad():
|
75 |
+
output = self.model(input_tensor)
|
76 |
+
proba = torch.softmax(output, dim=1)[0].tolist()
|
77 |
+
|
78 |
+
return {"Negative": proba[0], "Positive": proba[1]}
|
79 |
+
|
80 |
+
|
81 |
+
def launch(self):
|
82 |
+
"""
|
83 |
+
Launch interface
|
84 |
+
"""
|
85 |
+
|
86 |
+
interface = gr.Interface(
|
87 |
+
fn=self.predict,
|
88 |
+
inputs=gr.Textbox(label="Enter your text"),
|
89 |
+
outputs=gr.Label(label="Result"),
|
90 |
+
title="Evaluating the tone of the text",
|
91 |
+
examples=["Very good! Increadble! So fantastic",
|
92 |
+
"Thw worst thing in the world!"]
|
93 |
+
)
|
94 |
+
|
95 |
+
if self.config.local:
|
96 |
+
interface.launch(
|
97 |
+
server_name=self.config.host,
|
98 |
+
server_port=self.config.port
|
99 |
+
)
|
100 |
+
else:
|
101 |
+
interface.launch(
|
102 |
+
share=True
|
103 |
+
)
|
src/app/config.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class AppConfig:
|
8 |
+
model_path: str
|
9 |
+
vocab_path: str
|
10 |
+
config_path: str
|
11 |
+
max_seq_len: int = 300
|
12 |
+
local: bool = True
|
13 |
+
host: str = "0.0.0.0"
|
14 |
+
port: int = 7860
|
15 |
+
|
16 |
+
@classmethod
|
17 |
+
def from_yaml(cls, config_path: str) -> 'AppConfig':
|
18 |
+
"""
|
19 |
+
AppConfig from path string
|
20 |
+
|
21 |
+
Args:
|
22 |
+
config_path: path string
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
AppConfig object
|
26 |
+
"""
|
27 |
+
|
28 |
+
with open(config_path, 'r') as f:
|
29 |
+
config_data = yaml.safe_load(f)
|
30 |
+
|
31 |
+
return cls(
|
32 |
+
model_path=config_data['model_path'],
|
33 |
+
vocab_path=config_data['vocab_path'],
|
34 |
+
config_path=config_data['config_path'],
|
35 |
+
max_seq_len=int(config_data['max_seq_len']),
|
36 |
+
local=config_data.get('server', {}).get('local', True),
|
37 |
+
host=config_data.get('server', {}).get('host', "0.0.0.0"),
|
38 |
+
port=config_data.get('server', {}).get('port', 7860)
|
39 |
+
)
|
src/app/model_utils/factory.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Dict, Any, Optional
|
5 |
+
|
6 |
+
from src.models.models import TransformerClassifier, MambaClassifier, LSTMClassifier
|
7 |
+
|
8 |
+
|
9 |
+
class ModelFactory:
|
10 |
+
"""
|
11 |
+
Factory class for creating and loading models
|
12 |
+
"""
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
def create_model(
|
16 |
+
model_type: str,
|
17 |
+
model_params: Dict[str, Any],
|
18 |
+
state_dict_path: Optional[Path] = None
|
19 |
+
) -> torch.nn.Module:
|
20 |
+
"""
|
21 |
+
Create and load a model from configuration
|
22 |
+
|
23 |
+
Args:
|
24 |
+
model_type: Type of model ('Transformer', 'Mamba', 'LSTM')
|
25 |
+
model_params: Dictionary of model parameters
|
26 |
+
state_dict_path: Path to saved state dictionary
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
Initialized PyTorch model
|
30 |
+
|
31 |
+
Raises:
|
32 |
+
ValueError: If model_type is unknown
|
33 |
+
"""
|
34 |
+
|
35 |
+
model_classes = {
|
36 |
+
"Transformer": TransformerClassifier,
|
37 |
+
"Mamba": MambaClassifier,
|
38 |
+
"LSTM": LSTMClassifier
|
39 |
+
}
|
40 |
+
|
41 |
+
if model_type not in model_classes:
|
42 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
43 |
+
|
44 |
+
model = model_classes[model_type](**model_params)
|
45 |
+
|
46 |
+
if state_dict_path:
|
47 |
+
state_dict = torch.load(state_dict_path, map_location="cpu")
|
48 |
+
model.load_state_dict(state_dict)
|
49 |
+
|
50 |
+
model.eval()
|
51 |
+
return model
|
src/app/model_utils/manager.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Dict, Any
|
6 |
+
|
7 |
+
from src.app.model_utils.factory import ModelFactory
|
8 |
+
|
9 |
+
|
10 |
+
class ModelManager:
|
11 |
+
"""
|
12 |
+
Manages model loading and inference operations
|
13 |
+
|
14 |
+
Args:
|
15 |
+
model_dir: Directory containing model artifacts
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, model_dir: str = "../pretrained") -> None:
|
19 |
+
self.model_dir = Path(model_dir)
|
20 |
+
self.loaded_models: Dict[str, Any] = {}
|
21 |
+
self._load_model_artifacts()
|
22 |
+
|
23 |
+
|
24 |
+
def _load_model_artifacts(self) -> None:
|
25 |
+
"""
|
26 |
+
Load model configuration and vocabulary
|
27 |
+
"""
|
28 |
+
|
29 |
+
with open(self.model_dir / "config.json", "r") as f:
|
30 |
+
self.config = json.load(f)
|
31 |
+
|
32 |
+
with open(self.model_dir / "vocab.json", "r") as f:
|
33 |
+
self.vocab = json.load(f)
|
34 |
+
|
35 |
+
self.idx_to_label = {0: "Negative", 1: "Positive"}
|
36 |
+
|
37 |
+
|
38 |
+
def get_model(self) -> torch.nn.Module:
|
39 |
+
"""
|
40 |
+
Get the loaded model (cached for performance)
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
Loaded PyTorch model in evaluation mode
|
44 |
+
"""
|
45 |
+
|
46 |
+
model_type = self.config["model_type"]
|
47 |
+
|
48 |
+
if model_type not in self.loaded_models:
|
49 |
+
model = ModelFactory.create_model(
|
50 |
+
model_type=model_type,
|
51 |
+
model_params=self.config["model_params"],
|
52 |
+
state_dict_path=self.model_dir / "best_model.pth"
|
53 |
+
)
|
54 |
+
self.loaded_models[model_type] = model
|
55 |
+
|
56 |
+
return self.loaded_models[model_type]
|
57 |
+
|
58 |
+
|
59 |
+
def get_vocab(self) -> Dict[str, int]:
|
60 |
+
"""
|
61 |
+
Get vocabulary mapping
|
62 |
+
"""
|
63 |
+
|
64 |
+
return self.vocab
|
65 |
+
|
66 |
+
|
67 |
+
def get_config(self) -> Dict[str, Any]:
|
68 |
+
"""
|
69 |
+
Get model configuration
|
70 |
+
"""
|
71 |
+
|
72 |
+
return self.config
|
src/data_utils/__init__.py
ADDED
File without changes
|
src/data_utils/config.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class DatasetConfig:
|
6 |
+
"""
|
7 |
+
Configuration class for dataset generation parameters
|
8 |
+
|
9 |
+
Attributes:
|
10 |
+
embedding_dim: Dimension for embedding layer output
|
11 |
+
train_size: Number of samples in training set
|
12 |
+
val_size: Number of samples in validation set
|
13 |
+
test_size: Number of samples in test set
|
14 |
+
random_state: Random seed for reproducibility
|
15 |
+
min_word_freq: Minimum word frequency to include in vocabulary
|
16 |
+
load_from_disk: Load dataset from local dir. If false download from huggin face
|
17 |
+
path_to_data: Path to local dataset data
|
18 |
+
max_seq_len: Maximum sequence length (will be padded/truncated to this)
|
19 |
+
lowercase: Whether to convert text to lowercase
|
20 |
+
remove_punct: Whether to remove punctuation
|
21 |
+
pad_token: Padding token
|
22 |
+
unk_token: Unknown token
|
23 |
+
"""
|
24 |
+
|
25 |
+
embedding_dim: int = 64
|
26 |
+
train_size: int = 10000
|
27 |
+
val_size: int = 5000
|
28 |
+
test_size: int = 5000
|
29 |
+
random_state: int = 42
|
30 |
+
min_word_freq: int = 1
|
31 |
+
load_from_disk: bool = False
|
32 |
+
path_to_data: str = "./datasets"
|
33 |
+
|
34 |
+
max_seq_len: int = 300
|
35 |
+
lowercase: bool = True
|
36 |
+
remove_punct: bool = False
|
37 |
+
pad_token: str = "<PAD>"
|
38 |
+
unk_token: str = "<UNK>"
|
39 |
+
|
40 |
+
|
41 |
+
@dataclass
|
42 |
+
class TextProcessorConfig:
|
43 |
+
"""
|
44 |
+
Configuration class for text processor parameters (params should be equal dataset config)
|
45 |
+
|
46 |
+
Attributes:
|
47 |
+
max_seq_len: Maximum sequence length (will be padded/truncated to this)
|
48 |
+
lowercase: Whether to convert text to lowercase
|
49 |
+
remove_punct: Whether to remove punctuation
|
50 |
+
pad_token: Padding token
|
51 |
+
unk_token: Unknown token
|
52 |
+
"""
|
53 |
+
|
54 |
+
max_seq_len: int = 300
|
55 |
+
lowercase: bool = True
|
56 |
+
remove_punct: bool = False
|
57 |
+
pad_token: str = "<PAD>"
|
58 |
+
unk_token: str = "<UNK>"
|
src/data_utils/dataset_generator.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import Counter
|
2 |
+
from typing import Dict, Tuple, List
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from datasets import load_dataset, load_from_disk
|
8 |
+
from sklearn.model_selection import train_test_split
|
9 |
+
|
10 |
+
import src.data_utils.dataset_params as dataset_params
|
11 |
+
|
12 |
+
from src.data_utils.config import DatasetConfig, TextProcessorConfig
|
13 |
+
from src.data_utils.text_processor import TextProcessor
|
14 |
+
|
15 |
+
|
16 |
+
class DatasetGenerator:
|
17 |
+
"""
|
18 |
+
Main dataset generator class
|
19 |
+
|
20 |
+
Provides methods to load, build vocabulary, convert text datasets
|
21 |
+
into tensor format suitable for deep learning models.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
dataset_name: Name of dataset from DatasetName enum
|
25 |
+
config: Configuration object with preprocessing parameters
|
26 |
+
device: Torch device to place tensors on (cpu/cuda)
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
dataset_name: dataset_params.DatasetName,
|
32 |
+
config: DatasetConfig = DatasetConfig(),
|
33 |
+
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
34 |
+
):
|
35 |
+
self.dataset_params = dataset_params.get_dataset_params_by_name(dataset_name=dataset_name)
|
36 |
+
self.config = config
|
37 |
+
self.device = device
|
38 |
+
self.text_processor = TextProcessor(
|
39 |
+
vocab=None,
|
40 |
+
config=TextProcessorConfig(
|
41 |
+
max_seq_len=self.config.max_seq_len,
|
42 |
+
lowercase=self.config.lowercase,
|
43 |
+
remove_punct=self.config.remove_punct,
|
44 |
+
pad_token=self.config.pad_token,
|
45 |
+
unk_token=self.config.unk_token,
|
46 |
+
)
|
47 |
+
)
|
48 |
+
self.vocab = None
|
49 |
+
self.id2word = None
|
50 |
+
self.embedding_layer = None
|
51 |
+
|
52 |
+
|
53 |
+
def load_raw_data(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
54 |
+
"""
|
55 |
+
Load raw dataset from source
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
Tuple of (train_df, val_df, test_df) DataFrames
|
59 |
+
"""
|
60 |
+
if self.config.load_from_disk:
|
61 |
+
dataset = load_from_disk(f"{self.config.path_to_data}/{self.dataset_params.local_path}")
|
62 |
+
else:
|
63 |
+
dataset = load_dataset(self.dataset_params.hugging_face_name)
|
64 |
+
train_df = pd.DataFrame(dataset["train"])
|
65 |
+
test_df = pd.DataFrame(dataset["test"])
|
66 |
+
val_df, test_df = train_test_split(
|
67 |
+
test_df,
|
68 |
+
test_size=0.5,
|
69 |
+
random_state=self.config.random_state,
|
70 |
+
stratify=test_df[self.dataset_params.label_col_name]
|
71 |
+
)
|
72 |
+
|
73 |
+
# Sample configured sizes
|
74 |
+
train_df = train_df.sample(n=self.config.train_size, random_state=self.config.random_state)
|
75 |
+
val_df = val_df.sample(n=self.config.val_size, random_state=self.config.random_state)
|
76 |
+
test_df = test_df.sample(n=self.config.test_size, random_state=self.config.random_state)
|
77 |
+
|
78 |
+
return train_df, val_df, test_df
|
79 |
+
|
80 |
+
|
81 |
+
def build_vocabulary(self, tokenized_texts: List[List[str]]) -> Tuple[Dict[str, int], Dict[int, str]]:
|
82 |
+
"""
|
83 |
+
Build vocabulary from tokenized texts
|
84 |
+
|
85 |
+
Args:
|
86 |
+
tokenized_texts: List of tokenized texts
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
Tuple of (word_to_id, id_to_word) mappings
|
90 |
+
"""
|
91 |
+
|
92 |
+
all_tokens = [token for tokens in tokenized_texts for token in tokens]
|
93 |
+
word_counts = Counter(all_tokens)
|
94 |
+
|
95 |
+
filtered_words = [word for word, count in word_counts.items()
|
96 |
+
if count >= self.config.min_word_freq]
|
97 |
+
|
98 |
+
word_to_id = {self.config.pad_token: 0, self.config.unk_token: 1}
|
99 |
+
id_to_word = {0: self.config.pad_token, 1: self.config.unk_token}
|
100 |
+
|
101 |
+
for idx, word in enumerate(filtered_words, start=2):
|
102 |
+
word_to_id[word] = idx
|
103 |
+
id_to_word[idx] = word
|
104 |
+
|
105 |
+
self.text_processor.vocab = word_to_id
|
106 |
+
|
107 |
+
return word_to_id, id_to_word
|
108 |
+
|
109 |
+
|
110 |
+
def generate_dataset(self) -> Tuple[
|
111 |
+
Tuple[torch.Tensor, torch.Tensor],
|
112 |
+
Tuple[torch.Tensor, torch.Tensor],
|
113 |
+
Tuple[torch.Tensor, torch.Tensor]
|
114 |
+
]:
|
115 |
+
"""
|
116 |
+
Main method to generate the full dataset
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
Tuple containing:
|
120 |
+
- (train_features, train_labels)
|
121 |
+
- (val_features, val_labels)
|
122 |
+
- (test_features, test_labels)
|
123 |
+
- embedding_layer
|
124 |
+
"""
|
125 |
+
|
126 |
+
train_df, val_df, test_df = self.load_raw_data()
|
127 |
+
|
128 |
+
train_texts = train_df[self.dataset_params.content_col_name].tolist()
|
129 |
+
train_tokens = [self.text_processor.preprocess_text(text) for text in train_texts]
|
130 |
+
|
131 |
+
self.vocab, self.id2word = self.build_vocabulary(train_tokens)
|
132 |
+
|
133 |
+
X_train = torch.stack([self.text_processor.text_to_tensor(text) for text in train_texts])
|
134 |
+
|
135 |
+
val_texts = val_df[self.dataset_params.content_col_name].tolist()
|
136 |
+
X_val = torch.stack([self.text_processor.text_to_tensor(text) for text in val_texts])
|
137 |
+
|
138 |
+
test_texts = test_df[self.dataset_params.content_col_name].tolist()
|
139 |
+
X_test = torch.stack([self.text_processor.text_to_tensor(text) for text in test_texts])
|
140 |
+
|
141 |
+
y_train = torch.tensor(train_df[self.dataset_params.label_col_name].values, dtype=torch.long)
|
142 |
+
y_val = torch.tensor(val_df[self.dataset_params.label_col_name].values, dtype=torch.long)
|
143 |
+
y_test = torch.tensor(test_df[self.dataset_params.label_col_name].values, dtype=torch.long)
|
144 |
+
|
145 |
+
return (X_train, y_train), (X_val, y_val), (X_test, y_test)
|
146 |
+
|
147 |
+
|
148 |
+
def get_vocabulary(self) -> Tuple[Dict[str, int], Dict[int, str]]:
|
149 |
+
"""
|
150 |
+
Get vocabulary mappings
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
Tuple of (word_to_id, id_to_word) dictionaries
|
154 |
+
"""
|
155 |
+
|
156 |
+
return self.vocab, self.id2word
|
157 |
+
|
158 |
+
|
159 |
+
def get_config(self) -> DatasetConfig:
|
160 |
+
"""
|
161 |
+
Get current configuration
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
DatasetConfig object
|
165 |
+
"""
|
166 |
+
|
167 |
+
return self.config
|
168 |
+
|
169 |
+
|
170 |
+
def get_text_processor(self) -> TextProcessor:
|
171 |
+
"""
|
172 |
+
Get the text processor for inference usage
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
TextProcessor object
|
176 |
+
"""
|
177 |
+
return self.text_processor
|
src/data_utils/dataset_params.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import enum
|
2 |
+
|
3 |
+
|
4 |
+
class DatasetName(enum.Enum):
|
5 |
+
"""
|
6 |
+
Supported dataset names enumeration
|
7 |
+
"""
|
8 |
+
|
9 |
+
IMDB = "imdb"
|
10 |
+
POLARITY = "polarity"
|
11 |
+
|
12 |
+
|
13 |
+
class DatasetParams:
|
14 |
+
"""
|
15 |
+
Abstarct class for dataset
|
16 |
+
"""
|
17 |
+
|
18 |
+
hugging_face_name = ""
|
19 |
+
content_col_name = ""
|
20 |
+
label_col_name = ""
|
21 |
+
local_path = ""
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
def get_dataset_params_by_name(dataset_name: DatasetName) -> DatasetParams:
|
26 |
+
if dataset_name == DatasetName.IMDB:
|
27 |
+
return ImbdParams()
|
28 |
+
if dataset_name == DatasetName.POLARITY:
|
29 |
+
return PolarityParams()
|
30 |
+
|
31 |
+
raise ValueError(f"Unsupported dataset: {dataset_name}")
|
32 |
+
|
33 |
+
|
34 |
+
class ImbdParams(DatasetParams):
|
35 |
+
"""
|
36 |
+
IMDB dataset params class
|
37 |
+
"""
|
38 |
+
|
39 |
+
hugging_face_name = "stanfordnlp/imdb"
|
40 |
+
content_col_name = "text"
|
41 |
+
label_col_name = "label"
|
42 |
+
local_path = "imdb"
|
43 |
+
|
44 |
+
|
45 |
+
class PolarityParams(DatasetParams):
|
46 |
+
"""
|
47 |
+
POLARITY dataset params class
|
48 |
+
"""
|
49 |
+
|
50 |
+
hugging_face_name = "fancyzhx/amazon_polarity"
|
51 |
+
content_col_name = "content"
|
52 |
+
label_col_name = "label"
|
53 |
+
local_path = "polarity"
|
src/data_utils/text_processor.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict
|
2 |
+
|
3 |
+
import nltk
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from nltk.tokenize import word_tokenize
|
7 |
+
|
8 |
+
from src.data_utils.config import TextProcessorConfig
|
9 |
+
|
10 |
+
|
11 |
+
class TextProcessor:
|
12 |
+
"""
|
13 |
+
Main text preprocessor class
|
14 |
+
|
15 |
+
Args:
|
16 |
+
vocab: Vocabulary dictionary
|
17 |
+
config: Configuration object
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, vocab: Dict[str, int], config: TextProcessorConfig):
|
21 |
+
self.vocab = vocab
|
22 |
+
self.config = config
|
23 |
+
self._ensure_nltk_downloaded()
|
24 |
+
|
25 |
+
|
26 |
+
def _ensure_nltk_downloaded(self):
|
27 |
+
try:
|
28 |
+
word_tokenize("test")
|
29 |
+
except LookupError:
|
30 |
+
nltk.download("punkt")
|
31 |
+
|
32 |
+
|
33 |
+
def preprocess_text(self, text: str) -> List[str]:
|
34 |
+
"""
|
35 |
+
Tokenize and preprocess single text string
|
36 |
+
|
37 |
+
Args:
|
38 |
+
text: Your text
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
List of preprocessed tokens
|
42 |
+
"""
|
43 |
+
|
44 |
+
if self.config.lowercase:
|
45 |
+
text = text.lower()
|
46 |
+
|
47 |
+
tokens = word_tokenize(text)
|
48 |
+
|
49 |
+
if self.config.remove_punct:
|
50 |
+
tokens = [t for t in tokens if t.isalpha()]
|
51 |
+
|
52 |
+
return tokens
|
53 |
+
|
54 |
+
|
55 |
+
def text_to_tensor(self, text: str) -> torch.Tensor:
|
56 |
+
"""
|
57 |
+
Convert raw text to tensor
|
58 |
+
|
59 |
+
Args:
|
60 |
+
text: Your text
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
Tensor of your text
|
64 |
+
"""
|
65 |
+
|
66 |
+
tokens = self.preprocess_text(text)
|
67 |
+
ids = [self.vocab.get(token, self.vocab[self.config.unk_token]) for token in tokens]
|
68 |
+
|
69 |
+
# Pad or truncate
|
70 |
+
if len(ids) < self.config.max_seq_len:
|
71 |
+
ids = ids + [self.vocab[self.config.pad_token]] * (self.config.max_seq_len - len(ids))
|
72 |
+
else:
|
73 |
+
ids = ids[:self.config.max_seq_len]
|
74 |
+
|
75 |
+
return torch.tensor(ids, dtype=torch.long)
|
src/models/__init__.py
ADDED
File without changes
|
src/models/models.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class TransformerClassifier(nn.Module):
|
6 |
+
def __init__(self, vocab_size, embed_dim, num_heads, num_layers, num_classes, max_seq_len):
|
7 |
+
super(TransformerClassifier, self).__init__()
|
8 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
|
9 |
+
self.pos_encoder = nn.Parameter(torch.zeros(1, max_seq_len, embed_dim))
|
10 |
+
encoder_layers = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True, dim_feedforward=embed_dim*4)
|
11 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
|
12 |
+
self.fc = nn.Linear(embed_dim, num_classes)
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
padding_mask = (x == 0)
|
16 |
+
x = self.embedding(x) + self.pos_encoder
|
17 |
+
x = self.transformer_encoder(x, src_key_padding_mask=padding_mask)
|
18 |
+
x = x.mean(dim=1)
|
19 |
+
x = self.fc(x)
|
20 |
+
return x
|
21 |
+
|
22 |
+
class LSTMClassifier(nn.Module):
|
23 |
+
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, num_classes, dropout):
|
24 |
+
super(LSTMClassifier, self).__init__()
|
25 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
|
26 |
+
self.lstm = nn.LSTM(
|
27 |
+
input_size=embed_dim, hidden_size=hidden_dim, num_layers=num_layers,
|
28 |
+
batch_first=True, bidirectional=True, dropout=dropout if num_layers > 1 else 0
|
29 |
+
)
|
30 |
+
self.fc = nn.Linear(hidden_dim * 2, num_classes)
|
31 |
+
self.dropout = nn.Dropout(dropout)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
embedded = self.dropout(self.embedding(x))
|
35 |
+
_, (hidden, cell) = self.lstm(embedded)
|
36 |
+
hidden_cat = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
|
37 |
+
output = self.fc(self.dropout(hidden_cat))
|
38 |
+
return output
|
39 |
+
|
40 |
+
|
41 |
+
class SimpleMambaBlock(nn.Module):
|
42 |
+
"""
|
43 |
+
Логика: Проекция -> 1D Свертка -> Активация -> Селективный SSM -> Выходная проекция
|
44 |
+
"""
|
45 |
+
def __init__(self, d_model, d_state, d_conv, expand=2):
|
46 |
+
super().__init__()
|
47 |
+
self.d_model = d_model
|
48 |
+
self.d_state = d_state
|
49 |
+
self.d_conv = d_conv
|
50 |
+
self.expand = expand
|
51 |
+
|
52 |
+
d_inner = int(self.expand * self.d_model)
|
53 |
+
|
54 |
+
self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
|
55 |
+
|
56 |
+
self.conv1d = nn.Conv1d(
|
57 |
+
in_channels=d_inner, out_channels=d_inner,
|
58 |
+
kernel_size=d_conv, padding=d_conv - 1,
|
59 |
+
groups=d_inner, bias=True
|
60 |
+
)
|
61 |
+
|
62 |
+
self.x_proj = nn.Linear(d_inner, self.d_state + self.d_state + 1, bias=False)
|
63 |
+
self.dt_proj = nn.Linear(1, d_inner, bias=True)
|
64 |
+
|
65 |
+
A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(d_inner, 1)
|
66 |
+
self.A_log = nn.Parameter(torch.log(A))
|
67 |
+
self.D = nn.Parameter(torch.ones(d_inner))
|
68 |
+
|
69 |
+
self.out_proj = nn.Linear(d_inner, d_model, bias=False)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
B, L, D = x.shape
|
73 |
+
|
74 |
+
xz = self.in_proj(x)
|
75 |
+
x, z = xz.chunk(2, dim=-1)
|
76 |
+
|
77 |
+
x = x.transpose(1, 2)
|
78 |
+
x = self.conv1d(x)[:, :, :L]
|
79 |
+
x = x.transpose(1, 2)
|
80 |
+
|
81 |
+
x = F.silu(x)
|
82 |
+
|
83 |
+
y = self.ssm(x)
|
84 |
+
|
85 |
+
y = y * F.silu(z)
|
86 |
+
y = self.out_proj(y)
|
87 |
+
|
88 |
+
return y
|
89 |
+
|
90 |
+
def ssm(self, x):
|
91 |
+
batch_size, seq_len, d_inner = x.shape
|
92 |
+
|
93 |
+
A = -torch.exp(self.A_log.float())
|
94 |
+
D = self.D.float()
|
95 |
+
|
96 |
+
x_dbl = self.x_proj(x)
|
97 |
+
delta, B_param, C_param = torch.split(x_dbl, [1, self.d_state, self.d_state], dim=-1)
|
98 |
+
|
99 |
+
delta = F.softplus(self.dt_proj(delta))
|
100 |
+
|
101 |
+
h = torch.zeros(batch_size, d_inner, self.d_state, device=x.device)
|
102 |
+
ys = []
|
103 |
+
|
104 |
+
for i in range(seq_len):
|
105 |
+
delta_i = delta[:, i, :]
|
106 |
+
A_i = torch.exp(delta_i.unsqueeze(-1) * A)
|
107 |
+
B_i = delta_i.unsqueeze(-1) * B_param[:, i, :].unsqueeze(1)
|
108 |
+
|
109 |
+
h = A_i * h + B_i * x[:, i, :].unsqueeze(-1)
|
110 |
+
|
111 |
+
y_i = (h @ C_param[:, i, :].unsqueeze(-1)).squeeze(-1)
|
112 |
+
ys.append(y_i)
|
113 |
+
|
114 |
+
y = torch.stack(ys, dim=1)
|
115 |
+
y = y + x * D
|
116 |
+
|
117 |
+
return y
|
118 |
+
|
119 |
+
|
120 |
+
class CustomMambaClassifier(nn.Module):
|
121 |
+
def __init__(self, vocab_size, d_model, d_state, d_conv, num_layers, num_classes):
|
122 |
+
super().__init__()
|
123 |
+
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
|
124 |
+
|
125 |
+
self.layers = nn.ModuleList(
|
126 |
+
[SimpleMambaBlock(d_model, d_state, d_conv) for _ in range(num_layers)]
|
127 |
+
)
|
128 |
+
|
129 |
+
self.fc = nn.Linear(d_model, num_classes)
|
130 |
+
|
131 |
+
def forward(self, x):
|
132 |
+
x = self.embedding(x)
|
133 |
+
for layer in self.layers:
|
134 |
+
x = layer(x)
|
135 |
+
|
136 |
+
pooled_output = x.mean(dim=1)
|
137 |
+
return self.fc(pooled_output)
|
src/models/predict.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
from nltk.tokenize import word_tokenize
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
from src.models.models import TransformerClassifier, MambaClassifier, LSTMClassifier
|
7 |
+
|
8 |
+
SAVE_DIR = "pretrained"
|
9 |
+
MODEL_PATH = f"{SAVE_DIR}/best_model.pth"
|
10 |
+
CONFIG_PATH = f"{SAVE_DIR}/config.json"
|
11 |
+
VOCAB_PATH = f"{SAVE_DIR}/vocab.json"
|
12 |
+
|
13 |
+
ID_TO_LABEL = {0: "Negative", 1: "Positive"}
|
14 |
+
|
15 |
+
def load_artifacts():
|
16 |
+
with open(CONFIG_PATH, 'r') as f:
|
17 |
+
config = json.load(f)
|
18 |
+
with open(VOCAB_PATH, 'r') as f:
|
19 |
+
vocab = json.load(f)
|
20 |
+
|
21 |
+
model_type = config['model_type']
|
22 |
+
model_params = config['model_params']
|
23 |
+
|
24 |
+
if model_type == 'Transformer':
|
25 |
+
model = TransformerClassifier(**model_params)
|
26 |
+
elif model_type == 'Mamba':
|
27 |
+
model = MambaClassifier(**model_params)
|
28 |
+
elif model_type == 'LSTM':
|
29 |
+
model = LSTMClassifier(**model_params)
|
30 |
+
else:
|
31 |
+
raise ValueError("Неизвестный тип модели в файле конфигурации.")
|
32 |
+
|
33 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
34 |
+
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
|
35 |
+
model.to(device)
|
36 |
+
model.eval()
|
37 |
+
|
38 |
+
return model, vocab, config, device
|
39 |
+
|
40 |
+
def preprocess_text(text, vocab, max_len):
|
41 |
+
tokens = word_tokenize(text.lower())
|
42 |
+
ids = [vocab.get(token, vocab['<UNK>']) for token in tokens]
|
43 |
+
if len(ids) < max_len:
|
44 |
+
ids.extend([vocab['<PAD>']] * (max_len - len(ids)))
|
45 |
+
else:
|
46 |
+
ids = ids[:max_len]
|
47 |
+
return torch.tensor(ids).unsqueeze(0)
|
48 |
+
|
49 |
+
def predict(text, model, vocab, config, device):
|
50 |
+
input_tensor = preprocess_text(text, vocab, config['max_seq_len'])
|
51 |
+
input_tensor = input_tensor.to(device)
|
52 |
+
|
53 |
+
with torch.no_grad():
|
54 |
+
outputs = model(input_tensor)
|
55 |
+
probabilities = torch.softmax(outputs, dim=1)
|
56 |
+
prediction_id = torch.argmax(probabilities, dim=1).item()
|
57 |
+
|
58 |
+
predicted_label = ID_TO_LABEL[prediction_id]
|
59 |
+
confidence = probabilities[0][prediction_id].item()
|
60 |
+
|
61 |
+
return predicted_label, confidence
|
62 |
+
|
63 |
+
if __name__ == "__main__":
|
64 |
+
parser = argparse.ArgumentParser(description="Предсказать тональность текста с помощью обученной модели.")
|
65 |
+
parser.add_argument("text", type=str, help="Текст для анализа (в кавычках).")
|
66 |
+
args = parser.parse_args()
|
67 |
+
|
68 |
+
print("Загрузка модели и артефактов...")
|
69 |
+
try:
|
70 |
+
loaded_model, loaded_vocab, loaded_config, device = load_artifacts()
|
71 |
+
print(f"Модель '{loaded_config['model_type']}' успешно загружена на устройство {device}.")
|
72 |
+
except FileNotFoundError:
|
73 |
+
print("\nОШИБКА: Файлы модели не найдены!")
|
74 |
+
print("Сначала запустите скрипт train.py для обучения и сохранения модели.")
|
75 |
+
exit()
|
76 |
+
|
77 |
+
label, conf = predict(args.text, loaded_model, loaded_vocab, loaded_config, device)
|
78 |
+
|
79 |
+
print("\n--- Результат предсказания ---")
|
80 |
+
print(f"Текст: '{args.text}'")
|
81 |
+
print(f"Тональность: {label}")
|
82 |
+
print(f"Уверенность: {conf:.2%}")
|
83 |
+
|