File size: 2,632 Bytes
ad944b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
# models/bert_model.py
import torch
import torch.nn as nn
from transformers import BertModel
from config import DROPOUT_RATE, BERT_MODEL_NAME # Import BERT_MODEL_NAME from config
class BertMultiOutputModel(nn.Module):
"""
BERT-based model for multi-output classification.
It uses a pre-trained BERT model as its backbone and adds a dropout layer
followed by separate linear classification heads for each target label.
"""
# Statically set tokenizer name for easy access in main.py
tokenizer_name = BERT_MODEL_NAME
def __init__(self, num_labels):
"""
Initializes the BertMultiOutputModel.
Args:
num_labels (list): A list where each element is the number of classes
for a corresponding label column.
"""
super(BertMultiOutputModel, self).__init__()
# Load the pre-trained BERT model.
# BertModel provides contextual embeddings and a pooled output for classification.
self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)
self.dropout = nn.Dropout(DROPOUT_RATE) # Dropout layer for regularization
# Create a list of classification heads, one for each label column.
# Each head is a linear layer mapping BERT's pooled output size to the number of classes for that label.
self.classifiers = nn.ModuleList([
nn.Linear(self.bert.config.hidden_size, n_classes) for n_classes in num_labels
])
def forward(self, input_ids, attention_mask):
"""
Performs the forward pass of the model.
Args:
input_ids (torch.Tensor): Tensor of token IDs (from tokenizer).
attention_mask (torch.Tensor): Tensor indicating attention (from tokenizer).
Returns:
list: A list of logit tensors, one for each classification head.
Each tensor has shape (batch_size, num_classes_for_that_label).
"""
# Pass input_ids and attention_mask through BERT.
# .pooler_output typically represents the hidden state of the [CLS] token,
# processed through a linear layer and tanh activation, often used for classification.
pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
# Apply dropout for regularization
pooled_output = self.dropout(pooled_output)
# Pass the pooled output through each classification head.
# The result is a list of logits (raw scores before softmax/sigmoid) for each label.
return [classifier(pooled_output) for classifier in self.classifiers] |