sae_flux / retrieval.py
surokpro2's picture
Upload 47 files
215c4b7 verified
from huggingface_hub import snapshot_download
from clip_retrieval.clip_back import load_clip_indices, KnnService, ClipOptions
from collections import defaultdict
import os
import glob
import shutil
import random
class FeatureRetriever:
def __init__(self,
num_images=50,
imgs_per_dir=15,
force_download=False):
if force_download or not os.path.exists("./clip"):
print("Downloading clip resources")
rand_num = random.randint(0, 100000)
tmp_dir = f"./tmp_{rand_num}"
snapshot_download(repo_type="dataset", repo_id="wendlerc/sdxl-unbox-clip-indices", cache_dir=tmp_dir)
clip_dirs = glob.glob(f"{tmp_dir}/**/down_10_5120", recursive=True)
if len(clip_dirs) > 0:
shutil.copytree(clip_dirs[0].replace("down_10_5120", ""), "./clip", dirs_exist_ok=True)
shutil.rmtree(tmp_dir)
else:
ValueError("Could not find clip indices in the downloaded repo.")
# Initialize CLIP service
clip_options = ClipOptions(
indice_folder="currently unused by knn.query()",
clip_model="ViT-B/32", #"open_clip:ViT-H-14",
enable_hdf5=False,
enable_faiss_memory_mapping=True,
columns_to_return=["image_path", "similarity"],
reorder_metadata_by_ivf_index=False,
enable_mclip_option=False,
use_jit=False,
use_arrow=False,
provide_safety_model=False,
provide_violence_detector=False,
provide_aesthetic_embeddings=False,
)
self.names = ["down.2.1", "mid.0", "up.0.0", "up.0.1"]
self.paths = ["./clip/down_10_5120/indices_paths.json",
"./clip/mid_10_5120/indices_paths.json",
"./clip/up0_10_5120/indices_paths.json",
"./clip/up_10_5120/indices_paths.json",]
self.knn_service = {}
for name, path in zip(self.names, self.paths):
resources = load_clip_indices(path, clip_options)
self.knn_service[name] = KnnService(clip_resources=resources)
self.num_images = num_images
self.imgs_per_dir = imgs_per_dir
def query_text(self, query, block):
if block not in self.names:
raise ValueError(f"Block must be one of {self.names}")
results = self.knn_service[block].query(
text_input=query,
num_images=self.num_images,
num_result_ids=self.num_images,
deduplicate=True,
)
feat_sims = defaultdict(list)
feat_scores = {}
for result in results:
feature_id = result["image_path"].split("/")[-2]
feat_sims[feature_id] += [result["similarity"]]
for fid, sims in feat_sims.items():
feat_scores[fid] = (sum(sims) / len(sims)) * (len(sims)/self.imgs_per_dir)
return dict(sorted(feat_scores.items(), key=lambda item: -item[1]))