from toolkit.paths import MODELS_PATH
import requests
import os
import json
import tqdm


class ModelCache:
    def __init__(self):
        self.raw_cache = {}
        self.cache_path = os.path.join(MODELS_PATH, '.ai_toolkit_cache.json')
        if os.path.exists(self.cache_path):
            with open(self.cache_path, 'r') as f:
                all_cache = json.load(f)
                if 'models' in all_cache:
                    self.raw_cache = all_cache['models']
                else:
                    self.raw_cache = all_cache

    def get_model_path(self, model_id: int, model_version_id: int = None):
        if str(model_id) not in self.raw_cache:
            return None
        if model_version_id is None:
            # get latest version
            model_version_id = max([int(x) for x in self.raw_cache[str(model_id)].keys()])
            if model_version_id is None:
                return None
            model_path = self.raw_cache[str(model_id)][str(model_version_id)]['model_path']
            # check if model path exists
            if not os.path.exists(model_path):
                # remove version from cache
                del self.raw_cache[str(model_id)][str(model_version_id)]
                self.save()
                return None
            return model_path
        else:
            if str(model_version_id) not in self.raw_cache[str(model_id)]:
                return None
            model_path = self.raw_cache[str(model_id)][str(model_version_id)]['model_path']
            # check if model path exists
            if not os.path.exists(model_path):
                # remove version from cache
                del self.raw_cache[str(model_id)][str(model_version_id)]
                self.save()
                return None
            return model_path

    def update_cache(self, model_id: int, model_version_id: int, model_path: str):
        if str(model_id) not in self.raw_cache:
            self.raw_cache[str(model_id)] = {}
        if str(model_version_id) not in self.raw_cache[str(model_id)]:
            self.raw_cache[str(model_id)][str(model_version_id)] = {}
        self.raw_cache[str(model_id)][str(model_version_id)] = {
            'model_path': model_path
        }
        self.save()

    def save(self):
        if not os.path.exists(os.path.dirname(self.cache_path)):
            os.makedirs(os.path.dirname(self.cache_path), exist_ok=True)
        all_cache = {'models': {}}
        if os.path.exists(self.cache_path):
            # load it first
            with open(self.cache_path, 'r') as f:
                all_cache = json.load(f)

        all_cache['models'] = self.raw_cache

        with open(self.cache_path, 'w') as f:
            json.dump(all_cache, f, indent=2)


def get_model_download_info(model_id: int, model_version_id: int = None):
    # curl https://civitai.com/api/v1/models?limit=3&types=TextualInversion \
    # -H "Content-Type: application/json" \
    # -X GET
    print(
        f"Getting model info for model id: {model_id}{f' and version id: {model_version_id}' if model_version_id is not None else ''}")
    endpoint = f"https://civitai.com/api/v1/models/{model_id}"

    # get the json
    response = requests.get(endpoint)
    response.raise_for_status()
    model_data = response.json()

    model_version = None

    # go through versions and get the top one if one is not set
    for version in model_data['modelVersions']:
        if model_version_id is not None:
            if str(version['id']) == str(model_version_id):
                model_version = version
                break
        else:
            # get first version
            model_version = version
            break

    if model_version is None:
        raise ValueError(
            f"Could not find a model version for model id: {model_id}{f' and version id: {model_version_id}' if model_version_id is not None else ''}")

    model_file = None
    # go through files and prefer fp16 safetensors
    # "metadata": {
    #   "fp": "fp16",
    #   "size": "pruned",
    #   "format": "SafeTensor"
    # },
    # todo check pickle scans and skip if not good
    # try to get fp16 safetensor
    for file in model_version['files']:
        if file['metadata']['fp'] == 'fp16' and file['metadata']['format'] == 'SafeTensor':
            model_file = file
            break

    if model_file is None:
        # try to get primary
        for file in model_version['files']:
            if file['primary']:
                model_file = file
                break

    if model_file is None:
        # try to get any safetensor
        for file in model_version['files']:
            if file['metadata']['format'] == 'SafeTensor':
                model_file = file
                break

    if model_file is None:
        # try to get any fp16
        for file in model_version['files']:
            if file['metadata']['fp'] == 'fp16':
                model_file = file
                break

    if model_file is None:
        # try to get any
        for file in model_version['files']:
            model_file = file
            break

    if model_file is None:
        raise ValueError(f"Could not find a model file to download for model id: {model_id}")

    return model_file, model_version['id']


def get_model_path_from_url(url: str):
    # get query params form url if they are set
    # https: // civitai.com / models / 25694?modelVersionId = 127742
    query_params = {}
    if '?' in url:
        query_string = url.split('?')[1]
        query_params = dict(qc.split("=") for qc in query_string.split("&"))

    # get model id from url
    model_id = url.split('/')[-1]
    # remove query params from model id
    if '?' in model_id:
        model_id = model_id.split('?')[0]
    if model_id.isdigit():
        model_id = int(model_id)
    else:
        raise ValueError(f"Invalid model id: {model_id}")

    model_cache = ModelCache()
    model_path = model_cache.get_model_path(model_id, query_params.get('modelVersionId', None))
    if model_path is not None:
        return model_path
    else:
        # download model
        file_info, model_version_id = get_model_download_info(model_id, query_params.get('modelVersionId', None))

        download_url = file_info['downloadUrl']  # url does not work directly
        size_kb = file_info['sizeKB']
        filename = file_info['name']
        model_path = os.path.join(MODELS_PATH, filename)

        # download model
        print(f"Did not find model locally, downloading from model from: {download_url}")

        # use tqdm to show status of downlod
        response = requests.get(download_url, stream=True)
        response.raise_for_status()
        total_size_in_bytes = int(response.headers.get('content-length', 0))
        block_size = 1024  # 1 Kibibyte
        progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
        tmp_path = os.path.join(MODELS_PATH, f".download_tmp_{filename}")
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        # remove tmp file if it exists
        if os.path.exists(tmp_path):
            os.remove(tmp_path)

        try:

            with open(tmp_path, 'wb') as f:
                for data in response.iter_content(block_size):
                    progress_bar.update(len(data))
                    f.write(data)
                progress_bar.close()
                # move to final path
                os.rename(tmp_path, model_path)
                model_cache.update_cache(model_id, model_version_id, model_path)

                return model_path
        except Exception as e:
            # remove tmp file
            os.remove(tmp_path)
            raise e


# if is main
if __name__ == '__main__':
    model_path = get_model_path_from_url("https://civitai.com/models/25694?modelVersionId=127742")
    print(model_path)