EarthLoc2 / convert_to_web_dataset.py
Pawel Piwowarski
init commit
0a82b18
from torch.utils.data import DataLoader
import webdataset as wds
import torch
from PIL import Image
import io
import json
import os
from pathlib import Path
import sys
from tqdm import tqdm
import pandas as pd
sys.path.append(str(Path("image-matching-models")))
sys.path.append(str(Path("image-matching-models/matching/third_party")))
import util_matching
def create_webdataset_shards(image_dir, output_dir, shard_size=10000):
"""Convert directory of images to WebDataset shards"""
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)
image_paths = list(Path(image_dir).glob("*/*/*.jpg")) # Adjust extension
num_shards = (len(image_paths) // shard_size) + 1
for shard_idx in tqdm(range(num_shards), desc='Creating Shards'):
start = shard_idx * shard_size
end = min((shard_idx + 1) * shard_size, len(image_paths))
shard_path = output_dir / f"shard-{shard_idx:06d}.tar"
with wds.TarWriter(str(shard_path)) as sink:
for idx in tqdm(range(start, end)):
img_path = image_paths[idx]
relative = img_path.relative_to(image_dir).with_suffix('')
key = relative.as_posix().replace('/', '_').replace('.', ',')
# Load image and metadata
with Image.open(img_path) as img:
img_bytes = io.BytesIO()
img.save(img_bytes, format="JPEG")
sink.write({
"__key__": key,
"jpg": img_bytes.getvalue(),
})
if __name__ == "__main__":# Usage:
create_webdataset_shards(
image_dir="data/database",
output_dir="data/webdataset_shards",
shard_size=15000 # Adjust based on your needs
)