bengali-political-maf / model_architecture.py
lucius-40's picture
Upload model_architecture.py with huggingface_hub
88d0434 verified
"""
Model architecture for Bengali Memes Classification using Multimodal Attention Fusion (MAF)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel
class MultiheadAttention(nn.Module):
"""Multi-head attention mechanism for cross-modal fusion"""
def __init__(self, d_model, nhead, dropout=0.1):
super(MultiheadAttention, self).__init__()
self.attention = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
def forward(self, query, key, value, mask=None):
output, _ = self.attention(query, key, value, attn_mask=mask)
return output
class MAF(nn.Module):
"""Multimodal Attention Fusion (MAF) Model"""
def __init__(self, clip_model, num_classes=4, num_heads=16):
super(MAF, self).__init__()
# Visual feature extractor (CLIP)
self.clip = clip_model
self.visual_linear = nn.Linear(512, 768)
# Textual feature extractor (BERT)
self.bert = AutoModel.from_pretrained("sagorsarker/bangla-bert-base")
# Multihead attention
self.attention = MultiheadAttention(d_model=768, nhead=num_heads)
# Fully connected layers
self.fc = nn.Sequential(
nn.Linear(768 + 768 + 768, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, num_classes),
)
def forward(self, image_input, input_ids, attention_mask):
# Extract visual features using CLIP
image_features = self.clip(image_input)
image_features = self.visual_linear(image_features)
image_features = image_features.unsqueeze(1)
image_features = F.adaptive_avg_pool1d(image_features.permute(0, 2, 1), 70).permute(0, 2, 1)
# Extract BERT embeddings
bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
bert_output = bert_outputs.last_hidden_state
# Apply multihead attention
attention_output = self.attention(
query=image_features.permute(1, 0, 2),
key=bert_output.permute(1, 0, 2),
value=image_features.permute(1, 0, 2),
mask=None
)
attention_output = attention_output.permute(1, 0, 2)
# Concatenate and classify
fusion_input = torch.cat([attention_output, image_features, bert_output], dim=2)
output = self.fc(fusion_input.mean(1))
return output