STron commited on
Commit
7df2acb
Β·
0 Parent(s):

Added Roberta and Vit

Browse files
Files changed (8) hide show
  1. .gitignore +4 -0
  2. app.py +354 -0
  3. get_data.py +157 -0
  4. readme.md +12 -0
  5. requirements.txt +0 -0
  6. test.py +106 -0
  7. train_model.ipynb +0 -0
  8. validate.py +138 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .gradio/
2
+ __pycache__/
3
+
4
+ .gitattributes
app.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import onnxruntime as ort
3
+ from transformers import RobertaTokenizer, ViTImageProcessor
4
+ from PIL import Image
5
+ import numpy as np
6
+ import torch
7
+ <<<<<<< HEAD
8
+ import os
9
+ import time
10
+ import logging
11
+
12
+ # Setup logging
13
+ logging.basicConfig(level=logging.INFO,
14
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
15
+ logger = logging.getLogger(__name__)
16
+
17
+ vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
18
+ tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
19
+
20
+ model_path = "./multimodal_model.onnx"
21
+ try:
22
+ if not os.path.exists(model_path):
23
+ raise FileNotFoundError(f"ONNX model not found at {model_path}")
24
+
25
+ logger.info(f"Loading ONNX model from {model_path}")
26
+ sess_options = ort.SessionOptions()
27
+ sess_options.log_severity_level = 0
28
+ ort_session = ort.InferenceSession(
29
+ model_path,
30
+ sess_options=sess_options,
31
+ providers=['CPUExecutionProvider']
32
+ )
33
+ logger.info("ONNX model loaded successfully")
34
+
35
+ input_names = [input.name for input in ort_session.get_inputs()]
36
+ input_shapes = {input.name: input.shape for input in ort_session.get_inputs()}
37
+ output_names = [output.name for output in ort_session.get_outputs()]
38
+
39
+ logger.info(f"Model inputs: {input_names} with shapes {input_shapes}")
40
+ logger.info(f"Model outputs: {output_names}")
41
+
42
+ except Exception as e:
43
+ logger.error(f"Error loading ONNX model: {e}")
44
+ raise
45
+
46
+ labels = ["Real", "Real Text with fake image", "Fake"]
47
+
48
+ def softmax(x):
49
+ """Compute softmax values for each sets of scores in x."""
50
+ e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
51
+ return e_x / e_x.sum(axis=1, keepdims=True)
52
+
53
+ def image_with_prediction(img, label, confidence):
54
+ """Return the original image with an overlay showing the prediction"""
55
+ from PIL import Image, ImageDraw, ImageFont
56
+
57
+ img_copy = img.copy()
58
+ draw = ImageDraw.Draw(img_copy)
59
+
60
+ width, height = img_copy.size
61
+
62
+ overlay = Image.new('RGBA', (width, 40), (0, 0, 0, 150))
63
+ img_copy.paste(overlay, (0, height-40), overlay)
64
+
65
+ text = f"{label}: {confidence:.1%}"
66
+
67
+ try:
68
+ font = ImageFont.truetype("arial.ttf", 20)
69
+ except IOError:
70
+ font = ImageFont.load_default()
71
+
72
+ try:
73
+ text_width = draw.textlength(text, font=font)
74
+ except AttributeError:
75
+ text_width = font.getsize(text)[0] if hasattr(font, 'getsize') else 200
76
+
77
+ text_position = ((width - text_width) // 2, height - 35)
78
+ draw.text(text_position, text, fill=(255, 255, 255), font=font)
79
+
80
+ return img_copy
81
+
82
+ def predict_news(text, image):
83
+ if text is None or text.strip() == "":
84
+ return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, "Please enter some text to analyze."
85
+
86
+ if image is None:
87
+ return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, "Please upload an image to analyze."
88
+
89
+ try:
90
+ logger.info(f"Processing text: {text[:50]}...")
91
+ logger.info(f"Processing image size: {image.size}")
92
+
93
+ # Process text input
94
+ inputs = tokenizer.encode_plus(text, add_special_tokens = True, return_tensors='np', max_length=80, truncation=True, padding='max_length')
95
+ =======
96
+ from torchvision.transforms import v2
97
+ import os
98
+ import time
99
+
100
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
101
+ model_path = "./multimodal_model_optimized.onnx"
102
+ ort_session = ort.InferenceSession(model_path)
103
+
104
+ transform = v2.Compose([
105
+ v2.Resize((256, 256)),
106
+ v2.ToImage(),
107
+ v2.ToDtype(torch.float32, scale=True),
108
+ v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
109
+ ])
110
+
111
+ labels = ["Fake", "Real"]
112
+
113
+ def predict_news(text, image):
114
+ if text is None or text.strip() == "":
115
+ return {labels[0]: 0.0, labels[1]: 0.0}, None, "Please enter some text to analyze."
116
+
117
+ if image is None:
118
+ return {labels[0]: 0.0, labels[1]: 0.0}, None, "Please upload an image to analyze."
119
+
120
+ try:
121
+ inputs = tokenizer.encode_plus(
122
+ text,
123
+ add_special_tokens=True,
124
+ return_tensors='np',
125
+ max_length=80,
126
+ truncation=True,
127
+ padding='max_length'
128
+ )
129
+ >>>>>>> 585173c095c709c00b2ab290bb8d69553911f0d5
130
+
131
+ input_ids = inputs['input_ids']
132
+ attention_mask = inputs['attention_mask']
133
+
134
+ <<<<<<< HEAD
135
+ logger.info(f"Input IDs shape: {input_ids.shape}")
136
+ logger.info(f"Attention mask shape: {attention_mask.shape}")
137
+
138
+ # Process image input
139
+ image_processed = vit_processor(images=image, return_tensors="np")["pixel_values"]
140
+ logger.info(f"Processed image shape: {image_processed.shape}")
141
+
142
+ ort_inputs = {}
143
+ for input_meta in ort_session.get_inputs():
144
+ input_name = input_meta.name
145
+ if 'ids' in input_name.lower() or input_name == 'text_input_ids':
146
+ ort_inputs[input_name] = input_ids
147
+ elif 'mask' in input_name.lower() or input_name == 'text_attention_mask':
148
+ ort_inputs[input_name] = attention_mask
149
+ elif 'image' in input_name.lower() or input_name == 'image_input':
150
+ ort_inputs[input_name] = image_processed
151
+
152
+ logger.info(f"ONNX input keys: {list(ort_inputs.keys())}")
153
+
154
+ # Run inference
155
+ start_time = time.time()
156
+ logger.info("Starting inference")
157
+ outputs = ort_session.run(None, ort_inputs)
158
+ inference_time = time.time() - start_time
159
+ logger.info(f"Inference completed in {inference_time:.3f}s")
160
+
161
+ # Process model outputs
162
+ logits = outputs[0]
163
+ logger.info(f"Raw output shape: {logits.shape}, values: {logits}")
164
+
165
+ probs = softmax(logits)[0]
166
+ logger.info(f"Probabilities: {probs}")
167
+
168
+ pred_idx = int(np.argmax(probs))
169
+ confidence = float(probs[pred_idx])
170
+
171
+ if pred_idx == 1:
172
+ color = "orange"
173
+ message = f"This content appears to be **REAL TEXT WITH FAKE IMAGE** with {confidence:.1%} confidence."
174
+ elif pred_idx == 2:
175
+ color = "red"
176
+ message = f"This content appears to contain **FAKE** with {confidence:.1%} confidence."
177
+ else:
178
+ color = "green"
179
+ message = f"This content appears to be **REAL** with {confidence:.1%} confidence."
180
+
181
+ analysis = f"""
182
+ <div style='text-align: center; padding: 10px; background-color: {color}15; border-radius: 5px; margin-top: 10px;'>
183
+ <span style='font-size: 18px; color: {color}; font-weight: bold;'>{message}</span>
184
+ <p>Inference time: {inference_time:.3f} seconds</p>
185
+ </div>
186
+ """
187
+
188
+ result = {
189
+ labels[0]: float(probs[0]),
190
+ labels[1]: float(probs[1]),
191
+ labels[2]: float(probs[2])
192
+ }
193
+
194
+ interpretation = image_with_prediction(image, labels[pred_idx], confidence)
195
+
196
+ return result, interpretation, analysis
197
+
198
+ except Exception as e:
199
+ logger.error(f"Error during analysis: {str(e)}", exc_info=True)
200
+ return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, f"Error during analysis: {str(e)}"
201
+ =======
202
+ image_tensor = transform(image).numpy()
203
+
204
+ ort_inputs = {
205
+ "input_ids": input_ids,
206
+ "attention_mask": attention_mask,
207
+ "image": image_tensor.reshape(1, 3, 256, 256) # Ensure correct shape
208
+ }
209
+
210
+ start_time = time.time()
211
+ outputs = ort_session.run(None, ort_inputs)
212
+ inference_time = time.time() - start_time
213
+
214
+ logits = outputs[0]
215
+ probs = softmax(logits)[0]
216
+
217
+ pred_idx = int(np.argmax(probs))
218
+ confidence = float(probs[pred_idx])
219
+
220
+ if pred_idx == 1: # Real
221
+ color = "green"
222
+ message = f"This content appears to be **REAL** with {confidence:.1%} confidence."
223
+ else: # Fake
224
+ color = "red"
225
+ message = f"This content appears to be **FAKE** with {confidence:.1%} confidence."
226
+
227
+ analysis = f"""
228
+ <div style='text-align: center; padding: 10px; background-color: {color}15; border-radius: 5px; margin-top: 10px;'>
229
+ <span style='font-size: 18px; color: {color}; font-weight: bold;'>{message}</span>
230
+ <p>Inference time: {inference_time:.3f} seconds</p>
231
+ </div>
232
+ """
233
+
234
+ result = {labels[0]: float(probs[0]), labels[1]: float(probs[1])}
235
+
236
+ interpretation = image_with_prediction(image, labels[pred_idx], confidence)
237
+
238
+ return result, interpretation, analysis
239
+
240
+ except Exception as e:
241
+ return {labels[0]: 0.0, labels[1]: 0.0}, None, f"Error during analysis: {str(e)}"
242
+
243
+ def softmax(x):
244
+ """Compute softmax values for each sets of scores in x."""
245
+ e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
246
+ return e_x / e_x.sum(axis=1, keepdims=True)
247
+
248
+ def image_with_prediction(img, label, confidence):
249
+ """Return the original image with an overlay showing the prediction"""
250
+ from PIL import Image, ImageDraw, ImageFont
251
+ import io
252
+
253
+ img_copy = img.copy()
254
+ draw = ImageDraw.Draw(img_copy)
255
+
256
+ width, height = img_copy.size
257
+
258
+ overlay = Image.new('RGBA', (width, 40), (0, 0, 0, 150))
259
+ img_copy.paste(overlay, (0, height-40), overlay)
260
+
261
+ text = f"{label}: {confidence:.1%}"
262
+
263
+ try:
264
+ font = ImageFont.truetype("arial.ttf", 20)
265
+ except IOError:
266
+ font = ImageFont.load_default()
267
+
268
+ text_width = draw.textlength(text, font=font)
269
+ text_position = ((width - text_width) // 2, height - 35)
270
+ draw.text(text_position, text, fill=(255, 255, 255), font=font)
271
+
272
+ return img_copy
273
+ >>>>>>> 585173c095c709c00b2ab290bb8d69553911f0d5
274
+
275
+ examples = [
276
+ ["COVID-19 vaccine causes severe side effects in 80% of recipients", "https://images.unsplash.com/photo-1605289982774-9a6fef564df8?q=80&w=1000&auto=format&fit=crop"],
277
+ ["Scientists discover new species of deep-sea fish", "https://images.unsplash.com/photo-1524704796725-9fc3044a58b2?q=80&w=1000&auto=format&fit=crop"],
278
+ ]
279
+
280
+ <<<<<<< HEAD
281
+ # Build Gradio interface
282
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
283
+ gr.Markdown(
284
+ """
285
+ # πŸ“° Fake News Detector (BERT + VIT)
286
+ =======
287
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
288
+ gr.Markdown(
289
+ """
290
+ # πŸ“° Fake News Detector (BERT + ResNet)
291
+ >>>>>>> 585173c095c709c00b2ab290bb8d69553911f0d5
292
+
293
+ This multimodal AI system analyzes both text and images to detect potentially fake news content.
294
+ Upload an image and enter a news headline to see if the combination is likely to be real or fake news.
295
+ """
296
+ )
297
+
298
+ with gr.Row():
299
+ with gr.Column(scale=1):
300
+ text_input = gr.Textbox(
301
+ label="News Headline / Text",
302
+ placeholder="Enter the news headline or text here...",
303
+ lines=3
304
+ )
305
+ image_input = gr.Image(type="pil", label="Associated Image")
306
+
307
+ analyze_btn = gr.Button("Analyze Content", variant="primary")
308
+
309
+ with gr.Column(scale=1):
310
+ label_output = gr.Label(label="Prediction Probabilities")
311
+ image_output = gr.Image(type="pil", label="Visual Analysis")
312
+ analysis_html = gr.HTML(label="Analysis")
313
+
314
+ gr.Examples(
315
+ examples=examples,
316
+ inputs=[text_input, image_input],
317
+ outputs=[label_output, image_output, analysis_html],
318
+ fn=predict_news,
319
+ cache_examples=True,
320
+ )
321
+
322
+ gr.Markdown(
323
+ """
324
+ ### How it works
325
+
326
+ This system combines:
327
+ <<<<<<< HEAD
328
+ - **RoBERTa**: Analyzes the textual content
329
+ - **ViT**: Processes the image data
330
+ - **Multimodal Fusion**: Combines both signals to make a prediction
331
+
332
+ The model was trained on the Fakeddit dataset containing real and fake news pairs with both text and images.
333
+ =======
334
+ - **BERT**: Analyzes the textual content
335
+ - **ResNet**: Processes the image data
336
+ - **Multimodal Fusion**: Combines both signals to make a prediction
337
+
338
+ The model was trained on a dataset of real and fake news pairs containing both text and images.
339
+ >>>>>>> 585173c095c709c00b2ab290bb8d69553911f0d5
340
+ """
341
+ )
342
+
343
+ analyze_btn.click(
344
+ predict_news,
345
+ inputs=[text_input, image_input],
346
+ outputs=[label_output, image_output, analysis_html]
347
+ )
348
+
349
+ if __name__ == "__main__":
350
+ <<<<<<< HEAD
351
+ logger.info("Starting Gradio application")
352
+ =======
353
+ >>>>>>> 585173c095c709c00b2ab290bb8d69553911f0d5
354
+ demo.launch()
get_data.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import pandas as pd
4
+ import os
5
+ from urllib import request
6
+ from concurrent.futures import ThreadPoolExecutor, as_completed
7
+ from tqdm import tqdm
8
+ from sklearn.utils import resample
9
+ from torchvision.transforms import v2
10
+ from PIL import Image
11
+
12
+ def load_and_prepare_data(file_path):
13
+ df = pd.read_csv(file_path, sep="\t")
14
+ df.drop(['2_way_label', '3_way_label', 'title'], axis=1, inplace=True)
15
+ df['binary_label'] = df['6_way_label'].apply(lambda x: 0 if x == 0 else 1)
16
+ df.reset_index(drop=True, inplace=True)
17
+ return df
18
+
19
+ def balance_data(df, max_samples_per_class=35000):
20
+ df_with_image = df[df['hasImage'] == True]
21
+ df_class_0 = df_with_image[df_with_image['binary_label'] == 0]
22
+ df_class_1 = df_with_image[df_with_image['binary_label'] == 1]
23
+ target_count = min(len(df_class_0), len(df_class_1), max_samples_per_class)
24
+
25
+ df_sample_0 = resample(df_class_0, replace=False, n_samples=target_count, random_state=42)
26
+ df_sample_1 = resample(df_class_1, replace=False, n_samples=target_count, random_state=42)
27
+
28
+ df_balanced = pd.concat([df_sample_0, df_sample_1])
29
+ df_balanced = df_balanced.sample(frac=1, random_state=42).reset_index(drop=True)
30
+ df_balanced = df_balanced.replace(np.nan, '', regex=True)
31
+ df_balanced.fillna('', inplace=True)
32
+ return df_balanced, df_class_1[~df_class_1['id'].isin(df_sample_1['id'])]
33
+
34
+ def ensure_directory(path):
35
+ if not os.path.exists(path):
36
+ os.makedirs(path)
37
+
38
+ def download_image(row, image_dir):
39
+ index = row[0]
40
+ row = row[1]
41
+ if row["hasImage"] and row["image_url"] not in ["", "nan"]:
42
+ image_url = row["image_url"]
43
+ path = os.path.join(image_dir, f"{row['id']}.jpg")
44
+ try:
45
+ with open(path, 'wb') as f:
46
+ f.write(request.urlopen(image_url, timeout=5).read())
47
+ except:
48
+ return index
49
+ return None
50
+
51
+ def download_images_fast(df, image_dir, max_workers=16):
52
+ failed_indices = []
53
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
54
+ futures = [executor.submit(download_image, row, image_dir) for row in df.iterrows()]
55
+ for f in tqdm(as_completed(futures), total=len(futures), desc="Downloading images"):
56
+ result = f.result()
57
+ if result is not None:
58
+ failed_indices.append(result)
59
+ df.drop(index=failed_indices, inplace=True)
60
+ df.reset_index(drop=True, inplace=True)
61
+ return df
62
+
63
+ def validate_image(row, image_dir):
64
+ index = row[0]
65
+ row = row[1]
66
+ image_path = os.path.join(image_dir, f"{row['id']}.jpg")
67
+ try:
68
+ with Image.open(image_path) as img:
69
+ img.verify()
70
+ return None
71
+ except:
72
+ if os.path.exists(image_path):
73
+ os.remove(image_path)
74
+ return index
75
+
76
+ def validate_images_fast(df, image_dir, max_workers=16):
77
+ corrupted_indices = []
78
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
79
+ futures = [executor.submit(validate_image, row, image_dir) for row in df.iterrows()]
80
+ for f in tqdm(as_completed(futures), total=len(futures), desc="Validating images"):
81
+ result = f.result()
82
+ if result is not None:
83
+ corrupted_indices.append(result)
84
+ df.drop(index=corrupted_indices, inplace=True)
85
+ df.reset_index(drop=True, inplace=True)
86
+ return df, corrupted_indices
87
+
88
+ def resize_images(df, image_dir, size=(256, 256)):
89
+ resize_transform = v2.Resize(size)
90
+ for index, row in tqdm(df.iterrows(), total=len(df), desc="Resizing images"):
91
+ image_path = os.path.join(image_dir, f"{row['id']}.jpg")
92
+ try:
93
+ image = Image.open(image_path).convert("RGB")
94
+ resized_image = resize_transform(image)
95
+ resized_image.save(image_path)
96
+ except Exception as e:
97
+ print(f"Failed to resize {image_path}: {e}")
98
+ df.drop(index=index, inplace=True)
99
+ df.reset_index(drop=True, inplace=True)
100
+ return df
101
+
102
+ def augment_minority_class(df_balanced, df_remaining_class_1, image_dir, batch_size=4000):
103
+ needed = len(df_balanced[df_balanced['binary_label'] == 0]) - len(df_balanced[df_balanced['binary_label'] == 1])
104
+ collected = []
105
+ print(f"Need to add {needed} more class 1 samples...")
106
+ while len(collected) < needed and len(df_remaining_class_1) > 0:
107
+ batch = df_remaining_class_1.sample(n=min(batch_size, len(df_remaining_class_1)), random_state=42)
108
+ df_remaining_class_1 = df_remaining_class_1.drop(batch.index)
109
+
110
+ print(f"\nπŸŒ€ Downloading batch of {len(batch)} images...")
111
+ batch = download_images_fast(batch.copy(), image_dir)
112
+
113
+ print(f"πŸ”Ž Validating downloaded images...")
114
+ valid_batch, _ = validate_images_fast(batch.copy(), image_dir)
115
+
116
+ print(f"🎨 Resizing valid images...")
117
+ valid_batch = resize_images(valid_batch, image_dir)
118
+
119
+ collected.append(valid_batch)
120
+
121
+ if sum(len(df) for df in collected) >= needed:
122
+ break
123
+
124
+ df_extra_class_1 = pd.concat(collected).reset_index(drop=True)
125
+ df_extra_class_1 = df_extra_class_1.sample(n=needed, random_state=42).reset_index(drop=True)
126
+
127
+ df_balanced_updated = pd.concat([df_balanced, df_extra_class_1], ignore_index=True)
128
+ df_balanced_updated = df_balanced_updated.sample(frac=1, random_state=42).reset_index(drop=True)
129
+ return df_balanced_updated
130
+
131
+ def main(args):
132
+ ensure_directory(args.image_dir)
133
+
134
+ df = load_and_prepare_data(args.tsv_path)
135
+ df_balanced, df_remaining_class_1 = balance_data(df, max_samples_per_class=args.max_samples)
136
+ df_balanced.to_csv("./df.csv", index=False)
137
+
138
+ df_balanced = download_images_fast(df_balanced, args.image_dir)
139
+ print(f"βœ… Finished downloading. Remaining rows: {len(df_balanced)}")
140
+ df_balanced.to_csv("./df_balanced.csv", index=False)
141
+
142
+ df_balanced, _ = validate_images_fast(df_balanced, args.image_dir)
143
+ df_balanced = resize_images(df_balanced, args.image_dir)
144
+ df_balanced.to_csv("./df_balanced_resized.csv", index=False)
145
+
146
+ df_balanced_updated = augment_minority_class(df_balanced, df_remaining_class_1, args.image_dir)
147
+ df_balanced_updated.to_csv(args.output_csv, index=False)
148
+
149
+ if __name__ == "__main__":
150
+ parser = argparse.ArgumentParser(description="Image Dataset Preprocessing Pipeline")
151
+ parser.add_argument('--tsv_path', type=str, default="./multimodal_train.tsv", help='Path to the input TSV file')
152
+ parser.add_argument('--image_dir', type=str, default="./images", help='Directory to save images')
153
+ parser.add_argument('--output_csv', type=str, default="./final_output.csv", help='Path to save final balanced CSV')
154
+ parser.add_argument('--max_samples', type=int, default=35000, help='Maximum number of samples per class')
155
+ parser.add_argument('--skip_existing', action='store_true', help='Skip downloading if image already exists')
156
+ args = parser.parse_args()
157
+ main(args)
readme.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Fake News Detection Demo
3
+ emoji: πŸ“š
4
+ colorFrom: blue
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 5.29.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc-by-nc-4.0
11
+ short_description: Multimodal fake news classification on fakeddit dataset.
12
+ ---
requirements.txt ADDED
Binary file (934 Bytes). View file
 
test.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import BertTokenizer, BertModel
3
+ import torch.nn as nn
4
+ from torchvision.models import resnet50, ResNet50_Weights
5
+ from PIL import Image
6
+ from torchvision.transforms import v2
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ print("\nπŸš€ Using device:", device)
10
+
11
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
12
+
13
+ def get_bert_embedding(text):
14
+ inputs = tokenizer.encode_plus(
15
+ text, add_special_tokens=True,
16
+ return_tensors='pt', max_length=80,
17
+ truncation=True, padding='max_length'
18
+ )
19
+ return inputs['input_ids'].squeeze(0), inputs['attention_mask'].squeeze(0)
20
+
21
+ class SelfAttentionFusion(nn.Module):
22
+ def __init__(self, embed_dim):
23
+ super().__init__()
24
+ self.attn = nn.Linear(embed_dim * 2, 2)
25
+ self.softmax = nn.Softmax(dim=1)
26
+
27
+ def forward(self, x_text, x_img):
28
+ stacked = torch.stack([x_text, x_img], dim=1)
29
+ attn_weights = self.softmax(self.attn(torch.cat([x_text, x_img], dim=1))).unsqueeze(2)
30
+ fused = (attn_weights * stacked).sum(dim=1)
31
+ return fused
32
+
33
+ class BERTResNetClassifier(nn.Module):
34
+ def __init__(self, num_classes=2):
35
+ super().__init__()
36
+ self.image_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
37
+ self.fc_image = nn.Linear(1000, 512)
38
+ self.drop_img = nn.Dropout(0.3)
39
+
40
+ self.text_model = BertModel.from_pretrained("bert-base-uncased")
41
+ self.fc_text = nn.Linear(self.text_model.config.hidden_size, 512)
42
+ self.drop_text = nn.Dropout(0.3)
43
+
44
+ self.fusion = SelfAttentionFusion(512)
45
+ self.fc_final = nn.Linear(512, num_classes)
46
+
47
+ def forward(self, image, input_ids, attention_mask):
48
+ x_img = self.image_model(image)
49
+ x_img = self.drop_img(x_img)
50
+ x_img = self.fc_image(x_img)
51
+
52
+ x_text = self.text_model(input_ids=input_ids, attention_mask=attention_mask)[0][:, 0, :]
53
+ x_text = self.drop_text(x_text)
54
+ x_text = self.fc_text(x_text)
55
+
56
+ x_fused = self.fusion(x_text, x_img)
57
+ return self.fc_final(x_fused)
58
+
59
+ def remove_module_prefix(state_dict):
60
+ from collections import OrderedDict
61
+ new_state_dict = OrderedDict()
62
+ for k, v in state_dict.items():
63
+ name = k.replace('module.', '')
64
+ new_state_dict[name] = v
65
+ return new_state_dict
66
+
67
+ print("πŸ“¦ Loading model weights...")
68
+ state_dict = torch.load("state_dict.pth", map_location=device)
69
+ clean_state_dict = remove_module_prefix(state_dict)
70
+
71
+ model = BERTResNetClassifier(num_classes=2)
72
+ model.load_state_dict(clean_state_dict)
73
+ model.to(device)
74
+ model.eval()
75
+ print("βœ… Model loaded successfully.")
76
+
77
+ text = "The Traditionalists - Whole Roasted Kitten"
78
+ image_address = "./image.png"
79
+
80
+ image = Image.open(image_address).convert("RGB")
81
+ transform = v2.Compose([
82
+ v2.Resize((256, 256)),
83
+ v2.ToImage(),
84
+ v2.ToDtype(torch.float32, scale=True),
85
+ v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
86
+ ])
87
+ image = transform(image).unsqueeze(0)
88
+ input_ids, attention_mask = get_bert_embedding(text)
89
+ input_ids = input_ids.unsqueeze(0)
90
+ attention_mask = attention_mask.unsqueeze(0)
91
+
92
+ image.to(device)
93
+ attention_mask.to(device)
94
+ input_ids.to(device)
95
+
96
+ output = model(image, input_ids, attention_mask)
97
+
98
+ # PRINT OUTPUT
99
+ classes = ["Fake", "Real"]
100
+
101
+ probabilities = torch.nn.functional.softmax(output, dim=1)
102
+ prob_values = [f"{prob:.2%}" for prob in probabilities[0].tolist()]
103
+ print("Probabilities:", prob_values)
104
+
105
+ prediction_id = torch.argmax(output, dim=1).item()
106
+ print("Prediction:", classes[prediction_id])
train_model.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
validate.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from transformers import BertTokenizer, BertModel
5
+ from torchvision.models import resnet50, ResNet50_Weights
6
+ from torchvision.transforms import v2
7
+ from PIL import Image
8
+ import pandas as pd
9
+ from tqdm import tqdm
10
+
11
+ # DEVICE SETUP
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ print("\nπŸš€ Using device:", device)
14
+
15
+ # Load tokenizer
16
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
17
+
18
+ # ----- HELPER FUNCTIONS -----
19
+ def get_bert_embedding(text):
20
+ inputs = tokenizer.encode_plus(
21
+ text, add_special_tokens=True,
22
+ return_tensors='pt', max_length=80,
23
+ truncation=True, padding='max_length'
24
+ )
25
+ return inputs['input_ids'].squeeze(0), inputs['attention_mask'].squeeze(0)
26
+
27
+ # ----- DATASET CLASS -----
28
+ class FakedditDataset(Dataset):
29
+ def __init__(self, df, text_field="clean_title", label_field="binary_label", image_id="id"):
30
+ self.df = df.reset_index(drop=True)
31
+ self.text_field = text_field
32
+ self.label_field = label_field
33
+ self.image_id = image_id
34
+
35
+ self.transform = v2.Compose([
36
+ v2.Resize((256, 256)),
37
+ v2.ToImage(),
38
+ v2.ToDtype(torch.float32, scale=True),
39
+ v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
40
+ ])
41
+
42
+ def __len__(self):
43
+ return len(self.df)
44
+
45
+ def __getitem__(self, idx):
46
+ text = self.df.at[idx, self.text_field]
47
+ label = self.df.at[idx, self.label_field]
48
+ image_path = f"./val_images/{self.df.at[idx, self.image_id]}.jpg"
49
+
50
+ image = Image.open(image_path).convert('RGB')
51
+ image = self.transform(image)
52
+ input_ids, attention_mask = get_bert_embedding(str(text))
53
+
54
+ return image, input_ids, attention_mask, torch.tensor(label, dtype=torch.long)
55
+
56
+ # ----- MODEL CLASSES -----
57
+ class SelfAttentionFusion(nn.Module):
58
+ def __init__(self, embed_dim):
59
+ super().__init__()
60
+ self.attn = nn.Linear(embed_dim * 2, 2)
61
+ self.softmax = nn.Softmax(dim=1)
62
+
63
+ def forward(self, x_text, x_img):
64
+ stacked = torch.stack([x_text, x_img], dim=1)
65
+ attn_weights = self.softmax(self.attn(torch.cat([x_text, x_img], dim=1))).unsqueeze(2)
66
+ fused = (attn_weights * stacked).sum(dim=1)
67
+ return fused
68
+
69
+ class BERTResNetClassifier(nn.Module):
70
+ def __init__(self, num_classes=2):
71
+ super().__init__()
72
+ self.image_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
73
+ self.fc_image = nn.Linear(1000, 512)
74
+ self.drop_img = nn.Dropout(0.3)
75
+
76
+ self.text_model = BertModel.from_pretrained("bert-base-uncased")
77
+ self.fc_text = nn.Linear(self.text_model.config.hidden_size, 512)
78
+ self.drop_text = nn.Dropout(0.3)
79
+
80
+ self.fusion = SelfAttentionFusion(512)
81
+ self.fc_final = nn.Linear(512, num_classes)
82
+
83
+ def forward(self, image, input_ids, attention_mask):
84
+ x_img = self.image_model(image)
85
+ x_img = self.drop_img(x_img)
86
+ x_img = self.fc_image(x_img)
87
+
88
+ x_text = self.text_model(input_ids=input_ids, attention_mask=attention_mask)[0][:, 0, :]
89
+ x_text = self.drop_text(x_text)
90
+ x_text = self.fc_text(x_text)
91
+
92
+ x_fused = self.fusion(x_text, x_img)
93
+ return self.fc_final(x_fused)
94
+
95
+ # ----- LOAD DATA -----
96
+ df = pd.read_csv("./val_output.csv")
97
+ print("πŸ“„ Loaded validation CSV with", len(df), "samples")
98
+ val_dataset = FakedditDataset(df)
99
+ val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
100
+
101
+ # ----- LOAD MODEL STATE -----
102
+ def remove_module_prefix(state_dict):
103
+ from collections import OrderedDict
104
+ new_state_dict = OrderedDict()
105
+ for k, v in state_dict.items():
106
+ name = k.replace('module.', '')
107
+ new_state_dict[name] = v
108
+ return new_state_dict
109
+
110
+ print("πŸ“¦ Loading model weights...")
111
+ state_dict = torch.load("state_dict.pth", map_location=device)
112
+ clean_state_dict = remove_module_prefix(state_dict)
113
+
114
+ model = BERTResNetClassifier(num_classes=2)
115
+ model.load_state_dict(clean_state_dict)
116
+ model.to(device)
117
+ model.eval()
118
+ print("βœ… Model loaded and ready for evaluation")
119
+
120
+ # ----- EVALUATION -----
121
+ correct = 0
122
+ total = 0
123
+ print("\nπŸ” Starting evaluation...")
124
+ with torch.no_grad():
125
+ for batch in tqdm(val_loader, desc="Evaluating"):
126
+ images, input_ids, attention_mask, labels = batch
127
+ images = images.to(device)
128
+ input_ids = input_ids.to(device)
129
+ attention_mask = attention_mask.to(device)
130
+ labels = labels.to(device)
131
+
132
+ outputs = model(images, input_ids, attention_mask)
133
+ preds = torch.argmax(outputs, dim=1)
134
+ correct += (preds == labels).sum().item()
135
+ total += labels.size(0)
136
+
137
+ accuracy = correct / total * 100
138
+ print(f"\n🎯 Final Validation Accuracy: {accuracy:.2f}%")