File size: 4,814 Bytes
7df2acb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.transforms import v2
from PIL import Image
import pandas as pd
from tqdm import tqdm
# DEVICE SETUP
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("\nπ Using device:", device)
# Load tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# ----- HELPER FUNCTIONS -----
def get_bert_embedding(text):
inputs = tokenizer.encode_plus(
text, add_special_tokens=True,
return_tensors='pt', max_length=80,
truncation=True, padding='max_length'
)
return inputs['input_ids'].squeeze(0), inputs['attention_mask'].squeeze(0)
# ----- DATASET CLASS -----
class FakedditDataset(Dataset):
def __init__(self, df, text_field="clean_title", label_field="binary_label", image_id="id"):
self.df = df.reset_index(drop=True)
self.text_field = text_field
self.label_field = label_field
self.image_id = image_id
self.transform = v2.Compose([
v2.Resize((256, 256)),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
text = self.df.at[idx, self.text_field]
label = self.df.at[idx, self.label_field]
image_path = f"./val_images/{self.df.at[idx, self.image_id]}.jpg"
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
input_ids, attention_mask = get_bert_embedding(str(text))
return image, input_ids, attention_mask, torch.tensor(label, dtype=torch.long)
# ----- MODEL CLASSES -----
class SelfAttentionFusion(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.attn = nn.Linear(embed_dim * 2, 2)
self.softmax = nn.Softmax(dim=1)
def forward(self, x_text, x_img):
stacked = torch.stack([x_text, x_img], dim=1)
attn_weights = self.softmax(self.attn(torch.cat([x_text, x_img], dim=1))).unsqueeze(2)
fused = (attn_weights * stacked).sum(dim=1)
return fused
class BERTResNetClassifier(nn.Module):
def __init__(self, num_classes=2):
super().__init__()
self.image_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
self.fc_image = nn.Linear(1000, 512)
self.drop_img = nn.Dropout(0.3)
self.text_model = BertModel.from_pretrained("bert-base-uncased")
self.fc_text = nn.Linear(self.text_model.config.hidden_size, 512)
self.drop_text = nn.Dropout(0.3)
self.fusion = SelfAttentionFusion(512)
self.fc_final = nn.Linear(512, num_classes)
def forward(self, image, input_ids, attention_mask):
x_img = self.image_model(image)
x_img = self.drop_img(x_img)
x_img = self.fc_image(x_img)
x_text = self.text_model(input_ids=input_ids, attention_mask=attention_mask)[0][:, 0, :]
x_text = self.drop_text(x_text)
x_text = self.fc_text(x_text)
x_fused = self.fusion(x_text, x_img)
return self.fc_final(x_fused)
# ----- LOAD DATA -----
df = pd.read_csv("./val_output.csv")
print("π Loaded validation CSV with", len(df), "samples")
val_dataset = FakedditDataset(df)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
# ----- LOAD MODEL STATE -----
def remove_module_prefix(state_dict):
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k.replace('module.', '')
new_state_dict[name] = v
return new_state_dict
print("π¦ Loading model weights...")
state_dict = torch.load("state_dict.pth", map_location=device)
clean_state_dict = remove_module_prefix(state_dict)
model = BERTResNetClassifier(num_classes=2)
model.load_state_dict(clean_state_dict)
model.to(device)
model.eval()
print("β
Model loaded and ready for evaluation")
# ----- EVALUATION -----
correct = 0
total = 0
print("\nπ Starting evaluation...")
with torch.no_grad():
for batch in tqdm(val_loader, desc="Evaluating"):
images, input_ids, attention_mask, labels = batch
images = images.to(device)
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
labels = labels.to(device)
outputs = model(images, input_ids, attention_mask)
preds = torch.argmax(outputs, dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
accuracy = correct / total * 100
print(f"\nπ― Final Validation Accuracy: {accuracy:.2f}%") |