File size: 6,168 Bytes
b81f538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# dataset_utils.py

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer, RobertaTokenizer, DebertaTokenizer
import pickle
import os

from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, TOKENIZER_PATH, LABEL_ENCODERS_PATH, METADATA_COLUMNS

class ComplianceDataset(Dataset):
    """
    Custom Dataset class for handling text and multi-output labels for PyTorch models.
    """
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.texts)

    def __getitem__(self, idx):
        """
        Retrieves a sample from the dataset at the given index.
        Tokenizes the text and converts labels to a PyTorch tensor.
        """
        text = str(self.texts[idx])
        # Tokenize the text, padding to max_length and truncating if longer.
        # return_tensors="pt" ensures PyTorch tensors are returned.
        inputs = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt"
        )
        # Squeeze removes the batch dimension (which is 1 here because we process one sample at a time)
        inputs = {key: val.squeeze(0) for key, val in inputs.items()}
        # Convert labels to a PyTorch long tensor
        labels = torch.tensor(self.labels[idx], dtype=torch.long)
        return inputs, labels

class ComplianceDatasetWithMetadata(Dataset):
    """
    Custom Dataset class for handling text, additional numerical metadata, and multi-output labels.
    Used for hybrid models combining text and tabular features.
    """
    def __init__(self, texts, metadata, labels, tokenizer, max_len):
        self.texts = texts
        self.metadata = metadata # Expects metadata as a NumPy array or list of lists
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.texts)

    def __getitem__(self, idx):
        """
        Retrieves a sample, its metadata, and labels from the dataset at the given index.
        Tokenizes text, converts metadata and labels to PyTorch tensors.
        """
        text = str(self.texts[idx])
        inputs = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt"
        )
        inputs = {key: val.squeeze(0) for key, val in inputs.items()}
        # Convert metadata for the current sample to a float tensor
        metadata = torch.tensor(self.metadata[idx], dtype=torch.float)
        labels = torch.tensor(self.labels[idx], dtype=torch.long)
        return inputs, metadata, labels

def load_and_preprocess_data(data_path):
    """
    Loads data from a CSV, fills missing values, and encodes categorical labels.
    Also handles converting specified METADATA_COLUMNS to numeric.

    Args:
        data_path (str): Path to the CSV data file.

    Returns:
        tuple: A tuple containing:
            - data (pd.DataFrame): The preprocessed DataFrame.
            - label_encoders (dict): A dictionary of LabelEncoder objects for each label column.
    """
    data = pd.read_csv(data_path)
    data.fillna("Unknown", inplace=True) # Fill any missing text values with "Unknown"

    # Convert metadata columns to numeric, coercing errors and filling NaNs with 0
    # This ensures metadata is suitable for neural networks.
    for col in METADATA_COLUMNS:
        if col in data.columns:
            data[col] = pd.to_numeric(data[col], errors='coerce').fillna(0) # Fill NaN with 0 or a suitable value

    label_encoders = {col: LabelEncoder() for col in LABEL_COLUMNS}
    for col in LABEL_COLUMNS:
        # Fit and transform each label column using its respective LabelEncoder
        data[col] = label_encoders[col].fit_transform(data[col])
    return data, label_encoders

def get_tokenizer(model_name):
    """
    Returns the appropriate Hugging Face tokenizer based on the model name.

    Args:
        model_name (str): The name of the pre-trained model (e.g., 'bert-base-uncased').

    Returns:
        transformers.PreTrainedTokenizer: The initialized tokenizer.
    """
    if "roberta" in model_name.lower():
        return RobertaTokenizer.from_pretrained(model_name)
    elif "bert" in model_name.lower():
        return BertTokenizer.from_pretrained(model_name)
    elif "deberta" in model_name.lower():
        return DebertaTokenizer.from_pretrained(model_name)
    else:
        raise ValueError(f"Unsupported tokenizer for model: {model_name}")

def save_label_encoders(label_encoders):
    """
    Saves a dictionary of label encoders to a pickle file.
    This is crucial for decoding predictions back to original labels.

    Args:
        label_encoders (dict): Dictionary of LabelEncoder objects.
    """
    with open(LABEL_ENCODERS_PATH, "wb") as f:
        pickle.dump(label_encoders, f)
    print(f"Label encoders saved to {LABEL_ENCODERS_PATH}")

def load_label_encoders():
    """
    Loads a dictionary of label encoders from a pickle file.

    Returns:
        dict: Loaded dictionary of LabelEncoder objects.
    """
    with open(LABEL_ENCODERS_PATH, "rb") as f:
        return pickle.load(f)
    print(f"Label encoders loaded from {LABEL_ENCODERS_PATH}")


def get_num_labels(label_encoders):
    """
    Returns a list containing the number of unique classes for each label column.
    This list is used to define the output dimensions of the model's classification heads.

    Args:
        label_encoders (dict): Dictionary of LabelEncoder objects.

    Returns:
        list: A list of integers, where each integer is the number of classes for a label.
    """
    return [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]