Spaces:
Sleeping
Sleeping
import argparse | |
import numpy as np | |
import pandas as pd | |
import os | |
from urllib import request | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from tqdm import tqdm | |
from sklearn.utils import resample | |
from torchvision.transforms import v2 | |
from PIL import Image | |
def load_and_prepare_data(file_path): | |
df = pd.read_csv(file_path, sep="\t") | |
df.drop(['2_way_label', '3_way_label', 'title'], axis=1, inplace=True) | |
df['binary_label'] = df['6_way_label'].apply(lambda x: 0 if x == 0 else 1) | |
df.reset_index(drop=True, inplace=True) | |
return df | |
def balance_data(df, max_samples_per_class=35000): | |
df_with_image = df[df['hasImage'] == True] | |
df_class_0 = df_with_image[df_with_image['binary_label'] == 0] | |
df_class_1 = df_with_image[df_with_image['binary_label'] == 1] | |
target_count = min(len(df_class_0), len(df_class_1), max_samples_per_class) | |
df_sample_0 = resample(df_class_0, replace=False, n_samples=target_count, random_state=42) | |
df_sample_1 = resample(df_class_1, replace=False, n_samples=target_count, random_state=42) | |
df_balanced = pd.concat([df_sample_0, df_sample_1]) | |
df_balanced = df_balanced.sample(frac=1, random_state=42).reset_index(drop=True) | |
df_balanced = df_balanced.replace(np.nan, '', regex=True) | |
df_balanced.fillna('', inplace=True) | |
return df_balanced, df_class_1[~df_class_1['id'].isin(df_sample_1['id'])] | |
def ensure_directory(path): | |
if not os.path.exists(path): | |
os.makedirs(path) | |
def download_image(row, image_dir): | |
index = row[0] | |
row = row[1] | |
if row["hasImage"] and row["image_url"] not in ["", "nan"]: | |
image_url = row["image_url"] | |
path = os.path.join(image_dir, f"{row['id']}.jpg") | |
try: | |
with open(path, 'wb') as f: | |
f.write(request.urlopen(image_url, timeout=5).read()) | |
except: | |
return index | |
return None | |
def download_images_fast(df, image_dir, max_workers=16): | |
failed_indices = [] | |
with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
futures = [executor.submit(download_image, row, image_dir) for row in df.iterrows()] | |
for f in tqdm(as_completed(futures), total=len(futures), desc="Downloading images"): | |
result = f.result() | |
if result is not None: | |
failed_indices.append(result) | |
df.drop(index=failed_indices, inplace=True) | |
df.reset_index(drop=True, inplace=True) | |
return df | |
def validate_image(row, image_dir): | |
index = row[0] | |
row = row[1] | |
image_path = os.path.join(image_dir, f"{row['id']}.jpg") | |
try: | |
with Image.open(image_path) as img: | |
img.verify() | |
return None | |
except: | |
if os.path.exists(image_path): | |
os.remove(image_path) | |
return index | |
def validate_images_fast(df, image_dir, max_workers=16): | |
corrupted_indices = [] | |
with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
futures = [executor.submit(validate_image, row, image_dir) for row in df.iterrows()] | |
for f in tqdm(as_completed(futures), total=len(futures), desc="Validating images"): | |
result = f.result() | |
if result is not None: | |
corrupted_indices.append(result) | |
df.drop(index=corrupted_indices, inplace=True) | |
df.reset_index(drop=True, inplace=True) | |
return df, corrupted_indices | |
def resize_images(df, image_dir, size=(256, 256)): | |
resize_transform = v2.Resize(size) | |
for index, row in tqdm(df.iterrows(), total=len(df), desc="Resizing images"): | |
image_path = os.path.join(image_dir, f"{row['id']}.jpg") | |
try: | |
image = Image.open(image_path).convert("RGB") | |
resized_image = resize_transform(image) | |
resized_image.save(image_path) | |
except Exception as e: | |
print(f"Failed to resize {image_path}: {e}") | |
df.drop(index=index, inplace=True) | |
df.reset_index(drop=True, inplace=True) | |
return df | |
def augment_minority_class(df_balanced, df_remaining_class_1, image_dir, batch_size=4000): | |
needed = len(df_balanced[df_balanced['binary_label'] == 0]) - len(df_balanced[df_balanced['binary_label'] == 1]) | |
collected = [] | |
print(f"Need to add {needed} more class 1 samples...") | |
while len(collected) < needed and len(df_remaining_class_1) > 0: | |
batch = df_remaining_class_1.sample(n=min(batch_size, len(df_remaining_class_1)), random_state=42) | |
df_remaining_class_1 = df_remaining_class_1.drop(batch.index) | |
print(f"\nπ Downloading batch of {len(batch)} images...") | |
batch = download_images_fast(batch.copy(), image_dir) | |
print(f"π Validating downloaded images...") | |
valid_batch, _ = validate_images_fast(batch.copy(), image_dir) | |
print(f"π¨ Resizing valid images...") | |
valid_batch = resize_images(valid_batch, image_dir) | |
collected.append(valid_batch) | |
if sum(len(df) for df in collected) >= needed: | |
break | |
df_extra_class_1 = pd.concat(collected).reset_index(drop=True) | |
df_extra_class_1 = df_extra_class_1.sample(n=needed, random_state=42).reset_index(drop=True) | |
df_balanced_updated = pd.concat([df_balanced, df_extra_class_1], ignore_index=True) | |
df_balanced_updated = df_balanced_updated.sample(frac=1, random_state=42).reset_index(drop=True) | |
return df_balanced_updated | |
def main(args): | |
ensure_directory(args.image_dir) | |
df = load_and_prepare_data(args.tsv_path) | |
df_balanced, df_remaining_class_1 = balance_data(df, max_samples_per_class=args.max_samples) | |
df_balanced.to_csv("./df.csv", index=False) | |
df_balanced = download_images_fast(df_balanced, args.image_dir) | |
print(f"β Finished downloading. Remaining rows: {len(df_balanced)}") | |
df_balanced.to_csv("./df_balanced.csv", index=False) | |
df_balanced, _ = validate_images_fast(df_balanced, args.image_dir) | |
df_balanced = resize_images(df_balanced, args.image_dir) | |
df_balanced.to_csv("./df_balanced_resized.csv", index=False) | |
df_balanced_updated = augment_minority_class(df_balanced, df_remaining_class_1, args.image_dir) | |
df_balanced_updated.to_csv(args.output_csv, index=False) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Image Dataset Preprocessing Pipeline") | |
parser.add_argument('--tsv_path', type=str, default="./multimodal_train.tsv", help='Path to the input TSV file') | |
parser.add_argument('--image_dir', type=str, default="./images", help='Directory to save images') | |
parser.add_argument('--output_csv', type=str, default="./final_output.csv", help='Path to save final balanced CSV') | |
parser.add_argument('--max_samples', type=int, default=35000, help='Maximum number of samples per class') | |
parser.add_argument('--skip_existing', action='store_true', help='Skip downloading if image already exists') | |
args = parser.parse_args() | |
main(args) | |