import logging
import math
from typing import Union

import torch
import torchaudio
from audio_denoiser.helpers.audio_helper import (
    create_spectrogram,
    reconstruct_from_spectrogram,
)
from audio_denoiser.helpers.torch_helper import batched_apply
from torch import nn

from modules.Denoiser.AudioNosiseModel import load_audio_denosier_model

_expected_t_std = 0.23
_recommended_backend = "soundfile"


# ref: https://github.com/jose-solorzano/audio-denoiser
class AudioDenoiser:
    def __init__(
        self,
        local_dir: str,
        device: Union[str, torch.device] = None,
        num_iterations: int = 100,
    ):
        super().__init__()
        if device is None:
            is_cuda = torch.cuda.is_available()
            if not is_cuda:
                logging.warning("CUDA not available. Will use CPU.")
            device = torch.device("cuda:0") if is_cuda else torch.device("cpu")
        self.device = device
        self.model = load_audio_denosier_model(dir_path=local_dir, device=device)
        self.model.eval()
        self.model_sample_rate = self.model.sample_rate
        self.scaler = self.model.scaler
        self.n_fft = self.model.n_fft
        self.segment_num_frames = self.model.num_frames
        self.num_iterations = num_iterations

    @staticmethod
    def _sp_log(spectrogram: torch.Tensor, eps=0.01):
        return torch.log(spectrogram + eps)

    @staticmethod
    def _sp_exp(log_spectrogram: torch.Tensor, eps=0.01):
        return torch.clamp(torch.exp(log_spectrogram) - eps, min=0)

    @staticmethod
    def _trimmed_dev(waveform: torch.Tensor, q: float = 0.90) -> float:
        # Expected for training data is ~0.23
        abs_waveform = torch.abs(waveform)
        quantile_value = torch.quantile(abs_waveform, q).item()
        trimmed_values = waveform[abs_waveform >= quantile_value]
        return torch.std(trimmed_values).item()

    def process_waveform(
        self,
        waveform: torch.Tensor,
        sample_rate: int,
        return_cpu_tensor: bool = False,
        auto_scale: bool = False,
    ) -> torch.Tensor:
        """
        Denoises a waveform.
        @param waveform: A waveform tensor. Use torchaudio structure.
        @param sample_rate: The sample rate of the waveform in Hz.
        @param return_cpu_tensor: Whether the returned tensor must be a CPU tensor.
        @param auto_scale: Normalize the scale of the waveform before processing. Recommended for low-volume audio.
        @return: A denoised waveform.
        """
        waveform = waveform.cpu()
        if auto_scale:
            w_t_std = self._trimmed_dev(waveform)
            waveform = waveform * _expected_t_std / w_t_std
        if sample_rate != self.model_sample_rate:
            transform = torchaudio.transforms.Resample(
                orig_freq=sample_rate, new_freq=self.model_sample_rate
            )
            waveform = transform(waveform)
        hop_len = self.n_fft // 2
        spectrogram = create_spectrogram(waveform, n_fft=self.n_fft, hop_length=hop_len)
        spectrogram = spectrogram.to(self.device)
        num_a_channels = spectrogram.size(0)
        with torch.no_grad():
            results = []
            for c in range(num_a_channels):
                c_spectrogram = spectrogram[c]
                # c_spectrogram: (257, num_frames)
                fft_size, num_frames = c_spectrogram.shape
                num_segments = math.ceil(num_frames / self.segment_num_frames)
                adj_num_frames = num_segments * self.segment_num_frames
                if adj_num_frames > num_frames:
                    c_spectrogram = nn.functional.pad(
                        c_spectrogram, (0, adj_num_frames - num_frames)
                    )
                c_spectrogram = c_spectrogram.view(
                    fft_size, num_segments, self.segment_num_frames
                )
                # c_spectrogram: (257, num_segments, 32)
                c_spectrogram = torch.permute(c_spectrogram, (1, 0, 2))
                # c_spectrogram: (num_segments, 257, 32)
                log_c_spectrogram = self._sp_log(c_spectrogram)
                scaled_log_c_sp = self.scaler(log_c_spectrogram)
                pred_noise_log_sp = batched_apply(
                    self.model, scaled_log_c_sp, detached=True
                )
                log_denoised_sp = log_c_spectrogram - pred_noise_log_sp
                denoised_sp = self._sp_exp(log_denoised_sp)
                # denoised_sp: (num_segments, 257, 32)
                denoised_sp = torch.permute(denoised_sp, (1, 0, 2))
                # denoised_sp: (257, num_segments, 32)
                denoised_sp = denoised_sp.contiguous().view(1, fft_size, adj_num_frames)
                # denoised_sp: (1, 257, adj_num_frames)
                denoised_sp = denoised_sp[:, :, :num_frames]
                denoised_sp = denoised_sp.cpu()
                denoised_waveform = reconstruct_from_spectrogram(
                    denoised_sp, num_iterations=self.num_iterations
                )
                # denoised_waveform: (1, num_samples)
                results.append(denoised_waveform)
            cpu_results = torch.cat(results)
            return cpu_results if return_cpu_tensor else cpu_results.to(self.device)

    def process_audio_file(
        self, in_audio_file: str, out_audio_file: str, auto_scale: bool = False
    ):
        """
        Denoises an audio file.
        @param in_audio_file: An input audio file with a format supported by torchaudio.
        @param out_audio_file: Am output audio file with a format supported by torchaudio.
        @param auto_scale: Whether the input waveform scale should be normalized before processing. Recommended for low-volume audio.
        """
        waveform, sample_rate = torchaudio.load(in_audio_file)
        denoised_waveform = self.process_waveform(
            waveform, sample_rate, return_cpu_tensor=True, auto_scale=auto_scale
        )
        torchaudio.save(
            out_audio_file, denoised_waveform, sample_rate=self.model_sample_rate
        )