from accelerate.utils import set_seed

set_seed(1024)


import math
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from datasets import concatenate_datasets
import matplotlib.pyplot as plt
import numpy as np
from .config import (
    BATCH_SIZE,
    DEVICE,
    EPOCHS,
    LR,
    GRAD_ACCUM_STEPS,
    HOP_LENGTH,
    SAMPLE_RATE,
)
from .model import TaikoConformer6
from .dataset import ds
from .preprocess import preprocess, collate_fn
from .loss import TaikoEnergyLoss
from huggingface_hub import upload_folder


# --- Helper function to log energy plots ---
def log_energy_plots_to_tensorboard(
    writer,
    tag_prefix,
    epoch,
    pred_don,
    pred_ka,
    pred_drumroll,
    true_don,
    true_ka,
    true_drumroll,
    valid_length,  # Actual valid length of the sequence (before padding)
    hop_sec,
):
    """
    Logs a plot of predicted vs. true energies for one sample to TensorBoard.
    Energies should be 1D numpy arrays for the single sample, up to valid_length.
    """
    # Ensure data is on CPU and converted to numpy, and select only the valid part
    pred_don = pred_don[:valid_length].detach().cpu().numpy()
    pred_ka = pred_ka[:valid_length].detach().cpu().numpy()
    pred_drumroll = pred_drumroll[:valid_length].detach().cpu().numpy()
    true_don = true_don[:valid_length].cpu().numpy()
    true_ka = true_ka[:valid_length].cpu().numpy()
    true_drumroll = true_drumroll[:valid_length].cpu().numpy()

    time_axis = np.arange(valid_length) * hop_sec

    fig, axs = plt.subplots(3, 1, figsize=(15, 10), sharex=True)
    fig.suptitle(f"{tag_prefix} - Epoch {epoch}", fontsize=16)

    axs[0].plot(time_axis, true_don, label="True Don", color="blue", linestyle="--")
    axs[0].plot(time_axis, pred_don, label="Pred Don", color="lightblue", alpha=0.8)
    axs[0].set_ylabel("Don Energy")
    axs[0].legend()
    axs[0].grid(True)

    axs[1].plot(time_axis, true_ka, label="True Ka", color="red", linestyle="--")
    axs[1].plot(time_axis, pred_ka, label="Pred Ka", color="lightcoral", alpha=0.8)
    axs[1].set_ylabel("Ka Energy")
    axs[1].legend()
    axs[1].grid(True)

    axs[2].plot(
        time_axis, true_drumroll, label="True Drumroll", color="green", linestyle="--"
    )
    axs[2].plot(
        time_axis, pred_drumroll, label="Pred Drumroll", color="lightgreen", alpha=0.8
    )
    axs[2].set_ylabel("Drumroll Energy")
    axs[2].set_xlabel("Time (s)")
    axs[2].legend()
    axs[2].grid(True)

    plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust layout to make space for suptitle
    writer.add_figure(f"{tag_prefix}/Energy_Comparison", fig, epoch)
    plt.close(fig)


