# dataset_utils.py

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer, RobertaTokenizer, DebertaTokenizer
import pickle
import os

from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, TOKENIZER_PATH, LABEL_ENCODERS_PATH, METADATA_COLUMNS

class ComplianceDataset(Dataset):
    """
    Custom Dataset class for handling text and multi-output labels for PyTorch models.
    """
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.texts)

    def __getitem__(self, idx):
        """
        Retrieves a sample from the dataset at the given index.
        Tokenizes the text and converts labels to a PyTorch tensor.
        """
        text = str(self.texts[idx])
        # Tokenize the text, padding to max_length and truncating if longer.
        # return_tensors="pt" ensures PyTorch tensors are returned.
        inputs = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt"
        )
        # Squeeze removes the batch dimension (which is 1 here because we process one sample at a time)
        inputs = {key: val.squeeze(0) for key, val in inputs.items()}
        # Convert labels to a PyTorch long tensor
        labels = torch.tensor(self.labels[idx], dtype=torch.long)
        return inputs, labels

class ComplianceDatasetWithMetadata(Dataset):
    """
    Custom Dataset class for handling text, additional numerical metadata, and multi-output labels.
    Used for hybrid models combining text and tabular features.
    """
    def __init__(self, texts, metadata, labels, tokenizer, max_len):
        self.texts = texts
        self.metadata = metadata # Expects metadata as a NumPy array or list of lists
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.texts)

    def __getitem__(self, idx):
        """
        Retrieves a sample, its metadata, and labels from the dataset at the given index.
        Tokenizes text, converts metadata and labels to PyTorch tensors.
        """
        text = str(self.texts[idx])
        inputs = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt"
        )
        inputs = {key: val.squeeze(0) for key, val in inputs.items()}
        # Convert metadata for the current sample to a float tensor
        metadata = torch.tensor(self.metadata[idx], dtype=torch.float)
        labels = torch.tensor(self.labels[idx], dtype=torch.long)
        return inputs, metadata, labels

def load_and_preprocess_data(data_path):
    """
    Loads data from a CSV, fills missing values, and encodes categorical labels.
    Also handles converting specified METADATA_COLUMNS to numeric.

    Args:
        data_path (str): Path to the CSV data file.

    Returns:
        tuple: A tuple containing:
            - data (pd.DataFrame): The preprocessed DataFrame.
            - label_encoders (dict): A dictionary of LabelEncoder objects for each label column.
    """
    data = pd.read_csv(data_path)
    data.fillna("Unknown", inplace=True) # Fill any missing text values with "Unknown"

    # Convert metadata columns to numeric, coercing errors and filling NaNs with 0
    # This ensures metadata is suitable for neural networks.
    for col in METADATA_COLUMNS:
        if col in data.columns:
            data[col] = pd.to_numeric(data[col], errors='coerce').fillna(0) # Fill NaN with 0 or a suitable value

    label_encoders = {col: LabelEncoder() for col in LABEL_COLUMNS}
    for col in LABEL_COLUMNS:
        # Fit and transform each label column using its respective LabelEncoder
        data[col] = label_encoders[col].fit_transform(data[col])
    return data, label_encoders

def get_tokenizer(model_name):
    """
    Returns the appropriate Hugging Face tokenizer based on the model name.

    Args:
        model_name (str): The name of the pre-trained model (e.g., 'bert-base-uncased').

    Returns:
        transformers.PreTrainedTokenizer: The initialized tokenizer.
    """
    if "roberta" in model_name.lower():
        return RobertaTokenizer.from_pretrained(model_name)
    elif "bert" in model_name.lower():
        return BertTokenizer.from_pretrained(model_name)
    elif "deberta" in model_name.lower():
        return DebertaTokenizer.from_pretrained(model_name)
    else:
        raise ValueError(f"Unsupported tokenizer for model: {model_name}")

def save_label_encoders(label_encoders):
    """
    Saves a dictionary of label encoders to a pickle file.
    This is crucial for decoding predictions back to original labels.

    Args:
        label_encoders (dict): Dictionary of LabelEncoder objects.
    """
    with open(LABEL_ENCODERS_PATH, "wb") as f:
        pickle.dump(label_encoders, f)
    print(f"Label encoders saved to {LABEL_ENCODERS_PATH}")

def load_label_encoders():
    """
    Loads a dictionary of label encoders from a pickle file.

    Returns:
        dict: Loaded dictionary of LabelEncoder objects.
    """
    with open(LABEL_ENCODERS_PATH, "rb") as f:
        return pickle.load(f)
    print(f"Label encoders loaded from {LABEL_ENCODERS_PATH}")


def get_num_labels(label_encoders):
    """
    Returns a list containing the number of unique classes for each label column.
    This list is used to define the output dimensions of the model's classification heads.

    Args:
        label_encoders (dict): Dictionary of LabelEncoder objects.

    Returns:
        list: A list of integers, where each integer is the number of classes for a label.
    """
    return [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]