{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "for warn in [UserWarning, FutureWarning]: warnings.filterwarnings(\"ignore\", category = warn)\n",
    "\n",
    "import os\n",
    "import time\n",
    "import json\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Импортируем классы моделей из нашего файла\n",
    "from src.models.models import TransformerClassifier, MambaClassifier, LSTMClassifier\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_TO_TRAIN = 'Transformer' \n",
    "\n",
    "# Гиперпараметры данных и модели\n",
    "MAX_SEQ_LEN = 300\n",
    "EMBEDDING_DIM = 128\n",
    "BATCH_SIZE = 32\n",
    "LEARNING_RATE = 1e-4\n",
    "NUM_EPOCHS = 5 # Увеличим для лучшего результата\n",
    "\n",
    "# Пути для сохранения артефактов\n",
    "SAVE_DIR = \"../pretrained\"\n",
    "os.makedirs(SAVE_DIR, exist_ok=True)\n",
    "MODEL_SAVE_PATH = os.path.join(SAVE_DIR, \"best_model.pth\")\n",
    "VOCAB_SAVE_PATH = os.path.join(SAVE_DIR, \"vocab.json\")\n",
    "CONFIG_SAVE_PATH = os.path.join(SAVE_DIR, \"config.json\")\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.data_utils.dataset_generator import DatasetGenerator\n",
    "from src.data_utils.dataset_params import DatasetName\n",
    "\n",
    "generator = DatasetGenerator(DatasetName.IMDB)\n",
    "(X_train, y_train), (X_val, y_val), (X_test, y_test) = generator.generate_dataset()\n",
    "X_train, y_train, X_val, y_val, X_test, y_test = X_train[:1000], y_train[:1000], X_val[:100], y_val[:100], X_test[:100], y_test[:100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_dataloader(X, y, batch_size):\n",
    "    dataset = TensorDataset(torch.tensor(X, dtype=torch.long), torch.tensor(y, dtype=torch.long))\n",
    "    return DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
    "train_loader = create_dataloader(X_train, y_train, BATCH_SIZE)\n",
    "val_loader = create_dataloader(X_val, y_val, BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_params = {}\n",
    "if MODEL_TO_TRAIN == 'Transformer':\n",
    "    model_params = {'vocab_size': len(generator.vocab), 'embed_dim': EMBEDDING_DIM, 'num_heads': 4, 'num_layers': 2, 'num_classes': 2, 'max_seq_len': MAX_SEQ_LEN}\n",
    "    model = TransformerClassifier(**model_params)\n",
    "elif MODEL_TO_TRAIN == 'Mamba':\n",
    "    model_params = {'vocab_size': len(generator.vocab), 'embed_dim': EMBEDDING_DIM, 'mamba_d_state': 16, 'mamba_d_conv': 4, 'mamba_expand': 2, 'num_classes': 2}\n",
    "    model = MambaClassifier(**model_params)\n",
    "elif MODEL_TO_TRAIN == 'LSTM':\n",
    "    model_params = {'vocab_size': len(generator.vocab), 'embed_dim': EMBEDDING_DIM, 'hidden_dim': 256, 'num_layers': 2, 'num_classes': 2, 'dropout': 0.5}\n",
    "    model = LSTMClassifier(**model_params)\n",
    "else:\n",
    "    raise ValueError(\"Неизвестный тип модели. Выберите 'Transformer', 'Mamba' или 'LSTM'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Начало обучения модели: Transformer ---\n",
      "Эпоха 1/5 | Время: 17.06с | Train Loss: 0.7023 | Val Loss: 0.7095 | Val Acc: 0.4000\n",
      "  -> Модель сохранена, новая лучшая Val Loss: 0.7095\n",
      "Эпоха 2/5 | Время: 16.40с | Train Loss: 0.6682 | Val Loss: 0.6937 | Val Acc: 0.4800\n",
      "  -> Модель сохранена, новая лучшая Val Loss: 0.6937\n",
      "Эпоха 3/5 | Время: 16.13с | Train Loss: 0.6471 | Val Loss: 0.7075 | Val Acc: 0.4100\n",
      "Эпоха 4/5 | Время: 16.36с | Train Loss: 0.6283 | Val Loss: 0.6917 | Val Acc: 0.5300\n",
      "  -> Модель сохранена, новая лучшая Val Loss: 0.6917\n",
      "Эпоха 5/5 | Время: 16.39с | Train Loss: 0.6050 | Val Loss: 0.6871 | Val Acc: 0.5300\n",
      "  -> Модель сохранена, новая лучшая Val Loss: 0.6871\n"
     ]
    }
   ],
   "source": [
    "model.to(DEVICE)\n",
    "optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "best_val_loss = float('inf')\n",
    "print(f\"--- Начало обучения модели: {MODEL_TO_TRAIN} ---\")\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    model.train()\n",
    "    start_time = time.time()\n",
    "    total_train_loss = 0\n",
    "\n",
    "    for batch_X, batch_y in train_loader:\n",
    "        batch_X, batch_y = batch_X.to(DEVICE), batch_y.to(DEVICE)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(batch_X)\n",
    "        loss = criterion(outputs, batch_y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_train_loss += loss.item()\n",
    "    avg_train_loss = total_train_loss / len(train_loader)\n",
    "    \n",
    "    model.eval()\n",
    "    total_val_loss, correct_val, total_val = 0, 0, 0\n",
    "    with torch.no_grad():\n",
    "        for batch_X, batch_y in val_loader:\n",
    "            batch_X, batch_y = batch_X.to(DEVICE), batch_y.to(DEVICE)\n",
    "            outputs = model(batch_X)\n",
    "            loss = criterion(outputs, batch_y)\n",
    "            total_val_loss += loss.item()\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total_val += batch_y.size(0)\n",
    "            correct_val += (predicted == batch_y).sum().item()\n",
    "    avg_val_loss = total_val_loss / len(val_loader)\n",
    "    val_accuracy = correct_val / total_val\n",
    "\n",
    "    epoch_time = time.time() - start_time\n",
    "    print(f\"Эпоха {epoch+1}/{NUM_EPOCHS} | Время: {epoch_time:.2f}с | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.4f}\")\n",
    "    \n",
    "    if avg_val_loss < best_val_loss:\n",
    "        best_val_loss = avg_val_loss\n",
    "        torch.save(model.state_dict(), MODEL_SAVE_PATH)\n",
    "        print(f\"  -> Модель сохранена, новая лучшая Val Loss: {best_val_loss:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Конфигурация модели сохранена в: ../pretrained/config.json\n"
     ]
    }
   ],
   "source": [
    "with open(VOCAB_SAVE_PATH, 'w', encoding='utf-8') as f:\n",
    "    json.dump(generator.vocab, f, ensure_ascii=False, indent=4)\n",
    "\n",
    "config = {\n",
    "    \"model_type\": MODEL_TO_TRAIN,\n",
    "    \"max_seq_len\": MAX_SEQ_LEN,\n",
    "    \"model_params\": model_params,\n",
    "}\n",
    "with open(CONFIG_SAVE_PATH, 'w', encoding='utf-8') as f:\n",
    "    json.dump(config, f, ensure_ascii=False, indent=4)\n",
    "print(f\"Конфигурация модели сохранена в: {CONFIG_SAVE_PATH}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "monkey-coding-dl-project-OWiM8ypK-py3.12",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}