|
import pandas as pd |
|
from pathlib import Path |
|
import webdataset as wds |
|
|
|
|
|
|
|
def convert_path_to_key(img_path: Path) -> str: |
|
|
|
|
|
|
|
relative = img_path.relative_to("data/database") |
|
|
|
|
|
no_suffix = relative.with_suffix('') |
|
|
|
|
|
flat = no_suffix.as_posix().replace('/', '_') |
|
|
|
|
|
key = flat.replace('.', ',') |
|
|
|
return key |
|
|
|
|
|
def update_mapping_csv(original_csv, webdataset_dir, new_csv_path): |
|
df = pd.read_csv(original_csv) |
|
|
|
|
|
webdataset_dir = Path(webdataset_dir) |
|
shards = list(webdataset_dir.glob("*.tar")) |
|
|
|
|
|
key_to_shard = {} |
|
for shard in shards: |
|
dataset = wds.WebDataset(str(shard), empty_check=False) |
|
|
|
for sample in dataset: |
|
key = sample["__key__"] |
|
key_to_shard[key] = str(shard) |
|
|
|
|
|
df["key"] = df["local_path"].apply(lambda p: convert_path_to_key(Path(p))) |
|
df["shard_path"] = df["key"].map(key_to_shard) |
|
|
|
if df["shard_path"].isna().any(): |
|
missing_keys = df[df["shard_path"].isna()]["key"].tolist() |
|
raise ValueError(f"Missing shard paths for the following keys: {missing_keys[:10]}... (and possibly more)") |
|
df.to_csv(new_csv_path, index=False) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
update_mapping_csv( |
|
original_csv="faiss_index/faiss_index_to_local_path.csv", |
|
webdataset_dir="data/webdataset_shards", |
|
new_csv_path="faiss_index/faiss_index_webdataset.csv" |
|
) |
|
|