import logging
import os
import traceback
import torch
from transformers import CLIPTokenizerFast

def model_options_long_clip(sd, tokenizer_data, model_options):
    w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
    if w is None:
        w = sd.get("text_model.embeddings.position_embedding.weight", None)
    return tokenizer_data, model_options

def parse_parentheses(string: str) -> list:
    """#### Parse a string with nested parentheses.

    #### Args:
        - `string` (str): The input string.

    #### Returns:
        - `list`: The parsed list of strings.
    """
    result = []
    current_item = ""
    nesting_level = 0
    for char in string:
        if char == "(":
            if nesting_level == 0:
                if current_item:
                    result.append(current_item)
                    current_item = "("
                else:
                    current_item = "("
            else:
                current_item += char
            nesting_level += 1
        elif char == ")":
            nesting_level -= 1
            if nesting_level == 0:
                result.append(current_item + ")")
                current_item = ""
            else:
                current_item += char
        else:
            current_item += char
    if current_item:
        result.append(current_item)
    return result


def token_weights(string: str, current_weight: float) -> list:
    """#### Parse a string into tokens with weights.

    #### Args:
        - `string` (str): The input string.
        - `current_weight` (float): The current weight.

    #### Returns:
        - `list`: The list of token-weight pairs.
    """
    a = parse_parentheses(string)
    out = []
    for x in a:
        weight = current_weight
        if len(x) >= 2 and x[-1] == ")" and x[0] == "(":
            x = x[1:-1]
            xx = x.rfind(":")
            weight *= 1.1
            if xx > 0:
                try:
                    weight = float(x[xx + 1 :])
                    x = x[:xx]
                except:
                    pass
            out += token_weights(x, weight)
        else:
            out += [(x, current_weight)]
    return out


def escape_important(text: str) -> str:
    """#### Escape important characters in a string.

    #### Args:
        - `text` (str): The input text.

    #### Returns:
        - `str`: The escaped text.
    """
    text = text.replace("\\)", "\0\1")
    text = text.replace("\\(", "\0\2")
    return text


def unescape_important(text: str) -> str:
    """#### Unescape important characters in a string.

    #### Args:
        - `text` (str): The input text.

    #### Returns:
        - `str`: The unescaped text.
    """
    text = text.replace("\0\1", ")")
    text = text.replace("\0\2", "(")
    return text


def expand_directory_list(directories: list) -> list:
    """#### Expand a list of directories to include all subdirectories.

    #### Args:
        - `directories` (list): The list of directories.

    #### Returns:
        - `list`: The expanded list of directories.
    """
    dirs = set()
    for x in directories:
        dirs.add(x)
        for root, subdir, file in os.walk(x, followlinks=True):
            dirs.add(root)
    return list(dirs)


def load_embed(embedding_name: str, embedding_directory: list, embedding_size: int, embed_key: str = None) -> torch.Tensor:
    """#### Load an embedding from a directory.

    #### Args:
        - `embedding_name` (str): The name of the embedding.
        - `embedding_directory` (list): The list of directories to search.
        - `embedding_size` (int): The size of the embedding.
        - `embed_key` (str, optional): The key for the embedding. Defaults to None.

    #### Returns:
        - `torch.Tensor`: The loaded embedding.
    """
    if isinstance(embedding_directory, str):
        embedding_directory = [embedding_directory]

    embedding_directory = expand_directory_list(embedding_directory)

    valid_file = None
    for embed_dir in embedding_directory:
        embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name))
        embed_dir = os.path.abspath(embed_dir)
        try:
            if os.path.commonpath((embed_dir, embed_path)) != embed_dir:
                continue
        except:
            continue
        if not os.path.isfile(embed_path):
            extensions = [".safetensors", ".pt", ".bin"]
            for x in extensions:
                t = embed_path + x
                if os.path.isfile(t):
                    valid_file = t
                    break
        else:
            valid_file = embed_path
        if valid_file is not None:
            break

    if valid_file is None:
        return None

    embed_path = valid_file

    embed_out = None

    try:
        if embed_path.lower().endswith(".safetensors"):
            import safetensors.torch

            embed = safetensors.torch.load_file(embed_path, device="cpu")
        else:
            if "weights_only" in torch.load.__code__.co_varnames:
                embed = torch.load(embed_path, weights_only=True, map_location="cpu")
            else:
                embed = torch.load(embed_path, map_location="cpu")
    except Exception:
        logging.warning(
            "{}\n\nerror loading embedding, skipping loading: {}".format(
                traceback.format_exc(), embedding_name
            )
        )
        return None

    if embed_out is None:
        if "string_to_param" in embed:
            values = embed["string_to_param"].values()
            embed_out = next(iter(values))
        elif isinstance(embed, list):
            out_list = []
            for x in range(len(embed)):
                for k in embed[x]:
                    t = embed[x][k]
                    if t.shape[-1] != embedding_size:
                        continue
                    out_list.append(t.reshape(-1, t.shape[-1]))
            embed_out = torch.cat(out_list, dim=0)
        elif embed_key is not None and embed_key in embed:
            embed_out = embed[embed_key]
        else:
            values = embed.values()
            embed_out = next(iter(values))
    return embed_out


