import hashlib
import os
import sys
import shutil
from functools import wraps, partial

import pytest

if os.path.dirname('src') not in sys.path:
    sys.path.append('src')

os.environ['HARD_ASSERTS'] = "1"

from src.utils import call_subprocess_onetask, makedirs, FakeTokenizer


def get_inf_port():
    if os.getenv('HOST') is not None:
        inf_port = os.environ['HOST'].split(':')[-1]
    elif os.getenv('GRADIO_SERVER_PORT') is not None:
        inf_port = os.environ['GRADIO_SERVER_PORT']
    else:
        inf_port = str(7860)
    return int(inf_port)


def get_inf_server():
    if os.getenv('HOST') is not None:
        inf_server = os.environ['HOST']
    elif os.getenv('GRADIO_SERVER_PORT') is not None:
        inf_server = "http://localhost:%s" % os.environ['GRADIO_SERVER_PORT']
    else:
        raise ValueError("Expect tests to set HOST or GRADIO_SERVER_PORT")
    return inf_server


def get_mods():
    testtotalmod = int(os.getenv('TESTMODULOTOTAL', '1'))
    testmod = int(os.getenv('TESTMODULO', '0'))
    return testtotalmod, testmod


def do_skip_test(name):
    """
    Control if skip test.  note that skipping all tests does not fail, doing no tests is what fails
    :param name:
    :return:
    """
    testtotalmod, testmod = get_mods()
    return int(get_sha(name), 16) % testtotalmod != testmod


def wrap_test_forked(func):
    """Decorate a function to test, call in subprocess"""

    @wraps(func)
    def f(*args, **kwargs):
        # automatically list or set, so can globally control server ports or host for all tests
        gradio_port = os.environ['GRADIO_SERVER_PORT'] = os.getenv('GRADIO_SERVER_PORT', str(7860))
        gradio_port = int(gradio_port)
        # testtotalmod, testmod = get_mods()
        # gradio_port += testmod
        os.environ['HOST'] = os.getenv('HOST', "http://localhost:%s" % gradio_port)

        pytest_name = get_test_name()
        if do_skip_test(pytest_name):
            # Skipping is based on raw name, so deterministic
            pytest.skip("[%s] TEST SKIPPED due to TESTMODULO" % pytest_name)
        func_new = partial(call_subprocess_onetask, func, args, kwargs)
        return run_test(func_new)

    return f


def run_test(func, *args, **kwargs):
    return func(*args, **kwargs)


def get_sha(value):
    return hashlib.md5(str(value).encode('utf-8')).hexdigest()


def sanitize_filename(name):
    """
    Sanitize file *base* names.  Also used to generation valid class names.
    :param name:
    :return:
    """
    bad_chars = ['[', ']', ',', '/', '\\', '\\w', '\\s', '-', '+', '\"', '\'', '>', '<', ' ', '=', ')', '(', ':', '^']
    for char in bad_chars:
        name = name.replace(char, "_")

    length = len(name)
    file_length_limit = 250  # bit smaller than 256 for safety
    sha_length = 32
    real_length_limit = file_length_limit - (sha_length + 2)
    if length > file_length_limit:
        sha = get_sha(name)
        half_real_length_limit = max(1, int(real_length_limit / 2))
        name = name[0:half_real_length_limit] + "_" + sha + "_" + name[length - half_real_length_limit:length]

    return name


def get_test_name():
    tn = os.environ['PYTEST_CURRENT_TEST'].split(':')[-1]
    tn = "_".join(tn.split(' ')[:-1])  # skip (call) at end
    return sanitize_filename(tn)


def make_user_path_test():
    import os
    import shutil
    user_path = makedirs('user_path_test', use_base=True)
    if os.path.isdir(user_path):
        shutil.rmtree(user_path)
    user_path = makedirs('user_path_test', use_base=True)
    db_dir = "db_dir_UserData"
    db_dir = makedirs(db_dir, use_base=True)
    if os.path.isdir(db_dir):
        shutil.rmtree(db_dir)
    db_dir = makedirs(db_dir, use_base=True)
    shutil.copy('data/pexels-evg-kowalievska-1170986_small.jpg', user_path)
    shutil.copy('README.md', user_path)
    shutil.copy('docs/FAQ.md', user_path)
    return user_path


def get_llama(llama_type=2):
    from huggingface_hub import hf_hub_download

    # FIXME: Pass into main()
    if llama_type == 1:
        file = 'ggml-model-q4_0_7b.bin'
        dest = 'models/7B/'
        prompt_type = 'plain'
    elif llama_type == 2:
        file = 'WizardLM-7B-uncensored.ggmlv3.q8_0.bin'
        dest = './'
        prompt_type = 'wizard2'
    else:
        raise ValueError("unknown llama_type=%s" % llama_type)

    makedirs(dest, exist_ok=True)
    full_path = os.path.join(dest, file)

    if not os.path.isfile(full_path):
        # True for case when locally already logged in with correct token, so don't have to set key
        token = os.getenv('HUGGING_FACE_HUB_TOKEN', True)
        out_path = hf_hub_download('h2oai/ggml', file, token=token, repo_type='model')
        # out_path will look like '/home/jon/.cache/huggingface/hub/models--h2oai--ggml/snapshots/57e79c71bb0cee07e3e3ffdea507105cd669fa96/ggml-model-q4_0_7b.bin'
        shutil.copy(out_path, dest)
    return prompt_type, full_path


def kill_weaviate(db_type):
    """
    weaviate launches detatched server, which accumulates entries in db, but we want to start freshly
    """
    if db_type == 'weaviate':
        os.system('pkill --signal 9 -f weaviate-embedded/weaviate')


def count_tokens_llm(prompt, base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b', tokenizer=None):
    import time
    if tokenizer is None:
        assert base_model is not None
        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained(base_model)
    t0 = time.time()
    a = len(tokenizer(prompt)['input_ids'])
    print('llm: ', a, time.time() - t0)
    return dict(llm=a)


def count_tokens(prompt, base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b'):
    tokenizer = FakeTokenizer()
    num_tokens = tokenizer.num_tokens_from_string(prompt)
    print(num_tokens)

    from transformers import AutoTokenizer

    t = AutoTokenizer.from_pretrained("distilgpt2")
    llm_tokenizer = AutoTokenizer.from_pretrained(base_model)

    from InstructorEmbedding import INSTRUCTOR
    emb = INSTRUCTOR('hkunlp/instructor-large')

    import nltk


    def nltkTokenize(text):
        words = nltk.word_tokenize(text)
        return words


    import re

    WORD = re.compile(r'\w+')


    def regTokenize(text):
        words = WORD.findall(text)
        return words

    counts = {}
    import time
    t0 = time.time()
    a = len(regTokenize(prompt))
    print('reg: ', a, time.time() - t0)
    counts.update(dict(reg=a))

    t0 = time.time()
    a = len(nltkTokenize(prompt))
    print('nltk: ', a, time.time() - t0)
    counts.update(dict(nltk=a))

    t0 = time.time()
    a = len(t(prompt)['input_ids'])
    print('tiktoken: ', a, time.time() - t0)
    counts.update(dict(tiktoken=a))

    t0 = time.time()
    a = len(llm_tokenizer(prompt)['input_ids'])
    print('llm: ', a, time.time() - t0)
    counts.update(dict(llm=a))

    t0 = time.time()
    a = emb.tokenize([prompt])['input_ids'].shape[1]
    print('instructor-large: ', a, time.time() - t0)
    counts.update(dict(instructor=a))

    return counts