File size: 6,894 Bytes
7df2acb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import argparse
import numpy as np
import pandas as pd
import os
from urllib import request
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from sklearn.utils import resample
from torchvision.transforms import v2
from PIL import Image
def load_and_prepare_data(file_path):
df = pd.read_csv(file_path, sep="\t")
df.drop(['2_way_label', '3_way_label', 'title'], axis=1, inplace=True)
df['binary_label'] = df['6_way_label'].apply(lambda x: 0 if x == 0 else 1)
df.reset_index(drop=True, inplace=True)
return df
def balance_data(df, max_samples_per_class=35000):
df_with_image = df[df['hasImage'] == True]
df_class_0 = df_with_image[df_with_image['binary_label'] == 0]
df_class_1 = df_with_image[df_with_image['binary_label'] == 1]
target_count = min(len(df_class_0), len(df_class_1), max_samples_per_class)
df_sample_0 = resample(df_class_0, replace=False, n_samples=target_count, random_state=42)
df_sample_1 = resample(df_class_1, replace=False, n_samples=target_count, random_state=42)
df_balanced = pd.concat([df_sample_0, df_sample_1])
df_balanced = df_balanced.sample(frac=1, random_state=42).reset_index(drop=True)
df_balanced = df_balanced.replace(np.nan, '', regex=True)
df_balanced.fillna('', inplace=True)
return df_balanced, df_class_1[~df_class_1['id'].isin(df_sample_1['id'])]
def ensure_directory(path):
if not os.path.exists(path):
os.makedirs(path)
def download_image(row, image_dir):
index = row[0]
row = row[1]
if row["hasImage"] and row["image_url"] not in ["", "nan"]:
image_url = row["image_url"]
path = os.path.join(image_dir, f"{row['id']}.jpg")
try:
with open(path, 'wb') as f:
f.write(request.urlopen(image_url, timeout=5).read())
except:
return index
return None
def download_images_fast(df, image_dir, max_workers=16):
failed_indices = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(download_image, row, image_dir) for row in df.iterrows()]
for f in tqdm(as_completed(futures), total=len(futures), desc="Downloading images"):
result = f.result()
if result is not None:
failed_indices.append(result)
df.drop(index=failed_indices, inplace=True)
df.reset_index(drop=True, inplace=True)
return df
def validate_image(row, image_dir):
index = row[0]
row = row[1]
image_path = os.path.join(image_dir, f"{row['id']}.jpg")
try:
with Image.open(image_path) as img:
img.verify()
return None
except:
if os.path.exists(image_path):
os.remove(image_path)
return index
def validate_images_fast(df, image_dir, max_workers=16):
corrupted_indices = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(validate_image, row, image_dir) for row in df.iterrows()]
for f in tqdm(as_completed(futures), total=len(futures), desc="Validating images"):
result = f.result()
if result is not None:
corrupted_indices.append(result)
df.drop(index=corrupted_indices, inplace=True)
df.reset_index(drop=True, inplace=True)
return df, corrupted_indices
def resize_images(df, image_dir, size=(256, 256)):
resize_transform = v2.Resize(size)
for index, row in tqdm(df.iterrows(), total=len(df), desc="Resizing images"):
image_path = os.path.join(image_dir, f"{row['id']}.jpg")
try:
image = Image.open(image_path).convert("RGB")
resized_image = resize_transform(image)
resized_image.save(image_path)
except Exception as e:
print(f"Failed to resize {image_path}: {e}")
df.drop(index=index, inplace=True)
df.reset_index(drop=True, inplace=True)
return df
def augment_minority_class(df_balanced, df_remaining_class_1, image_dir, batch_size=4000):
needed = len(df_balanced[df_balanced['binary_label'] == 0]) - len(df_balanced[df_balanced['binary_label'] == 1])
collected = []
print(f"Need to add {needed} more class 1 samples...")
while len(collected) < needed and len(df_remaining_class_1) > 0:
batch = df_remaining_class_1.sample(n=min(batch_size, len(df_remaining_class_1)), random_state=42)
df_remaining_class_1 = df_remaining_class_1.drop(batch.index)
print(f"\nπ Downloading batch of {len(batch)} images...")
batch = download_images_fast(batch.copy(), image_dir)
print(f"π Validating downloaded images...")
valid_batch, _ = validate_images_fast(batch.copy(), image_dir)
print(f"π¨ Resizing valid images...")
valid_batch = resize_images(valid_batch, image_dir)
collected.append(valid_batch)
if sum(len(df) for df in collected) >= needed:
break
df_extra_class_1 = pd.concat(collected).reset_index(drop=True)
df_extra_class_1 = df_extra_class_1.sample(n=needed, random_state=42).reset_index(drop=True)
df_balanced_updated = pd.concat([df_balanced, df_extra_class_1], ignore_index=True)
df_balanced_updated = df_balanced_updated.sample(frac=1, random_state=42).reset_index(drop=True)
return df_balanced_updated
def main(args):
ensure_directory(args.image_dir)
df = load_and_prepare_data(args.tsv_path)
df_balanced, df_remaining_class_1 = balance_data(df, max_samples_per_class=args.max_samples)
df_balanced.to_csv("./df.csv", index=False)
df_balanced = download_images_fast(df_balanced, args.image_dir)
print(f"β
Finished downloading. Remaining rows: {len(df_balanced)}")
df_balanced.to_csv("./df_balanced.csv", index=False)
df_balanced, _ = validate_images_fast(df_balanced, args.image_dir)
df_balanced = resize_images(df_balanced, args.image_dir)
df_balanced.to_csv("./df_balanced_resized.csv", index=False)
df_balanced_updated = augment_minority_class(df_balanced, df_remaining_class_1, args.image_dir)
df_balanced_updated.to_csv(args.output_csv, index=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Image Dataset Preprocessing Pipeline")
parser.add_argument('--tsv_path', type=str, default="./multimodal_train.tsv", help='Path to the input TSV file')
parser.add_argument('--image_dir', type=str, default="./images", help='Directory to save images')
parser.add_argument('--output_csv', type=str, default="./final_output.csv", help='Path to save final balanced CSV')
parser.add_argument('--max_samples', type=int, default=35000, help='Maximum number of samples per class')
parser.add_argument('--skip_existing', action='store_true', help='Skip downloading if image already exists')
args = parser.parse_args()
main(args)
|