class SDTokenizer:
    """#### Class representing a Stable Diffusion tokenizer."""

    def __init__(
        self,
        tokenizer_path: str = None,
        max_length: int = 77,
        pad_with_end: bool = True,
        embedding_directory: str = None,
        embedding_size: int = 768,
        embedding_key: str = "clip_l",
        tokenizer_class: type = CLIPTokenizerFast,
        has_start_token: bool = True,
        pad_to_max_length: bool = True,
        min_length: int = None,
    ):
        """#### Initialize the SDTokenizer.

        #### Args:
            - `tokenizer_path` (str, optional): The path to the tokenizer. Defaults to None.
            - `max_length` (int, optional): The maximum length of the input. Defaults to 77.
            - `pad_with_end` (bool, optional): Whether to pad with the end token. Defaults to True.
            - `embedding_directory` (str, optional): The directory for embeddings. Defaults to None.
            - `embedding_size` (int, optional): The size of the embeddings. Defaults to 768.
            - `embedding_key` (str, optional): The key for the embeddings. Defaults to "clip_l".
            - `tokenizer_class` (type, optional): The tokenizer class. Defaults to CLIPTokenizer.
            - `has_start_token` (bool, optional): Whether the tokenizer has a start token. Defaults to True.
            - `pad_to_max_length` (bool, optional): Whether to pad to the maximum length. Defaults to True.
            - `min_length` (int, optional): The minimum length of the input. Defaults to None.
        """
        if tokenizer_path is None:
            tokenizer_path = "_internal/sd1_tokenizer/"
        self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
        self.max_length = max_length
        self.min_length = min_length

        empty = self.tokenizer("")["input_ids"]
        if has_start_token:
            self.tokens_start = 1
            self.start_token = empty[0]
            self.end_token = empty[1]
        else:
            self.tokens_start = 0
            self.start_token = None
            self.end_token = empty[0]
        self.pad_with_end = pad_with_end
        self.pad_to_max_length = pad_to_max_length

        vocab = self.tokenizer.get_vocab()
        self.inv_vocab = {v: k for k, v in vocab.items()}
        self.embedding_directory = embedding_directory
        self.max_word_length = 8
        self.embedding_identifier = "embedding:"
        self.embedding_size = embedding_size
        self.embedding_key = embedding_key

    def _try_get_embedding(self, embedding_name: str) -> tuple:
        """#### Try to get an embedding.

        #### Args:
            - `embedding_name` (str): The name of the embedding.

        #### Returns:
            - `tuple`: The embedding and any leftover text.
        """
        embed = load_embed(
            embedding_name,
            self.embedding_directory,
            self.embedding_size,
            self.embedding_key,
        )
        if embed is None:
            stripped = embedding_name.strip(",")
            if len(stripped) < len(embedding_name):
                embed = load_embed(
                    stripped,
                    self.embedding_directory,
                    self.embedding_size,
                    self.embedding_key,
                )
                return (embed, embedding_name[len(stripped) :])
        return (embed, "")

    def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> list:
        """#### Tokenize text with weights.

        #### Args:
            - `text` (str): The input text.
            - `return_word_ids` (bool, optional): Whether to return word IDs. Defaults to False.

        #### Returns:
            - `list`: The tokenized text with weights.
        """
        if self.pad_with_end:
            pad_token = self.end_token
        else:
            pad_token = 0

        text = escape_important(text)
        parsed_weights = token_weights(text, 1.0)

        # tokenize words
        tokens = []
        for weighted_segment, weight in parsed_weights:
            to_tokenize = (
                unescape_important(weighted_segment).replace("\n", " ").split(" ")
            )
            to_tokenize = [x for x in to_tokenize if x != ""]
            for word in to_tokenize:
                # if we find an embedding, deal with the embedding
                if (
                    word.startswith(self.embedding_identifier)
                    and self.embedding_directory is not None
                ):
                    embedding_name = word[len(self.embedding_identifier) :].strip("\n")
                    embed, leftover = self._try_get_embedding(embedding_name)
                    if embed is None:
                        logging.warning(
                            f"warning, embedding:{embedding_name} does not exist, ignoring"
                        )
                    else:
                        if len(embed.shape) == 1:
                            tokens.append([(embed, weight)])
                        else:
                            tokens.append(
                                [(embed[x], weight) for x in range(embed.shape[0])]
                            )
                        print("loading ", embedding_name)
                    # if we accidentally have leftover text, continue parsing using leftover, else move on to next word
                    if leftover != "":
                        word = leftover
                    else:
                        continue
                # parse word
                tokens.append(
                    [
                        (t, weight)
                        for t in self.tokenizer(word)["input_ids"][
                            self.tokens_start : -1
                        ]
                    ]
                )

        # reshape token array to CLIP input size
        batched_tokens = []
        batch = []
        if self.start_token is not None:
            batch.append((self.start_token, 1.0, 0))
        batched_tokens.append(batch)
        for i, t_group in enumerate(tokens):
            # determine if we're going to try and keep the tokens in a single batch
            is_large = len(t_group) >= self.max_word_length

            while len(t_group) > 0:
                if len(t_group) + len(batch) > self.max_length - 1:
                    remaining_length = self.max_length - len(batch) - 1
                    # break word in two and add end token
                    if is_large:
                        batch.extend(
                            [(t, w, i + 1) for t, w in t_group[:remaining_length]]
                        )
                        batch.append((self.end_token, 1.0, 0))
                        t_group = t_group[remaining_length:]
                    # add end token and pad
                    else:
                        batch.append((self.end_token, 1.0, 0))
                        if self.pad_to_max_length:
                            batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
                    # start new batch
                    batch = []
                    if self.start_token is not None:
                        batch.append((self.start_token, 1.0, 0))
                    batched_tokens.append(batch)
                else:
                    batch.extend([(t, w, i + 1) for t, w in t_group])
                    t_group = []

        # fill last batch
        batch.append((self.end_token, 1.0, 0))
        if self.pad_to_max_length:
            batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
        if self.min_length is not None and len(batch) < self.min_length:
            batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))

        if not return_word_ids:
            batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]

        return batched_tokens

    def untokenize(self, token_weight_pair: list) -> list:
        """#### Untokenize a list of token-weight pairs.

        #### Args:
            - `token_weight_pair` (list): The list of token-weight pairs.

        #### Returns:
            - `list`: The untokenized list.
        """
        return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))


