STron
Added Roberta and Vit
7df2acb
raw
history blame
4.81 kB
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.transforms import v2
from PIL import Image
import pandas as pd
from tqdm import tqdm
# DEVICE SETUP
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("\nπŸš€ Using device:", device)
# Load tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# ----- HELPER FUNCTIONS -----
def get_bert_embedding(text):
inputs = tokenizer.encode_plus(
text, add_special_tokens=True,
return_tensors='pt', max_length=80,
truncation=True, padding='max_length'
)
return inputs['input_ids'].squeeze(0), inputs['attention_mask'].squeeze(0)
# ----- DATASET CLASS -----
class FakedditDataset(Dataset):
def __init__(self, df, text_field="clean_title", label_field="binary_label", image_id="id"):
self.df = df.reset_index(drop=True)
self.text_field = text_field
self.label_field = label_field
self.image_id = image_id
self.transform = v2.Compose([
v2.Resize((256, 256)),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
text = self.df.at[idx, self.text_field]
label = self.df.at[idx, self.label_field]
image_path = f"./val_images/{self.df.at[idx, self.image_id]}.jpg"
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
input_ids, attention_mask = get_bert_embedding(str(text))
return image, input_ids, attention_mask, torch.tensor(label, dtype=torch.long)
# ----- MODEL CLASSES -----
class SelfAttentionFusion(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.attn = nn.Linear(embed_dim * 2, 2)
self.softmax = nn.Softmax(dim=1)
def forward(self, x_text, x_img):
stacked = torch.stack([x_text, x_img], dim=1)
attn_weights = self.softmax(self.attn(torch.cat([x_text, x_img], dim=1))).unsqueeze(2)
fused = (attn_weights * stacked).sum(dim=1)
return fused
class BERTResNetClassifier(nn.Module):
def __init__(self, num_classes=2):
super().__init__()
self.image_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
self.fc_image = nn.Linear(1000, 512)
self.drop_img = nn.Dropout(0.3)
self.text_model = BertModel.from_pretrained("bert-base-uncased")
self.fc_text = nn.Linear(self.text_model.config.hidden_size, 512)
self.drop_text = nn.Dropout(0.3)
self.fusion = SelfAttentionFusion(512)
self.fc_final = nn.Linear(512, num_classes)
def forward(self, image, input_ids, attention_mask):
x_img = self.image_model(image)
x_img = self.drop_img(x_img)
x_img = self.fc_image(x_img)
x_text = self.text_model(input_ids=input_ids, attention_mask=attention_mask)[0][:, 0, :]
x_text = self.drop_text(x_text)
x_text = self.fc_text(x_text)
x_fused = self.fusion(x_text, x_img)
return self.fc_final(x_fused)
# ----- LOAD DATA -----
df = pd.read_csv("./val_output.csv")
print("πŸ“„ Loaded validation CSV with", len(df), "samples")
val_dataset = FakedditDataset(df)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
# ----- LOAD MODEL STATE -----
def remove_module_prefix(state_dict):
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k.replace('module.', '')
new_state_dict[name] = v
return new_state_dict
print("πŸ“¦ Loading model weights...")
state_dict = torch.load("state_dict.pth", map_location=device)
clean_state_dict = remove_module_prefix(state_dict)
model = BERTResNetClassifier(num_classes=2)
model.load_state_dict(clean_state_dict)
model.to(device)
model.eval()
print("βœ… Model loaded and ready for evaluation")
# ----- EVALUATION -----
correct = 0
total = 0
print("\nπŸ” Starting evaluation...")
with torch.no_grad():
for batch in tqdm(val_loader, desc="Evaluating"):
images, input_ids, attention_mask, labels = batch
images = images.to(device)
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
labels = labels.to(device)
outputs = model(images, input_ids, attention_mask)
preds = torch.argmax(outputs, dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
accuracy = correct / total * 100
print(f"\n🎯 Final Validation Accuracy: {accuracy:.2f}%")