STron
Added Roberta and Vit
7df2acb
import argparse
import numpy as np
import pandas as pd
import os
from urllib import request
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from sklearn.utils import resample
from torchvision.transforms import v2
from PIL import Image
def load_and_prepare_data(file_path):
df = pd.read_csv(file_path, sep="\t")
df.drop(['2_way_label', '3_way_label', 'title'], axis=1, inplace=True)
df['binary_label'] = df['6_way_label'].apply(lambda x: 0 if x == 0 else 1)
df.reset_index(drop=True, inplace=True)
return df
def balance_data(df, max_samples_per_class=35000):
df_with_image = df[df['hasImage'] == True]
df_class_0 = df_with_image[df_with_image['binary_label'] == 0]
df_class_1 = df_with_image[df_with_image['binary_label'] == 1]
target_count = min(len(df_class_0), len(df_class_1), max_samples_per_class)
df_sample_0 = resample(df_class_0, replace=False, n_samples=target_count, random_state=42)
df_sample_1 = resample(df_class_1, replace=False, n_samples=target_count, random_state=42)
df_balanced = pd.concat([df_sample_0, df_sample_1])
df_balanced = df_balanced.sample(frac=1, random_state=42).reset_index(drop=True)
df_balanced = df_balanced.replace(np.nan, '', regex=True)
df_balanced.fillna('', inplace=True)
return df_balanced, df_class_1[~df_class_1['id'].isin(df_sample_1['id'])]
def ensure_directory(path):
if not os.path.exists(path):
os.makedirs(path)
def download_image(row, image_dir):
index = row[0]
row = row[1]
if row["hasImage"] and row["image_url"] not in ["", "nan"]:
image_url = row["image_url"]
path = os.path.join(image_dir, f"{row['id']}.jpg")
try:
with open(path, 'wb') as f:
f.write(request.urlopen(image_url, timeout=5).read())
except:
return index
return None
def download_images_fast(df, image_dir, max_workers=16):
failed_indices = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(download_image, row, image_dir) for row in df.iterrows()]
for f in tqdm(as_completed(futures), total=len(futures), desc="Downloading images"):
result = f.result()
if result is not None:
failed_indices.append(result)
df.drop(index=failed_indices, inplace=True)
df.reset_index(drop=True, inplace=True)
return df
def validate_image(row, image_dir):
index = row[0]
row = row[1]
image_path = os.path.join(image_dir, f"{row['id']}.jpg")
try:
with Image.open(image_path) as img:
img.verify()
return None
except:
if os.path.exists(image_path):
os.remove(image_path)
return index
def validate_images_fast(df, image_dir, max_workers=16):
corrupted_indices = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(validate_image, row, image_dir) for row in df.iterrows()]
for f in tqdm(as_completed(futures), total=len(futures), desc="Validating images"):
result = f.result()
if result is not None:
corrupted_indices.append(result)
df.drop(index=corrupted_indices, inplace=True)
df.reset_index(drop=True, inplace=True)
return df, corrupted_indices
def resize_images(df, image_dir, size=(256, 256)):
resize_transform = v2.Resize(size)
for index, row in tqdm(df.iterrows(), total=len(df), desc="Resizing images"):
image_path = os.path.join(image_dir, f"{row['id']}.jpg")
try:
image = Image.open(image_path).convert("RGB")
resized_image = resize_transform(image)
resized_image.save(image_path)
except Exception as e:
print(f"Failed to resize {image_path}: {e}")
df.drop(index=index, inplace=True)
df.reset_index(drop=True, inplace=True)
return df
def augment_minority_class(df_balanced, df_remaining_class_1, image_dir, batch_size=4000):
needed = len(df_balanced[df_balanced['binary_label'] == 0]) - len(df_balanced[df_balanced['binary_label'] == 1])
collected = []
print(f"Need to add {needed} more class 1 samples...")
while len(collected) < needed and len(df_remaining_class_1) > 0:
batch = df_remaining_class_1.sample(n=min(batch_size, len(df_remaining_class_1)), random_state=42)
df_remaining_class_1 = df_remaining_class_1.drop(batch.index)
print(f"\nπŸŒ€ Downloading batch of {len(batch)} images...")
batch = download_images_fast(batch.copy(), image_dir)
print(f"πŸ”Ž Validating downloaded images...")
valid_batch, _ = validate_images_fast(batch.copy(), image_dir)
print(f"🎨 Resizing valid images...")
valid_batch = resize_images(valid_batch, image_dir)
collected.append(valid_batch)
if sum(len(df) for df in collected) >= needed:
break
df_extra_class_1 = pd.concat(collected).reset_index(drop=True)
df_extra_class_1 = df_extra_class_1.sample(n=needed, random_state=42).reset_index(drop=True)
df_balanced_updated = pd.concat([df_balanced, df_extra_class_1], ignore_index=True)
df_balanced_updated = df_balanced_updated.sample(frac=1, random_state=42).reset_index(drop=True)
return df_balanced_updated
def main(args):
ensure_directory(args.image_dir)
df = load_and_prepare_data(args.tsv_path)
df_balanced, df_remaining_class_1 = balance_data(df, max_samples_per_class=args.max_samples)
df_balanced.to_csv("./df.csv", index=False)
df_balanced = download_images_fast(df_balanced, args.image_dir)
print(f"βœ… Finished downloading. Remaining rows: {len(df_balanced)}")
df_balanced.to_csv("./df_balanced.csv", index=False)
df_balanced, _ = validate_images_fast(df_balanced, args.image_dir)
df_balanced = resize_images(df_balanced, args.image_dir)
df_balanced.to_csv("./df_balanced_resized.csv", index=False)
df_balanced_updated = augment_minority_class(df_balanced, df_remaining_class_1, args.image_dir)
df_balanced_updated.to_csv(args.output_csv, index=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Image Dataset Preprocessing Pipeline")
parser.add_argument('--tsv_path', type=str, default="./multimodal_train.tsv", help='Path to the input TSV file')
parser.add_argument('--image_dir', type=str, default="./images", help='Directory to save images')
parser.add_argument('--output_csv', type=str, default="./final_output.csv", help='Path to save final balanced CSV')
parser.add_argument('--max_samples', type=int, default=35000, help='Maximum number of samples per class')
parser.add_argument('--skip_existing', action='store_true', help='Skip downloading if image already exists')
args = parser.parse_args()
main(args)