|
from huggingface_hub import snapshot_download |
|
from clip_retrieval.clip_back import load_clip_indices, KnnService, ClipOptions |
|
from collections import defaultdict |
|
import os |
|
import glob |
|
import shutil |
|
import random |
|
|
|
class FeatureRetriever: |
|
def __init__(self, |
|
num_images=50, |
|
imgs_per_dir=15, |
|
force_download=False): |
|
|
|
if force_download or not os.path.exists("./clip"): |
|
print("Downloading clip resources") |
|
rand_num = random.randint(0, 100000) |
|
tmp_dir = f"./tmp_{rand_num}" |
|
snapshot_download(repo_type="dataset", repo_id="wendlerc/sdxl-unbox-clip-indices", cache_dir=tmp_dir) |
|
clip_dirs = glob.glob(f"{tmp_dir}/**/down_10_5120", recursive=True) |
|
if len(clip_dirs) > 0: |
|
shutil.copytree(clip_dirs[0].replace("down_10_5120", ""), "./clip", dirs_exist_ok=True) |
|
shutil.rmtree(tmp_dir) |
|
else: |
|
ValueError("Could not find clip indices in the downloaded repo.") |
|
|
|
|
|
clip_options = ClipOptions( |
|
indice_folder="currently unused by knn.query()", |
|
clip_model="ViT-B/32", |
|
enable_hdf5=False, |
|
enable_faiss_memory_mapping=True, |
|
columns_to_return=["image_path", "similarity"], |
|
reorder_metadata_by_ivf_index=False, |
|
enable_mclip_option=False, |
|
use_jit=False, |
|
use_arrow=False, |
|
provide_safety_model=False, |
|
provide_violence_detector=False, |
|
provide_aesthetic_embeddings=False, |
|
) |
|
self.names = ["down.2.1", "mid.0", "up.0.0", "up.0.1"] |
|
self.paths = ["./clip/down_10_5120/indices_paths.json", |
|
"./clip/mid_10_5120/indices_paths.json", |
|
"./clip/up0_10_5120/indices_paths.json", |
|
"./clip/up_10_5120/indices_paths.json",] |
|
self.knn_service = {} |
|
for name, path in zip(self.names, self.paths): |
|
resources = load_clip_indices(path, clip_options) |
|
self.knn_service[name] = KnnService(clip_resources=resources) |
|
self.num_images = num_images |
|
self.imgs_per_dir = imgs_per_dir |
|
|
|
def query_text(self, query, block): |
|
if block not in self.names: |
|
raise ValueError(f"Block must be one of {self.names}") |
|
results = self.knn_service[block].query( |
|
text_input=query, |
|
num_images=self.num_images, |
|
num_result_ids=self.num_images, |
|
deduplicate=True, |
|
) |
|
feat_sims = defaultdict(list) |
|
feat_scores = {} |
|
for result in results: |
|
feature_id = result["image_path"].split("/")[-2] |
|
feat_sims[feature_id] += [result["similarity"]] |
|
for fid, sims in feat_sims.items(): |
|
feat_scores[fid] = (sum(sims) / len(sims)) * (len(sims)/self.imgs_per_dir) |
|
|
|
return dict(sorted(feat_scores.items(), key=lambda item: -item[1])) |
|
|