import argparse import numpy as np import pandas as pd import os from urllib import request from concurrent.futures import ThreadPoolExecutor, as_completed from tqdm import tqdm from sklearn.utils import resample from torchvision.transforms import v2 from PIL import Image def load_and_prepare_data(file_path): df = pd.read_csv(file_path, sep="\t") df.drop(['2_way_label', '3_way_label', 'title'], axis=1, inplace=True) df['binary_label'] = df['6_way_label'].apply(lambda x: 0 if x == 0 else 1) df.reset_index(drop=True, inplace=True) return df def balance_data(df, max_samples_per_class=35000): df_with_image = df[df['hasImage'] == True] df_class_0 = df_with_image[df_with_image['binary_label'] == 0] df_class_1 = df_with_image[df_with_image['binary_label'] == 1] target_count = min(len(df_class_0), len(df_class_1), max_samples_per_class) df_sample_0 = resample(df_class_0, replace=False, n_samples=target_count, random_state=42) df_sample_1 = resample(df_class_1, replace=False, n_samples=target_count, random_state=42) df_balanced = pd.concat([df_sample_0, df_sample_1]) df_balanced = df_balanced.sample(frac=1, random_state=42).reset_index(drop=True) df_balanced = df_balanced.replace(np.nan, '', regex=True) df_balanced.fillna('', inplace=True) return df_balanced, df_class_1[~df_class_1['id'].isin(df_sample_1['id'])] def ensure_directory(path): if not os.path.exists(path): os.makedirs(path) def download_image(row, image_dir): index = row[0] row = row[1] if row["hasImage"] and row["image_url"] not in ["", "nan"]: image_url = row["image_url"] path = os.path.join(image_dir, f"{row['id']}.jpg") try: with open(path, 'wb') as f: f.write(request.urlopen(image_url, timeout=5).read()) except: return index return None def download_images_fast(df, image_dir, max_workers=16): failed_indices = [] with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [executor.submit(download_image, row, image_dir) for row in df.iterrows()] for f in tqdm(as_completed(futures), total=len(futures), desc="Downloading images"): result = f.result() if result is not None: failed_indices.append(result) df.drop(index=failed_indices, inplace=True) df.reset_index(drop=True, inplace=True) return df def validate_image(row, image_dir): index = row[0] row = row[1] image_path = os.path.join(image_dir, f"{row['id']}.jpg") try: with Image.open(image_path) as img: img.verify() return None except: if os.path.exists(image_path): os.remove(image_path) return index def validate_images_fast(df, image_dir, max_workers=16): corrupted_indices = [] with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [executor.submit(validate_image, row, image_dir) for row in df.iterrows()] for f in tqdm(as_completed(futures), total=len(futures), desc="Validating images"): result = f.result() if result is not None: corrupted_indices.append(result) df.drop(index=corrupted_indices, inplace=True) df.reset_index(drop=True, inplace=True) return df, corrupted_indices def resize_images(df, image_dir, size=(256, 256)): resize_transform = v2.Resize(size) for index, row in tqdm(df.iterrows(), total=len(df), desc="Resizing images"): image_path = os.path.join(image_dir, f"{row['id']}.jpg") try: image = Image.open(image_path).convert("RGB") resized_image = resize_transform(image) resized_image.save(image_path) except Exception as e: print(f"Failed to resize {image_path}: {e}") df.drop(index=index, inplace=True) df.reset_index(drop=True, inplace=True) return df def augment_minority_class(df_balanced, df_remaining_class_1, image_dir, batch_size=4000): needed = len(df_balanced[df_balanced['binary_label'] == 0]) - len(df_balanced[df_balanced['binary_label'] == 1]) collected = [] print(f"Need to add {needed} more class 1 samples...") while len(collected) < needed and len(df_remaining_class_1) > 0: batch = df_remaining_class_1.sample(n=min(batch_size, len(df_remaining_class_1)), random_state=42) df_remaining_class_1 = df_remaining_class_1.drop(batch.index) print(f"\nšŸŒ€ Downloading batch of {len(batch)} images...") batch = download_images_fast(batch.copy(), image_dir) print(f"šŸ”Ž Validating downloaded images...") valid_batch, _ = validate_images_fast(batch.copy(), image_dir) print(f"šŸŽØ Resizing valid images...") valid_batch = resize_images(valid_batch, image_dir) collected.append(valid_batch) if sum(len(df) for df in collected) >= needed: break df_extra_class_1 = pd.concat(collected).reset_index(drop=True) df_extra_class_1 = df_extra_class_1.sample(n=needed, random_state=42).reset_index(drop=True) df_balanced_updated = pd.concat([df_balanced, df_extra_class_1], ignore_index=True) df_balanced_updated = df_balanced_updated.sample(frac=1, random_state=42).reset_index(drop=True) return df_balanced_updated def main(args): ensure_directory(args.image_dir) df = load_and_prepare_data(args.tsv_path) df_balanced, df_remaining_class_1 = balance_data(df, max_samples_per_class=args.max_samples) df_balanced.to_csv("./df.csv", index=False) df_balanced = download_images_fast(df_balanced, args.image_dir) print(f"āœ… Finished downloading. Remaining rows: {len(df_balanced)}") df_balanced.to_csv("./df_balanced.csv", index=False) df_balanced, _ = validate_images_fast(df_balanced, args.image_dir) df_balanced = resize_images(df_balanced, args.image_dir) df_balanced.to_csv("./df_balanced_resized.csv", index=False) df_balanced_updated = augment_minority_class(df_balanced, df_remaining_class_1, args.image_dir) df_balanced_updated.to_csv(args.output_csv, index=False) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Image Dataset Preprocessing Pipeline") parser.add_argument('--tsv_path', type=str, default="./multimodal_train.tsv", help='Path to the input TSV file') parser.add_argument('--image_dir', type=str, default="./images", help='Directory to save images') parser.add_argument('--output_csv', type=str, default="./final_output.csv", help='Path to save final balanced CSV') parser.add_argument('--max_samples', type=int, default=35000, help='Maximum number of samples per class') parser.add_argument('--skip_existing', action='store_true', help='Skip downloading if image already exists') args = parser.parse_args() main(args)