EarthLoc2 / create_sharded_csv.py
Pawel Piwowarski
init commit
0a82b18
import pandas as pd
from pathlib import Path
import webdataset as wds
def convert_path_to_key(img_path: Path) -> str:
# 1. Get relative path from root
relative = img_path.relative_to("data/database")
# 2. Remove suffix (.jpg)
no_suffix = relative.with_suffix('')
# 3. Convert to POSIX-style string and flatten it
flat = no_suffix.as_posix().replace('/', '_')
# 4. Replace . with , to match your target format
key = flat.replace('.', ',')
return key
def update_mapping_csv(original_csv, webdataset_dir, new_csv_path):
df = pd.read_csv(original_csv)
webdataset_dir = Path(webdataset_dir)
shards = list(webdataset_dir.glob("*.tar"))
# Create mapping: key -> shard_path
key_to_shard = {}
for shard in shards:
dataset = wds.WebDataset(str(shard), empty_check=False)
for sample in dataset:
key = sample["__key__"]
key_to_shard[key] = str(shard)
df["key"] = df["local_path"].apply(lambda p: convert_path_to_key(Path(p)))
df["shard_path"] = df["key"].map(key_to_shard)
# ❗ Raise an error if any shard_path is NaN
if df["shard_path"].isna().any():
missing_keys = df[df["shard_path"].isna()]["key"].tolist()
raise ValueError(f"Missing shard paths for the following keys: {missing_keys[:10]}... (and possibly more)")
df.to_csv(new_csv_path, index=False)
if __name__ == "__main__":
update_mapping_csv(
original_csv="faiss_index/faiss_index_to_local_path.csv",
webdataset_dir="data/webdataset_shards",
new_csv_path="faiss_index/faiss_index_webdataset.csv"
)