class SD1Tokenizer:
    """#### Class representing the SD1Tokenizer."""

    def __init__(self, embedding_directory: str = None, clip_name: str = "l", tokenizer: type = SDTokenizer):
        """#### Initialize the SD1Tokenizer.

        #### Args:
            - `embedding_directory` (str, optional): The directory for embeddings. Defaults to None.
            - `clip_name` (str, optional): The name of the CLIP model. Defaults to "l".
            - `tokenizer` (type, optional): The tokenizer class. Defaults to SDTokenizer.
        """
        self.clip_name = clip_name
        self.clip = "clip_{}".format(self.clip_name)
        setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))

    def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict:
        """#### Tokenize text with weights.

        #### Args:
            - `text` (str): The input text.
            - `return_word_ids` (bool, optional): Whether to return word IDs. Defaults to False.

        #### Returns:
            - `dict`: The tokenized text with weights.
        """
        out = {}
        out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(
            text, return_word_ids
        )
        return out

    def untokenize(self, token_weight_pair: list) -> list:
        """#### Untokenize a list of token-weight pairs.

        #### Args:
            - `token_weight_pair` (list): The list of token-weight pairs.

        #### Returns:
            - `list`: The untokenized list.
        """
        return getattr(self, self.clip).untokenize(token_weight_pair)