File size: 13,892 Bytes
b81f538 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 |
# train_utils.py
import torch
import torch.nn as nn
from torch.optim import AdamW
from sklearn.metrics import classification_report
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
from tqdm import tqdm
import pandas as pd
import os
import joblib
from config import DEVICE, LABEL_COLUMNS, NUM_EPOCHS, LEARNING_RATE, MODEL_SAVE_DIR
def get_class_weights(data_df, field, label_encoder):
"""
Computes balanced class weights for a given target field.
These weights can be used in the loss function to mitigate class imbalance.
Args:
data_df (pd.DataFrame): The DataFrame containing the original (unencoded) label data.
field (str): The name of the label column for which to compute weights.
label_encoder (sklearn.preprocessing.LabelEncoder): The label encoder fitted for this field.
Returns:
torch.Tensor: A tensor of class weights for the specified field.
"""
# Get the original labels for the specified field
y = data_df[field].values
# Use label_encoder.transform directly - it will handle unseen labels
try:
y_encoded = label_encoder.transform(y)
except ValueError as e:
print(f"Warning: {e}")
print(f"Using only seen labels for class weights calculation")
# Filter out unseen labels
seen_labels = set(label_encoder.classes_)
y_filtered = [label for label in y if label in seen_labels]
y_encoded = label_encoder.transform(y_filtered)
# Ensure y_encoded is integer type
y_encoded = y_encoded.astype(int)
# Initialize counts for all possible classes
n_classes = len(label_encoder.classes_)
class_counts = np.zeros(n_classes, dtype=int)
# Count occurrences of each class
for i in range(n_classes):
class_counts[i] = np.sum(y_encoded == i)
# Calculate weights for all classes
total_samples = len(y_encoded)
class_weights = np.ones(n_classes) # Default weight of 1 for unseen classes
seen_classes = class_counts > 0
if np.any(seen_classes):
class_weights[seen_classes] = total_samples / (np.sum(seen_classes) * class_counts[seen_classes])
return torch.tensor(class_weights, dtype=torch.float)
def initialize_criterions(data_df, label_encoders):
"""
Initializes CrossEntropyLoss criteria for each label column, applying class weights.
Args:
data_df (pd.DataFrame): The original (unencoded) DataFrame. Used to compute class weights.
label_encoders (dict): Dictionary of LabelEncoder objects.
Returns:
dict: A dictionary where keys are label column names and values are
initialized `torch.nn.CrossEntropyLoss` objects.
"""
field_criterions = {}
for field in LABEL_COLUMNS:
# Get class weights for the current field
weights = get_class_weights(data_df, field, label_encoders[field])
# Initialize CrossEntropyLoss with the computed weights and move to the device
field_criterions[field] = torch.nn.CrossEntropyLoss(weight=weights.to(DEVICE))
return field_criterions
def train_model(model, loader, optimizer, field_criterions, epoch):
"""
Trains the given PyTorch model for one epoch.
Args:
model (torch.nn.Module): The model to train.
loader (torch.utils.data.DataLoader): DataLoader for training data.
optimizer (torch.optim.Optimizer): Optimizer for model parameters.
field_criterions (dict): Dictionary of loss functions for each label.
epoch (int): Current epoch number (for progress bar description).
Returns:
float: Average training loss for the epoch.
"""
model.train() # Set the model to training mode
total_loss = 0
# Use tqdm for a progress bar during training
tqdm_loader = tqdm(loader, desc=f"Epoch {epoch + 1} Training")
for batch in tqdm_loader:
# Unpack batch based on whether it contains metadata
if len(batch) == 2: # Text-only models (inputs, labels)
inputs, labels = batch
input_ids = inputs['input_ids'].to(DEVICE)
attention_mask = inputs['attention_mask'].to(DEVICE)
labels = labels.to(DEVICE)
# Forward pass through the model
outputs = model(input_ids, attention_mask)
elif len(batch) == 3: # Text + Metadata models (inputs, metadata, labels)
inputs, metadata, labels = batch
input_ids = inputs['input_ids'].to(DEVICE)
attention_mask = inputs['attention_mask'].to(DEVICE)
metadata = metadata.to(DEVICE)
labels = labels.to(DEVICE)
# Forward pass through the hybrid model
outputs = model(input_ids, attention_mask, metadata)
else:
raise ValueError("Unsupported batch format. Expected 2 or 3 items in batch.")
loss = 0
# Calculate total loss by summing loss for each label column
# `outputs` is a list of logits, one for each label column
for i, output_logits in enumerate(outputs):
# `labels[:, i]` gets the true labels for the i-th label column
# `field_criterions[LABEL_COLUMNS[i]]` selects the appropriate loss function
loss += field_criterions[LABEL_COLUMNS[i]](output_logits, labels[:, i])
optimizer.zero_grad() # Clear previous gradients
loss.backward() # Backpropagation
optimizer.step() # Update model parameters
total_loss += loss.item() # Accumulate loss
tqdm_loader.set_postfix(loss=loss.item()) # Update progress bar with current batch loss
return total_loss / len(loader) # Return average loss for the epoch
def evaluate_model(model, loader):
"""
Evaluates the given PyTorch model on a validation/test set.
Args:
model (torch.nn.Module): The model to evaluate.
loader (torch.utils.data.DataLoader): DataLoader for evaluation data.
Returns:
tuple: A tuple containing:
- reports (dict): Classification reports (dict format) for each label column.
- truths (list): List of true label arrays for each label column.
- predictions (list): List of predicted label arrays for each label column.
"""
model.eval() # Set the model to evaluation mode (disables dropout, batch norm updates, etc.)
# Initialize lists to store predictions and true labels for each output head
predictions = [[] for _ in range(len(LABEL_COLUMNS))]
truths = [[] for _ in range(len(LABEL_COLUMNS))]
with torch.no_grad(): # Disable gradient calculations during evaluation for efficiency
for batch in tqdm(loader, desc="Evaluation"):
if len(batch) == 2:
inputs, labels = batch
input_ids = inputs['input_ids'].to(DEVICE)
attention_mask = inputs['attention_mask'].to(DEVICE)
labels = labels.to(DEVICE)
outputs = model(input_ids, attention_mask)
elif len(batch) == 3:
inputs, metadata, labels = batch
input_ids = inputs['input_ids'].to(DEVICE)
attention_mask = inputs['attention_mask'].to(DEVICE)
metadata = metadata.to(DEVICE)
labels = labels.to(DEVICE)
outputs = model(input_ids, attention_mask, metadata)
else:
raise ValueError("Unsupported batch format.")
for i, output_logits in enumerate(outputs):
# Get the predicted class by taking the argmax of the logits
preds = torch.argmax(output_logits, dim=1).cpu().numpy()
predictions[i].extend(preds)
# Get the true labels for the current output head
truths[i].extend(labels[:, i].cpu().numpy())
reports = {}
# Generate classification report for each label column
for i, col in enumerate(LABEL_COLUMNS):
try:
# `zero_division=0` handles cases where a class might have no true or predicted samples
reports[col] = classification_report(truths[i], predictions[i], output_dict=True, zero_division=0)
except ValueError:
# Handle cases where a label might not appear in the validation set,
# which could cause classification_report to fail.
print(f"Warning: Could not generate classification report for {col}. Skipping.")
reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}}
return reports, truths, predictions
def summarize_metrics(metrics):
"""
Summarizes classification reports into a readable Pandas DataFrame.
Args:
metrics (dict): Dictionary of classification reports, as returned by `evaluate_model`.
Returns:
pd.DataFrame: A DataFrame summarizing precision, recall, f1-score, accuracy, and support for each field.
"""
summary = []
for field, report in metrics.items():
# Safely get metrics, defaulting to 0 if not present (e.g., for empty reports)
precision = report['weighted avg']['precision'] if 'weighted avg' in report else 0
recall = report['weighted avg']['recall'] if 'weighted avg' in report else 0
f1 = report['weighted avg']['f1-score'] if 'weighted avg' in report else 0
support = report['weighted avg']['support'] if 'weighted avg' in report else 0
accuracy = report['accuracy'] if 'accuracy' in report else 0 # Accuracy is usually top-level
summary.append({
"Field": field,
"Precision": precision,
"Recall": recall,
"F1-Score": f1,
"Accuracy": accuracy,
"Support": support
})
return pd.DataFrame(summary)
def save_model(model, model_name, save_format='pth'):
"""
Saves the state dictionary of a PyTorch model.
Args:
model (torch.nn.Module): The trained PyTorch model.
model_name (str): A descriptive name for the model (used for filename).
save_format (str): Format to save the model in ('pth' for PyTorch models, 'pickle' for traditional ML models).
"""
# Construct the save path dynamically relative to the project root
if save_format == 'pth':
model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}_model.pth")
torch.save(model.state_dict(), model_path)
elif save_format == 'pickle':
model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}.pkl")
joblib.dump(model, model_path)
else:
raise ValueError(f"Unsupported save format: {save_format}")
print(f"Model saved to {model_path}")
def load_model_state(model, model_name, model_class, num_labels, metadata_dim=0):
"""
Loads the state dictionary into a PyTorch model.
Args:
model (torch.nn.Module): An initialized model instance (architecture).
model_name (str): The name of the model to load.
model_class (class): The class of the model (e.g., RobertaMultiOutputModel).
num_labels (list): List of number of classes for each label.
metadata_dim (int): Dimensionality of metadata features, if applicable (default 0 for text-only).
Returns:
torch.nn.Module: The model with loaded state_dict, moved to the correct device, and set to eval mode.
"""
model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}_model.pth")
if not os.path.exists(model_path):
print(f"Warning: Model file not found at {model_path}. Returning a newly initialized model instance.")
# Re-initialize the model if not found, to ensure it has the correct architecture
if metadata_dim > 0:
return model_class(num_labels, metadata_dim=metadata_dim).to(DEVICE)
else:
return model_class(num_labels).to(DEVICE)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.to(DEVICE)
model.eval() # Set to evaluation mode after loading
print(f"Model loaded from {model_path}")
return model
def predict_probabilities(model, loader):
"""
Generates prediction probabilities for each label for a given model.
This is used for confidence scoring and feeding into a voting ensemble.
Args:
model (torch.nn.Module): The trained PyTorch model.
loader (torch.utils.data.DataLoader): DataLoader for the data to predict on.
Returns:
list: A list of lists of numpy arrays. Each inner list corresponds to a label column,
containing the softmax probabilities for each sample for that label.
"""
model.eval() # Set to evaluation mode
# List to store probabilities for each output head
all_probabilities = [[] for _ in range(len(LABEL_COLUMNS))]
with torch.no_grad():
for batch in tqdm(loader, desc="Predicting Probabilities"):
# Unpack batch, ignoring labels as we only need inputs
if len(batch) == 2:
inputs, _ = batch
input_ids = inputs['input_ids'].to(DEVICE)
attention_mask = inputs['attention_mask'].to(DEVICE)
outputs = model(input_ids, attention_mask)
elif len(batch) == 3:
inputs, metadata, _ = batch
input_ids = inputs['input_ids'].to(DEVICE)
attention_mask = inputs['attention_mask'].to(DEVICE)
metadata = metadata.to(DEVICE)
outputs = model(input_ids, attention_mask, metadata)
else:
raise ValueError("Unsupported batch format.")
for i, out_logits in enumerate(outputs):
# Apply softmax to logits to get probabilities
probs = torch.softmax(out_logits, dim=1).cpu().numpy()
all_probabilities[i].extend(probs)
return all_probabilities |