tools / update_almost_agi.py
patrickvonplaten's picture
add
56fb00e
raw
history blame
2.49 kB
#!/usr/bin/env python3
import os
from collections import Counter
from datasets import load_dataset
from huggingface_hub import HfApi, list_datasets
api = HfApi(token=os.environ.get("HF_TOKEN", None))
def restart_space():
api.restart_space(repo_id="OpenGenAI/parti-prompts-leaderboard")
parti_prompt_results = []
ORG = "diffusers-parti-prompts"
SUBMISSIONS = {
"kand2": None,
"sdxl": None,
"wuerst": None,
"karlo": None,
}
LINKS = {
"kand2": "https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder",
"sdxl": "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0",
"wuerst": "https://huggingface.co/warp-ai/wuerstchen",
"karlo": "https://huggingface.co/kakaobrain/karlo-v1-alpha",
}
MODEL_KEYS = "-".join(SUBMISSIONS.keys())
SUBMISSION_ORG = f"result-{MODEL_KEYS}"
submission_names = list(SUBMISSIONS.keys())
ORG = "diffusers-parti-prompts"
SUBMISSIONS = {
"kand2": load_dataset(os.path.join(ORG, "kandinsky-2-2"))["train"],
"sdxl": load_dataset(os.path.join(ORG, "sdxl-1.0-refiner"))["train"],
"wuerst": load_dataset(os.path.join(ORG, "wuerstchen"))["train"],
"karlo": load_dataset(os.path.join(ORG, "karlo-v1"))["train"],
}
ds = load_dataset("nateraw/parti-prompts")["train"]
parti_prompt_categories = ds["Category"]
parti_prompt_challenge = ds["Challenge"]
UPLOAD_ORG = "almost-agi-diff"
def load_non_solved():
all_datasets = list_datasets(author=SUBMISSION_ORG)
relevant_ids = [d.id for d in all_datasets]
all_non_solved_image_ids = []
for _id in relevant_ids[:5]:
try:
ds = load_dataset(_id)["train"]
except:
continue
for result, image_id in zip(ds["result"], ds["id"]):
if result == "":
all_non_solved_image_ids.append(image_id)
all_non_solved_image_ids_dict = Counter(all_non_solved_image_ids)
all_non_solved_image_ids = list(all_non_solved_image_ids_dict.keys())
all_non_solved_image_votes = list(all_non_solved_image_ids_dict.values())
return all_non_solved_image_ids, all_non_solved_image_votes
def main():
non_solved_ids, upvotes = load_non_solved()
for name, ds in SUBMISSIONS.items():
ds_to_push = ds.select(non_solved_ids)
votes_column = upvotes
ds_to_push.add_column("upvotes", votes_column)
sorted_ds = ds_to_push.sort("upvotes", reverse=True)
import ipdb; ipdb.set_trace()
sorted_ds.push_to_hub(f"{UPLOAD_ORG}/{name}")