"""
All the functions to build the relevant models and modules
from the Hydra config.
"""

import typing as tp

import omegaconf
import torch
from codeclm.utils.utils import dict_from_config
from codeclm.modules.pattern import (
    CodebooksPatternProvider,
    DelayedPatternProvider,
)
from codeclm.modules.conditioners import (
    BaseConditioner,
    QwTokenizerConditioner,
    QwTextConditioner,
    QuantizedEmbeddingConditioner,
    ConditionerProvider,
    ConditionFuser,
)


def get_audio_tokenizer_model(checkpoint_path: str, cfg: omegaconf.DictConfig):
    from codeclm.tokenizer.audio_tokenizer import AudioTokenizer
    """Instantiate a compression model."""
    if checkpoint_path is None:
        return None
    if checkpoint_path.startswith('//pretrained/'):
        name = checkpoint_path.split('/', 3)[-1]
        return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode)
    elif checkpoint_path == "":
        return None
    else:
        name = checkpoint_path
        return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode)
    
def get_lm_model(cfg: omegaconf.DictConfig): #-> LMModel:
    """Instantiate a LM."""    
    lm_kwargs = dict_from_config(getattr(cfg, 'lm'))
    
    # n_q: number of RVQ
    code_depth = lm_kwargs['code_depth']
    q_modeling = lm_kwargs.pop('q_modeling', None)    
        
    # conditioner
    condition_provider = get_conditioner_provider(lm_kwargs["dim"], cfg)
    
    # codebook pattern: delay
    codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
    if codebooks_pattern_cfg.modeling is None:
        assert q_modeling is not None, \
            "LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
        codebooks_pattern_cfg = omegaconf.OmegaConf.create(
            {'modeling': q_modeling, 'delay': {'delays': list(range(code_depth))}}
        )
    pattern_provider = get_codebooks_pattern_provider(code_depth, codebooks_pattern_cfg)
    
    # condition dropout
    attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
    cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
    cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
    
    # condition fuser
    fuser = get_condition_fuser(cfg)    
    lm_type = lm_kwargs['lm_type'] # YCY: For consistency, choose different lm.py based on lm_type
    if lm_type == 'Llama':
        from .lm_levo import LmModel
        return LmModel(
            pattern_provider=pattern_provider,
            condition_provider=condition_provider,
            fuser=fuser,
            cfg_dropout=cfg_prob,
            cfg_coef=cfg_coef,
            attribute_dropout=attribute_dropout,
            cfg=cfg,
            **lm_kwargs
        ).to('cpu')
    else:
        raise KeyError(f"Unexpected LM model {lm_type}")


def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditionerProvider:
    """Instantiate a conditioning model."""    
    cfg = getattr(cfg, 'conditioners')
    dict_cfg = {} if cfg is None else dict_from_config(cfg)
    conditioners: tp.Dict[str, BaseConditioner] = {}
    condition_provider_args = dict_cfg.pop('args', {})

    for cond, cond_cfg in dict_cfg.items():
        model_type = cond_cfg['model']
        model_args = cond_cfg[model_type]
        if model_type == 'QwTokenizer':
            conditioners[str(cond)] = QwTokenizerConditioner(
                output_dim=output_dim,
                **model_args
            )
        elif model_type == "QwTextTokenizer":
            conditioners[str(cond)] = QwTextConditioner(
                output_dim=output_dim,
                **model_args
            )
        elif model_type == "qt_embedding":
            conditioners[str(cond)] = QuantizedEmbeddingConditioner(
                dim=output_dim,
                **model_args
            )
        else:
            raise ValueError(f"Unrecognized conditioning model: {model_type}")
    conditioner = ConditionerProvider(conditioners, **condition_provider_args)
    return conditioner


def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
    """Instantiate a condition fuser object."""
    fuser_cfg = getattr(cfg, 'fuser')
    fuser_methods = ['sum', 'prepend']
    fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
    kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
    fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
    return fuser


def get_codebooks_pattern_provider(code_depth: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
    """Instantiate a codebooks pattern provider object."""
    pattern_providers = {
        'delay': DelayedPatternProvider,
    }
    name = cfg.modeling
    kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
    klass = pattern_providers[name]
    return klass(code_depth, **kwargs)