from huggingface_hub import snapshot_download

import os
import json
import time
from collections import defaultdict

from src.submission.check_validity import is_model_on_hub, check_model_card, get_model_tags
from src.leaderboard.read_evals import EvalResult
from src.envs import (
    DYNAMIC_INFO_REPO,
    DYNAMIC_INFO_PATH,
    DYNAMIC_INFO_FILE_PATH,
    API,
    H4_TOKEN,
    ORIGINAL_HF_LEADERBOARD_RESULTS_REPO,
    ORIGINAL_HF_LEADERBOARD_EVAL_RESULTS_PATH,
    GET_ORIGINAL_HF_LEADERBOARD_EVAL_RESULTS
)
from src.display.utils import ORIGINAL_TASKS

def update_models(file_path, models, original_leaderboard_files=None):
    """
    Search through all JSON files in the specified root folder and its subfolders,
    and update the likes key in JSON dict from value of input dict
    """
    with open(file_path, "r") as f:
        model_infos = json.load(f)
        for model_id, data in model_infos.items():
            if model_id not in models:
                data['still_on_hub'] = False
                data['likes'] = 0
                data['downloads'] = 0
                data['created_at'] = ""
                data['original_llm_scores'] = {}
                continue

            model_cfg = models[model_id]
            data['likes'] = model_cfg.likes
            data['downloads'] = model_cfg.downloads
            data['created_at'] = str(model_cfg.created_at)
            #data['params'] = get_model_size(model_cfg, data['precision'])
            data['license'] = model_cfg.card_data.license if model_cfg.card_data is not None else ""
            data['original_llm_scores'] = {}

            # Is the model still on the hub?
            model_name = model_id
            if model_cfg.card_data is not None and hasattr(model_cfg.card_data, "base_model") and model_cfg.card_data.base_model is not None:
                if isinstance(model_cfg.card_data.base_model, str):
                    model_name = model_cfg.card_data.base_model # for adapters, we look at the parent model
            still_on_hub, _, _ = is_model_on_hub(
                model_name=model_name, revision=data.get("revision"), trust_remote_code=True, test_tokenizer=False, token=H4_TOKEN
            )
            data['still_on_hub'] = still_on_hub

            tags = []

            if still_on_hub:
                status, _, _, model_card = check_model_card(model_id)
                tags = get_model_tags(model_card, model_id)

            
            if original_leaderboard_files is not None and model_id in original_leaderboard_files:
                eval_results = {}                
                for filepath in original_leaderboard_files[model_id]:
                    eval_result = EvalResult.init_from_json_file(filepath, is_original=True)
                    # Store results of same eval together
                    eval_name = eval_result.eval_name
                    if eval_name in eval_results.keys():
                        eval_results[eval_name].results.update({k: v for k, v in eval_result.results.items() if v is not None})
                    else:
                        eval_results[eval_name] = eval_result
                for eval_result in eval_results.values():
                    precision = eval_result.precision.value.name
                    if len(eval_result.results) < len(ORIGINAL_TASKS):
                        continue
                    data['original_llm_scores'][precision] = sum([v for v in eval_result.results.values() if v is not None]) / len(ORIGINAL_TASKS)
                        
            data["tags"] = tags

    with open(file_path, 'w') as f:
        json.dump(model_infos, f, indent=2)

def update_dynamic_files():
    """ This will only update metadata for models already linked in the repo, not add missing ones.
    """
    print("update_dynamic_files running...")
    snapshot_download(
        repo_id=DYNAMIC_INFO_REPO, local_dir=DYNAMIC_INFO_PATH, repo_type="dataset", tqdm_class=None, etag_timeout=30
    )

    print("UPDATE_DYNAMIC: Loaded snapshot")
    # Get models
    start = time.time()

    models = list(API.list_models(
        task="text-generation",
        full=False,
        cardData=True,
        fetch_config=True,
    ))
    id_to_model = {model.id : model for model in models}

    id_to_leaderboard_files = defaultdict(list)
    if GET_ORIGINAL_HF_LEADERBOARD_EVAL_RESULTS:
        try:
            print("UPDATE_DYNAMIC: Downloading Original HF Leaderboard results snapshot")
            snapshot_download(
                repo_id=ORIGINAL_HF_LEADERBOARD_RESULTS_REPO, local_dir=ORIGINAL_HF_LEADERBOARD_EVAL_RESULTS_PATH, repo_type="dataset", tqdm_class=None, etag_timeout=30
            )
            #original_leaderboard_files = [] #API.list_repo_files(ORIGINAL_HF_LEADERBOARD_RESULTS_REPO, repo_type='dataset')
            for dirpath,_,filenames in os.walk(ORIGINAL_HF_LEADERBOARD_EVAL_RESULTS_PATH):
                for f in filenames:
                    if not (f.startswith('results_') and f.endswith('.json')):
                        continue
                    
                    filepath = os.path.join(dirpath[len(ORIGINAL_HF_LEADERBOARD_EVAL_RESULTS_PATH)+1:], f)   
                    model_id = filepath[:filepath.find('/results_')]
                    id_to_leaderboard_files[model_id].append(os.path.join(dirpath, f))

            for model_id in id_to_leaderboard_files:
                id_to_leaderboard_files[model_id].sort()
        except Exception as e:
            print(f"UPDATE_DYNAMIC: Could not download original results from : {e}")
            id_to_leaderboard_files = None

    print(f"UPDATE_DYNAMIC: Downloaded list of models in {time.time() - start:.2f} seconds")

    start = time.time()

    update_models(DYNAMIC_INFO_FILE_PATH, id_to_model, id_to_leaderboard_files)

    print(f"UPDATE_DYNAMIC: updated in {time.time() - start:.2f} seconds")

    API.upload_file(
        path_or_fileobj=DYNAMIC_INFO_FILE_PATH,
        path_in_repo=DYNAMIC_INFO_FILE_PATH.split("/")[-1],
        repo_id=DYNAMIC_INFO_REPO,
        repo_type="dataset",
        commit_message=f"Daily request file update.",
    )
    print(f"UPDATE_DYNAMIC: pushed to hub")