STron
Added Roberta and Vit
7df2acb
raw
history blame
6.89 kB
import argparse
import numpy as np
import pandas as pd
import os
from urllib import request
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from sklearn.utils import resample
from torchvision.transforms import v2
from PIL import Image
def load_and_prepare_data(file_path):
df = pd.read_csv(file_path, sep="\t")
df.drop(['2_way_label', '3_way_label', 'title'], axis=1, inplace=True)
df['binary_label'] = df['6_way_label'].apply(lambda x: 0 if x == 0 else 1)
df.reset_index(drop=True, inplace=True)
return df
def balance_data(df, max_samples_per_class=35000):
df_with_image = df[df['hasImage'] == True]
df_class_0 = df_with_image[df_with_image['binary_label'] == 0]
df_class_1 = df_with_image[df_with_image['binary_label'] == 1]
target_count = min(len(df_class_0), len(df_class_1), max_samples_per_class)
df_sample_0 = resample(df_class_0, replace=False, n_samples=target_count, random_state=42)
df_sample_1 = resample(df_class_1, replace=False, n_samples=target_count, random_state=42)
df_balanced = pd.concat([df_sample_0, df_sample_1])
df_balanced = df_balanced.sample(frac=1, random_state=42).reset_index(drop=True)
df_balanced = df_balanced.replace(np.nan, '', regex=True)
df_balanced.fillna('', inplace=True)
return df_balanced, df_class_1[~df_class_1['id'].isin(df_sample_1['id'])]
def ensure_directory(path):
if not os.path.exists(path):
os.makedirs(path)
def download_image(row, image_dir):
index = row[0]
row = row[1]
if row["hasImage"] and row["image_url"] not in ["", "nan"]:
image_url = row["image_url"]
path = os.path.join(image_dir, f"{row['id']}.jpg")
try:
with open(path, 'wb') as f:
f.write(request.urlopen(image_url, timeout=5).read())
except:
return index
return None
def download_images_fast(df, image_dir, max_workers=16):
failed_indices = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(download_image, row, image_dir) for row in df.iterrows()]
for f in tqdm(as_completed(futures), total=len(futures), desc="Downloading images"):
result = f.result()
if result is not None:
failed_indices.append(result)
df.drop(index=failed_indices, inplace=True)
df.reset_index(drop=True, inplace=True)
return df
def validate_image(row, image_dir):
index = row[0]
row = row[1]
image_path = os.path.join(image_dir, f"{row['id']}.jpg")
try:
with Image.open(image_path) as img:
img.verify()
return None
except:
if os.path.exists(image_path):
os.remove(image_path)
return index
def validate_images_fast(df, image_dir, max_workers=16):
corrupted_indices = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(validate_image, row, image_dir) for row in df.iterrows()]
for f in tqdm(as_completed(futures), total=len(futures), desc="Validating images"):
result = f.result()
if result is not None:
corrupted_indices.append(result)
df.drop(index=corrupted_indices, inplace=True)
df.reset_index(drop=True, inplace=True)
return df, corrupted_indices
def resize_images(df, image_dir, size=(256, 256)):
resize_transform = v2.Resize(size)
for index, row in tqdm(df.iterrows(), total=len(df), desc="Resizing images"):
image_path = os.path.join(image_dir, f"{row['id']}.jpg")
try:
image = Image.open(image_path).convert("RGB")
resized_image = resize_transform(image)
resized_image.save(image_path)
except Exception as e:
print(f"Failed to resize {image_path}: {e}")
df.drop(index=index, inplace=True)
df.reset_index(drop=True, inplace=True)
return df
def augment_minority_class(df_balanced, df_remaining_class_1, image_dir, batch_size=4000):
needed = len(df_balanced[df_balanced['binary_label'] == 0]) - len(df_balanced[df_balanced['binary_label'] == 1])
collected = []
print(f"Need to add {needed} more class 1 samples...")
while len(collected) < needed and len(df_remaining_class_1) > 0:
batch = df_remaining_class_1.sample(n=min(batch_size, len(df_remaining_class_1)), random_state=42)
df_remaining_class_1 = df_remaining_class_1.drop(batch.index)
print(f"\nπŸŒ€ Downloading batch of {len(batch)} images...")
batch = download_images_fast(batch.copy(), image_dir)
print(f"πŸ”Ž Validating downloaded images...")
valid_batch, _ = validate_images_fast(batch.copy(), image_dir)
print(f"🎨 Resizing valid images...")
valid_batch = resize_images(valid_batch, image_dir)
collected.append(valid_batch)
if sum(len(df) for df in collected) >= needed:
break
df_extra_class_1 = pd.concat(collected).reset_index(drop=True)
df_extra_class_1 = df_extra_class_1.sample(n=needed, random_state=42).reset_index(drop=True)
df_balanced_updated = pd.concat([df_balanced, df_extra_class_1], ignore_index=True)
df_balanced_updated = df_balanced_updated.sample(frac=1, random_state=42).reset_index(drop=True)
return df_balanced_updated
def main(args):
ensure_directory(args.image_dir)
df = load_and_prepare_data(args.tsv_path)
df_balanced, df_remaining_class_1 = balance_data(df, max_samples_per_class=args.max_samples)
df_balanced.to_csv("./df.csv", index=False)
df_balanced = download_images_fast(df_balanced, args.image_dir)
print(f"βœ… Finished downloading. Remaining rows: {len(df_balanced)}")
df_balanced.to_csv("./df_balanced.csv", index=False)
df_balanced, _ = validate_images_fast(df_balanced, args.image_dir)
df_balanced = resize_images(df_balanced, args.image_dir)
df_balanced.to_csv("./df_balanced_resized.csv", index=False)
df_balanced_updated = augment_minority_class(df_balanced, df_remaining_class_1, args.image_dir)
df_balanced_updated.to_csv(args.output_csv, index=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Image Dataset Preprocessing Pipeline")
parser.add_argument('--tsv_path', type=str, default="./multimodal_train.tsv", help='Path to the input TSV file')
parser.add_argument('--image_dir', type=str, default="./images", help='Directory to save images')
parser.add_argument('--output_csv', type=str, default="./final_output.csv", help='Path to save final balanced CSV')
parser.add_argument('--max_samples', type=int, default=35000, help='Maximum number of samples per class')
parser.add_argument('--skip_existing', action='store_true', help='Skip downloading if image already exists')
args = parser.parse_args()
main(args)