def main():
    global ds

    # Calculate hop seconds for model output frames
    # This assumes the model output time dimension corresponds to the mel spectrogram time dimension
    output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE

    best_val_loss = float("inf")
    patience = 10  # Increased patience a bit
    pat_count = 0

    ds_oni = ds.map(
        preprocess,
        remove_columns=ds.column_names,
        fn_kwargs={"difficulty": "oni"},
        writer_batch_size=10,
    )
    ds_hard = ds.map(
        preprocess,
        remove_columns=ds.column_names,
        fn_kwargs={"difficulty": "hard"},
        writer_batch_size=10,
    )
    ds_normal = ds.map(
        preprocess,
        remove_columns=ds.column_names,
        fn_kwargs={"difficulty": "normal"},
        writer_batch_size=10,
    )
    ds = concatenate_datasets([ds_oni, ds_hard, ds_normal])

    ds_train_test = ds.train_test_split(test_size=0.1, seed=42)
    # ds_train_test.push_to_hub("JacobLinCool/taiko-conformer-6-ds")
    train_loader = DataLoader(
        ds_train_test["train"],
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=16,
        persistent_workers=True,
        prefetch_factor=4,
    )
    val_loader = DataLoader(
        ds_train_test["test"],
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=16,
        persistent_workers=True,
        prefetch_factor=4,
    )

    model = TaikoConformer6().to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

    criterion = TaikoEnergyLoss(reduction="mean").to(DEVICE)

    # Adjust scheduler steps for gradient accumulation
    num_optimizer_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
    total_optimizer_steps = EPOCHS * num_optimizer_steps_per_epoch

    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=LR, total_steps=total_optimizer_steps
    )

    writer = SummaryWriter()

    for epoch in range(1, EPOCHS + 1):
        model.train()
        total_epoch_loss = 0.0
        optimizer.zero_grad()

        for idx, batch in enumerate(tqdm(train_loader, desc=f"Train Epoch {epoch}")):
            mel = batch["mel"].to(DEVICE)
            # Unpack new energy-based labels
            don_labels = batch["don_labels"].to(DEVICE)
            ka_labels = batch["ka_labels"].to(DEVICE)
            drumroll_labels = batch["drumroll_labels"].to(DEVICE)
            lengths = batch["lengths"].to(
                DEVICE
            )  # These are for the Conformer model output
            nps = batch["nps"].to(DEVICE)
            difficulty = batch["difficulty"].to(DEVICE)  # Add difficulty
            level = batch["level"].to(DEVICE)  # Add level

            output_dict = model(
                mel, lengths, nps, difficulty, level
            )  # Pass difficulty and level
            # output_dict["presence"] is now (B, T_out, 3) for don, ka, drumroll energies
            pred_energies_batch = output_dict["presence"]  # (B, T_out, 3)

            loss_input_batch = {
                "don_labels": don_labels,
                "ka_labels": ka_labels,
                "drumroll_labels": drumroll_labels,
                "lengths": lengths,  # Pass lengths for masking within the loss function
            }
            loss = criterion(output_dict, loss_input_batch)

            (loss / GRAD_ACCUM_STEPS).backward()

            if (idx + 1) % GRAD_ACCUM_STEPS == 0 or (idx + 1) == len(train_loader):
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            total_epoch_loss += loss.item()

            # Log plot for the first sample of the first batch in each training epoch
            if idx == 0:
                first_sample_pred_don = pred_energies_batch[0, :, 0]
                first_sample_pred_ka = pred_energies_batch[0, :, 1]
                first_sample_pred_drumroll = pred_energies_batch[0, :, 2]

                first_sample_true_don = don_labels[0, :]
                first_sample_true_ka = ka_labels[0, :]
                first_sample_true_drumroll = drumroll_labels[0, :]

                first_sample_length = lengths[
                    0
                ].item()  # Get the valid length of the first sample

                log_energy_plots_to_tensorboard(
                    writer,
                    "Train/Sample_0",
                    epoch,
                    first_sample_pred_don,
                    first_sample_pred_ka,
                    first_sample_pred_drumroll,
                    first_sample_true_don,
                    first_sample_true_ka,
                    first_sample_true_drumroll,
                    first_sample_length,
                    output_frame_hop_sec,
                )

        avg_train_loss = total_epoch_loss / len(train_loader)
        writer.add_scalar("Loss/Train_Avg", avg_train_loss, epoch)

        # Validation
        model.eval()
        total_val_loss = 0.0
        # Removed storage for classification logits/labels and confusion matrix components

        with torch.no_grad():
            for val_idx, batch in enumerate(
                tqdm(val_loader, desc=f"Val Epoch {epoch}")
            ):
                mel = batch["mel"].to(DEVICE)
                don_labels = batch["don_labels"].to(DEVICE)
                ka_labels = batch["ka_labels"].to(DEVICE)
                drumroll_labels = batch["drumroll_labels"].to(DEVICE)
                lengths = batch["lengths"].to(DEVICE)
                nps = batch["nps"].to(DEVICE)  # Ground truth NPS from batch
                difficulty = batch["difficulty"].to(DEVICE)  # Add difficulty
                level = batch["level"].to(DEVICE)  # Add level

                output_dict = model(
                    mel, lengths, nps, difficulty, level
                )  # Pass difficulty and level
                pred_energies_val_batch = output_dict["presence"]  # (B, T_out, 3)

                val_loss_input_batch = {
                    "don_labels": don_labels,
                    "ka_labels": ka_labels,
                    "drumroll_labels": drumroll_labels,
                    "lengths": lengths,
                }
                val_loss = criterion(output_dict, val_loss_input_batch)
                total_val_loss += val_loss.item()

                # Log plot for the first sample of the first batch in each validation epoch
                if val_idx == 0:
                    first_val_sample_pred_don = pred_energies_val_batch[0, :, 0]
                    first_val_sample_pred_ka = pred_energies_val_batch[0, :, 1]
                    first_val_sample_pred_drumroll = pred_energies_val_batch[0, :, 2]

                    first_val_sample_true_don = don_labels[0, :]
                    first_val_sample_true_ka = ka_labels[0, :]
                    first_val_sample_true_drumroll = drumroll_labels[0, :]

                    first_val_sample_length = lengths[0].item()

                    log_energy_plots_to_tensorboard(
                        writer,
                        "Eval/Sample_0",
                        epoch,
                        first_val_sample_pred_don,
                        first_val_sample_pred_ka,
                        first_val_sample_pred_drumroll,
                        first_val_sample_true_don,
                        first_val_sample_true_ka,
                        first_val_sample_true_drumroll,
                        first_val_sample_length,
                        output_frame_hop_sec,
                    )

                # Log ground truth NPS for reference during validation if needed
                # writer.add_scalar("NPS/GT_Val_Batch_Avg", nps.mean().item(), epoch * len(val_loader) + idx)

        avg_val_loss = total_val_loss / len(val_loader)
        writer.add_scalar("Loss/Val_Avg", avg_val_loss, epoch)

        # Log learning rate
        current_lr = optimizer.param_groups[0]["lr"]
        writer.add_scalar("LR/learning_rate", current_lr, epoch)

        # Log ground truth NPS from the last validation batch (or mean over epoch)
        if "nps" in batch:  # Check if nps is in the last batch
            writer.add_scalar(
                "NPS/GT_Val_LastBatch_Avg", batch["nps"].mean().item(), epoch
            )

        print(
            f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}"
        )

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            pat_count = 0
            torch.save(model.state_dict(), "best_model.pt")  # Changed model save name
            print(f"Saved new best model to best_model.pt at epoch {epoch}")
        else:
            pat_count += 1
            if pat_count >= patience:
                print("Early stopping!")
                break
    writer.close()

    model_id = "JacobLinCool/taiko-conformer-6"
    try:
        model.push_to_hub(model_id, commit_message="Upload trained model")
        upload_folder(
            repo_id=model_id,
            folder_path="runs",
            path_in_repo=".",
            commit_message="Upload training logs",
            ignore_patterns=["*.txt", "*.json", "*.csv"],
        )
        print(f"Model and logs uploaded to {model_id}")
    except Exception as e:
        print(f"Error uploading to Hugging Face Hub: {e}")
        print("Make sure you have the correct permissions and try again.")


if __name__ == "__main__":
    main()