|
|
""" |
|
|
Model architecture for Bengali Memes Classification using Multimodal Attention Fusion (MAF) |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoModel |
|
|
|
|
|
|
|
|
class MultiheadAttention(nn.Module): |
|
|
"""Multi-head attention mechanism for cross-modal fusion""" |
|
|
|
|
|
def __init__(self, d_model, nhead, dropout=0.1): |
|
|
super(MultiheadAttention, self).__init__() |
|
|
self.attention = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
|
|
|
|
def forward(self, query, key, value, mask=None): |
|
|
output, _ = self.attention(query, key, value, attn_mask=mask) |
|
|
return output |
|
|
|
|
|
|
|
|
class MAF(nn.Module): |
|
|
"""Multimodal Attention Fusion (MAF) Model""" |
|
|
|
|
|
def __init__(self, clip_model, num_classes=4, num_heads=16): |
|
|
super(MAF, self).__init__() |
|
|
|
|
|
|
|
|
self.clip = clip_model |
|
|
self.visual_linear = nn.Linear(512, 768) |
|
|
|
|
|
|
|
|
self.bert = AutoModel.from_pretrained("sagorsarker/bangla-bert-base") |
|
|
|
|
|
|
|
|
self.attention = MultiheadAttention(d_model=768, nhead=num_heads) |
|
|
|
|
|
|
|
|
self.fc = nn.Sequential( |
|
|
nn.Linear(768 + 768 + 768, 128), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.2), |
|
|
nn.Linear(128, num_classes), |
|
|
) |
|
|
|
|
|
def forward(self, image_input, input_ids, attention_mask): |
|
|
|
|
|
image_features = self.clip(image_input) |
|
|
image_features = self.visual_linear(image_features) |
|
|
image_features = image_features.unsqueeze(1) |
|
|
image_features = F.adaptive_avg_pool1d(image_features.permute(0, 2, 1), 70).permute(0, 2, 1) |
|
|
|
|
|
|
|
|
bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
|
|
bert_output = bert_outputs.last_hidden_state |
|
|
|
|
|
|
|
|
attention_output = self.attention( |
|
|
query=image_features.permute(1, 0, 2), |
|
|
key=bert_output.permute(1, 0, 2), |
|
|
value=image_features.permute(1, 0, 2), |
|
|
mask=None |
|
|
) |
|
|
attention_output = attention_output.permute(1, 0, 2) |
|
|
|
|
|
|
|
|
fusion_input = torch.cat([attention_output, image_features, bert_output], dim=2) |
|
|
output = self.fc(fusion_input.mean(1)) |
|
|
return output |
|
|
|