from torch.utils.data import DataLoader import webdataset as wds import torch from PIL import Image import io import json import os from pathlib import Path import sys from tqdm import tqdm import pandas as pd sys.path.append(str(Path("image-matching-models"))) sys.path.append(str(Path("image-matching-models/matching/third_party"))) import util_matching def create_webdataset_shards(image_dir, output_dir, shard_size=10000): """Convert directory of images to WebDataset shards""" output_dir = Path(output_dir) output_dir.mkdir(exist_ok=True) image_paths = list(Path(image_dir).glob("*/*/*.jpg")) # Adjust extension num_shards = (len(image_paths) // shard_size) + 1 for shard_idx in tqdm(range(num_shards), desc='Creating Shards'): start = shard_idx * shard_size end = min((shard_idx + 1) * shard_size, len(image_paths)) shard_path = output_dir / f"shard-{shard_idx:06d}.tar" with wds.TarWriter(str(shard_path)) as sink: for idx in tqdm(range(start, end)): img_path = image_paths[idx] relative = img_path.relative_to(image_dir).with_suffix('') key = relative.as_posix().replace('/', '_').replace('.', ',') # Load image and metadata with Image.open(img_path) as img: img_bytes = io.BytesIO() img.save(img_bytes, format="JPEG") sink.write({ "__key__": key, "jpg": img_bytes.getvalue(), }) if __name__ == "__main__":# Usage: create_webdataset_shards( image_dir="data/database", output_dir="data/webdataset_shards", shard_size=15000 # Adjust based on your needs )