|
|
|
import os |
|
|
|
from collections import Counter |
|
from datasets import load_dataset |
|
from huggingface_hub import HfApi, list_datasets |
|
|
|
|
|
api = HfApi(token=os.environ.get("HF_TOKEN", None)) |
|
def restart_space(): |
|
api.restart_space(repo_id="OpenGenAI/parti-prompts-leaderboard") |
|
|
|
parti_prompt_results = [] |
|
ORG = "diffusers-parti-prompts" |
|
SUBMISSIONS = { |
|
"kand2": None, |
|
"sdxl": None, |
|
"wuerst": None, |
|
"karlo": None, |
|
} |
|
LINKS = { |
|
"kand2": "https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder", |
|
"sdxl": "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0", |
|
"wuerst": "https://huggingface.co/warp-ai/wuerstchen", |
|
"karlo": "https://huggingface.co/kakaobrain/karlo-v1-alpha", |
|
} |
|
MODEL_KEYS = "-".join(SUBMISSIONS.keys()) |
|
SUBMISSION_ORG = f"result-{MODEL_KEYS}" |
|
|
|
submission_names = list(SUBMISSIONS.keys()) |
|
|
|
ORG = "diffusers-parti-prompts" |
|
SUBMISSIONS = { |
|
"kand2": load_dataset(os.path.join(ORG, "kandinsky-2-2"))["train"], |
|
"sdxl": load_dataset(os.path.join(ORG, "sdxl-1.0-refiner"))["train"], |
|
"wuerst": load_dataset(os.path.join(ORG, "wuerstchen"))["train"], |
|
"karlo": load_dataset(os.path.join(ORG, "karlo-v1"))["train"], |
|
} |
|
ds = load_dataset("nateraw/parti-prompts")["train"] |
|
|
|
parti_prompt_categories = ds["Category"] |
|
parti_prompt_challenge = ds["Challenge"] |
|
|
|
UPLOAD_ORG = "almost-agi-diff" |
|
|
|
def load_non_solved(): |
|
all_datasets = list_datasets(author=SUBMISSION_ORG) |
|
relevant_ids = [d.id for d in all_datasets] |
|
|
|
all_non_solved_image_ids = [] |
|
|
|
for _id in relevant_ids[:5]: |
|
try: |
|
ds = load_dataset(_id)["train"] |
|
except: |
|
continue |
|
|
|
for result, image_id in zip(ds["result"], ds["id"]): |
|
if result == "": |
|
all_non_solved_image_ids.append(image_id) |
|
|
|
all_non_solved_image_ids_dict = Counter(all_non_solved_image_ids) |
|
all_non_solved_image_ids = list(all_non_solved_image_ids_dict.keys()) |
|
all_non_solved_image_votes = list(all_non_solved_image_ids_dict.values()) |
|
|
|
return all_non_solved_image_ids, all_non_solved_image_votes |
|
|
|
def main(): |
|
non_solved_ids, upvotes = load_non_solved() |
|
|
|
for name, ds in SUBMISSIONS.items(): |
|
ds_to_push = ds.select(non_solved_ids) |
|
|
|
votes_column = upvotes |
|
|
|
ds_to_push.add_column("upvotes", votes_column) |
|
sorted_ds = ds_to_push.sort("upvotes", reverse=True) |
|
|
|
import ipdb; ipdb.set_trace() |
|
sorted_ds.push_to_hub(f"{UPLOAD_ORG}/{name}") |
|
|