|
from torch.utils.data import DataLoader |
|
import webdataset as wds |
|
import torch |
|
from PIL import Image |
|
import io |
|
import json |
|
import os |
|
from pathlib import Path |
|
import sys |
|
from tqdm import tqdm |
|
import pandas as pd |
|
sys.path.append(str(Path("image-matching-models"))) |
|
sys.path.append(str(Path("image-matching-models/matching/third_party"))) |
|
import util_matching |
|
def create_webdataset_shards(image_dir, output_dir, shard_size=10000): |
|
"""Convert directory of images to WebDataset shards""" |
|
output_dir = Path(output_dir) |
|
output_dir.mkdir(exist_ok=True) |
|
|
|
|
|
image_paths = list(Path(image_dir).glob("*/*/*.jpg")) |
|
|
|
num_shards = (len(image_paths) // shard_size) + 1 |
|
|
|
for shard_idx in tqdm(range(num_shards), desc='Creating Shards'): |
|
start = shard_idx * shard_size |
|
end = min((shard_idx + 1) * shard_size, len(image_paths)) |
|
shard_path = output_dir / f"shard-{shard_idx:06d}.tar" |
|
|
|
with wds.TarWriter(str(shard_path)) as sink: |
|
for idx in tqdm(range(start, end)): |
|
img_path = image_paths[idx] |
|
|
|
relative = img_path.relative_to(image_dir).with_suffix('') |
|
key = relative.as_posix().replace('/', '_').replace('.', ',') |
|
|
|
|
|
|
|
with Image.open(img_path) as img: |
|
img_bytes = io.BytesIO() |
|
img.save(img_bytes, format="JPEG") |
|
|
|
|
|
sink.write({ |
|
"__key__": key, |
|
"jpg": img_bytes.getvalue(), |
|
|
|
|
|
}) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
create_webdataset_shards( |
|
image_dir="data/database", |
|
output_dir="data/webdataset_shards", |
|
shard_size=15000 |
|
) |