import sys
import os

import time
import json
import torch
import torchaudio
import numpy as np
from omegaconf import OmegaConf

from codeclm.trainer.codec_song_pl import CodecLM_PL
from codeclm.models import CodecLM
from third_party.demucs.models.pretrained import get_model_from_yaml


class Separator:
    def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
        if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
            self.device = torch.device(f"cuda:{gpu_id}")
        else:
            self.device = torch.device("cpu")
        self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)

    def init_demucs_model(self, model_path, config_path):
        model = get_model_from_yaml(config_path, model_path)
        model.to(self.device)
        model.eval()
        return model
    
    def load_audio(self, f):
        a, fs = torchaudio.load(f)
        if (fs != 48000):
            a = torchaudio.functional.resample(a, fs, 48000)
        if a.shape[-1] >= 48000*10:
            a = a[..., :48000*10]
        else:
            a = torch.cat([a, a], -1)
        return a[:, 0:48000*10]
    
    def run(self, audio_path, output_dir='tmp', ext=".flac"):
        os.makedirs(output_dir, exist_ok=True)
        name, _ = os.path.splitext(os.path.split(audio_path)[-1])
        output_paths = []

        for stem in self.demucs_model.sources:
            output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
            if os.path.exists(output_path):
                output_paths.append(output_path)
        if len(output_paths) == 1:  # 4
            vocal_path = output_paths[0]
        else:
            drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
            for path in [drums_path, bass_path, other_path]:
                os.remove(path)
        full_audio = self.load_audio(audio_path)
        vocal_audio = self.load_audio(vocal_path)
        bgm_audio = full_audio - vocal_audio
        return full_audio, vocal_audio, bgm_audio


def main_sep():
    torch.backends.cudnn.enabled = False #taiji的某些傻呗node会报奇奇怪怪的错
    OmegaConf.register_new_resolver("eval", lambda x: eval(x))
    OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
    OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
    OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
    cfg = OmegaConf.load(sys.argv[1])
    save_dir = sys.argv[2]
    input_jsonl = sys.argv[3]
    sidx = sys.argv[4]
    cfg.mode = 'inference'
    max_duration = cfg.max_dur
    
    # Define model or load pretrained model
    model_light = CodecLM_PL(cfg)

    model_light = model_light.eval().cuda()
    model_light.audiolm.cfg = cfg
    model = CodecLM(name = "tmp",
        lm = model_light.audiolm,
        audiotokenizer = model_light.audio_tokenizer,
        max_duration = max_duration,
        seperate_tokenizer = model_light.seperate_tokenizer,
    )
    separator = Separator()
    
    cfg_coef = 1.5 #25
    temp = 1.0
    top_k = 50
    top_p = 0.0
    record_tokens = True
    record_window = 50

    model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
                                top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
    os.makedirs(save_dir + "/token", exist_ok=True)
    os.makedirs(save_dir + "/audios", exist_ok=True)
    os.makedirs(save_dir + "/jsonl", exist_ok=True)

    with open(input_jsonl, "r") as fp:
        lines = fp.readlines()

    new_items = []
    for line in lines:
        item = json.loads(line)
        target_name = f"{save_dir}/token/{item['idx']}_s{sidx}.npy"
        target_wav_name = f"{save_dir}/audios/{item['idx']}_s{sidx}.flac"
        descriptions = item["descriptions"]
        lyric = item["gt_lyric"]
        
        start_time = time.time()
        pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
        generate_inp = {
            'lyrics': [lyric.replace("  ", " ")],
            'descriptions': [descriptions],
            'melody_wavs': pmt_wav,
            'vocal_wavs': vocal_wav,
            'bgm_wavs': bgm_wav,
        }

        mid_time = time.time()
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            tokens = model.generate(**generate_inp, return_tokens=True)
        end_time = time.time()
        if tokens.shape[-1] > 3000:
            tokens = tokens[..., :3000]
            
        with torch.no_grad():
            wav_seperate = model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav)
        torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
        np.save(target_name, tokens.cpu().squeeze(0).numpy())
        print(f"process{item['idx']}, demucs cost {mid_time - start_time}s, lm cos {end_time - mid_time}")

        item["idx"] = f"{item['idx']}_s{sidx}"
        item["tk_path"] = target_name
        new_items.append(item)
    
    src_jsonl_name = os.path.split(input_jsonl)[-1]
    with open(f"{save_dir}/jsonl/{src_jsonl_name}-s{sidx}.jsonl", "w", encoding='utf-8') as fw:
        for item in new_items:
            fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")


if __name__ == "__main__":
    main_sep()