|
import argparse |
|
import numpy as np |
|
import pandas as pd |
|
import os |
|
from urllib import request |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
from tqdm import tqdm |
|
from sklearn.utils import resample |
|
from torchvision.transforms import v2 |
|
from PIL import Image |
|
|
|
def load_and_prepare_data(file_path): |
|
df = pd.read_csv(file_path, sep="\t") |
|
df.drop(['2_way_label', '3_way_label', 'title'], axis=1, inplace=True) |
|
df['binary_label'] = df['6_way_label'].apply(lambda x: 0 if x == 0 else 1) |
|
df.reset_index(drop=True, inplace=True) |
|
return df |
|
|
|
def balance_data(df, max_samples_per_class=35000): |
|
df_with_image = df[df['hasImage'] == True] |
|
df_class_0 = df_with_image[df_with_image['binary_label'] == 0] |
|
df_class_1 = df_with_image[df_with_image['binary_label'] == 1] |
|
target_count = min(len(df_class_0), len(df_class_1), max_samples_per_class) |
|
|
|
df_sample_0 = resample(df_class_0, replace=False, n_samples=target_count, random_state=42) |
|
df_sample_1 = resample(df_class_1, replace=False, n_samples=target_count, random_state=42) |
|
|
|
df_balanced = pd.concat([df_sample_0, df_sample_1]) |
|
df_balanced = df_balanced.sample(frac=1, random_state=42).reset_index(drop=True) |
|
df_balanced = df_balanced.replace(np.nan, '', regex=True) |
|
df_balanced.fillna('', inplace=True) |
|
return df_balanced, df_class_1[~df_class_1['id'].isin(df_sample_1['id'])] |
|
|
|
def ensure_directory(path): |
|
if not os.path.exists(path): |
|
os.makedirs(path) |
|
|
|
def download_image(row, image_dir): |
|
index = row[0] |
|
row = row[1] |
|
if row["hasImage"] and row["image_url"] not in ["", "nan"]: |
|
image_url = row["image_url"] |
|
path = os.path.join(image_dir, f"{row['id']}.jpg") |
|
try: |
|
with open(path, 'wb') as f: |
|
f.write(request.urlopen(image_url, timeout=5).read()) |
|
except: |
|
return index |
|
return None |
|
|
|
def download_images_fast(df, image_dir, max_workers=16): |
|
failed_indices = [] |
|
with ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
futures = [executor.submit(download_image, row, image_dir) for row in df.iterrows()] |
|
for f in tqdm(as_completed(futures), total=len(futures), desc="Downloading images"): |
|
result = f.result() |
|
if result is not None: |
|
failed_indices.append(result) |
|
df.drop(index=failed_indices, inplace=True) |
|
df.reset_index(drop=True, inplace=True) |
|
return df |
|
|
|
def validate_image(row, image_dir): |
|
index = row[0] |
|
row = row[1] |
|
image_path = os.path.join(image_dir, f"{row['id']}.jpg") |
|
try: |
|
with Image.open(image_path) as img: |
|
img.verify() |
|
return None |
|
except: |
|
if os.path.exists(image_path): |
|
os.remove(image_path) |
|
return index |
|
|
|
def validate_images_fast(df, image_dir, max_workers=16): |
|
corrupted_indices = [] |
|
with ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
futures = [executor.submit(validate_image, row, image_dir) for row in df.iterrows()] |
|
for f in tqdm(as_completed(futures), total=len(futures), desc="Validating images"): |
|
result = f.result() |
|
if result is not None: |
|
corrupted_indices.append(result) |
|
df.drop(index=corrupted_indices, inplace=True) |
|
df.reset_index(drop=True, inplace=True) |
|
return df, corrupted_indices |
|
|
|
def resize_images(df, image_dir, size=(256, 256)): |
|
resize_transform = v2.Resize(size) |
|
for index, row in tqdm(df.iterrows(), total=len(df), desc="Resizing images"): |
|
image_path = os.path.join(image_dir, f"{row['id']}.jpg") |
|
try: |
|
image = Image.open(image_path).convert("RGB") |
|
resized_image = resize_transform(image) |
|
resized_image.save(image_path) |
|
except Exception as e: |
|
print(f"Failed to resize {image_path}: {e}") |
|
df.drop(index=index, inplace=True) |
|
df.reset_index(drop=True, inplace=True) |
|
return df |
|
|
|
def augment_minority_class(df_balanced, df_remaining_class_1, image_dir, batch_size=4000): |
|
needed = len(df_balanced[df_balanced['binary_label'] == 0]) - len(df_balanced[df_balanced['binary_label'] == 1]) |
|
collected = [] |
|
print(f"Need to add {needed} more class 1 samples...") |
|
while len(collected) < needed and len(df_remaining_class_1) > 0: |
|
batch = df_remaining_class_1.sample(n=min(batch_size, len(df_remaining_class_1)), random_state=42) |
|
df_remaining_class_1 = df_remaining_class_1.drop(batch.index) |
|
|
|
print(f"\nπ Downloading batch of {len(batch)} images...") |
|
batch = download_images_fast(batch.copy(), image_dir) |
|
|
|
print(f"π Validating downloaded images...") |
|
valid_batch, _ = validate_images_fast(batch.copy(), image_dir) |
|
|
|
print(f"π¨ Resizing valid images...") |
|
valid_batch = resize_images(valid_batch, image_dir) |
|
|
|
collected.append(valid_batch) |
|
|
|
if sum(len(df) for df in collected) >= needed: |
|
break |
|
|
|
df_extra_class_1 = pd.concat(collected).reset_index(drop=True) |
|
df_extra_class_1 = df_extra_class_1.sample(n=needed, random_state=42).reset_index(drop=True) |
|
|
|
df_balanced_updated = pd.concat([df_balanced, df_extra_class_1], ignore_index=True) |
|
df_balanced_updated = df_balanced_updated.sample(frac=1, random_state=42).reset_index(drop=True) |
|
return df_balanced_updated |
|
|
|
def main(args): |
|
ensure_directory(args.image_dir) |
|
|
|
df = load_and_prepare_data(args.tsv_path) |
|
df_balanced, df_remaining_class_1 = balance_data(df, max_samples_per_class=args.max_samples) |
|
df_balanced.to_csv("./df.csv", index=False) |
|
|
|
df_balanced = download_images_fast(df_balanced, args.image_dir) |
|
print(f"β
Finished downloading. Remaining rows: {len(df_balanced)}") |
|
df_balanced.to_csv("./df_balanced.csv", index=False) |
|
|
|
df_balanced, _ = validate_images_fast(df_balanced, args.image_dir) |
|
df_balanced = resize_images(df_balanced, args.image_dir) |
|
df_balanced.to_csv("./df_balanced_resized.csv", index=False) |
|
|
|
df_balanced_updated = augment_minority_class(df_balanced, df_remaining_class_1, args.image_dir) |
|
df_balanced_updated.to_csv(args.output_csv, index=False) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Image Dataset Preprocessing Pipeline") |
|
parser.add_argument('--tsv_path', type=str, default="./multimodal_train.tsv", help='Path to the input TSV file') |
|
parser.add_argument('--image_dir', type=str, default="./images", help='Directory to save images') |
|
parser.add_argument('--output_csv', type=str, default="./final_output.csv", help='Path to save final balanced CSV') |
|
parser.add_argument('--max_samples', type=int, default=35000, help='Maximum number of samples per class') |
|
parser.add_argument('--skip_existing', action='store_true', help='Skip downloading if image already exists') |
|
args = parser.parse_args() |
|
main(args) |
|
|