# This module is from [WeNet](https://github.com/wenet-e2e/wenet).

# ## Citations

# ```bibtex
# @inproceedings{yao2021wenet,
#   title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
#   author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
#   booktitle={Proc. Interspeech},
#   year={2021},
#   address={Brno, Czech Republic },
#   organization={IEEE}
# }

# @article{zhang2022wenet,
#   title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
#   author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
#   journal={arXiv preprint arXiv:2203.15455},
#   year={2022}
# }
#

from __future__ import print_function

import argparse
import os
import sys

import torch
import yaml
import logging

import torch.nn.functional as F
from wenet.utils.checkpoint import load_checkpoint
from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import TransformerDecoder
from wenet.transformer.encoder import BaseEncoder
from wenet.utils.init_model import init_model
from wenet.utils.mask import make_pad_mask

try:
    import onnxruntime
except ImportError:
    print("Please install onnxruntime-gpu!")
    sys.exit(1)

logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)


class Encoder(torch.nn.Module):
    def __init__(self, encoder: BaseEncoder, ctc: CTC, beam_size: int = 10):
        super().__init__()
        self.encoder = encoder
        self.ctc = ctc
        self.beam_size = beam_size

    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
    ):
        """Encoder
        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
        Returns:
            encoder_out: B x T x F
            encoder_out_lens: B
            ctc_log_probs: B x T x V
            beam_log_probs: B x T x beam_size
            beam_log_probs_idx: B x T x beam_size
        """
        encoder_out, encoder_mask = self.encoder(speech, speech_lengths, -1, -1)
        encoder_out_lens = encoder_mask.squeeze(1).sum(1)
        ctc_log_probs = self.ctc.log_softmax(encoder_out)
        encoder_out_lens = encoder_out_lens.int()
        beam_log_probs, beam_log_probs_idx = torch.topk(
            ctc_log_probs, self.beam_size, dim=2
        )
        return (
            encoder_out,
            encoder_out_lens,
            ctc_log_probs,
            beam_log_probs,
            beam_log_probs_idx,
        )


class StreamingEncoder(torch.nn.Module):
    def __init__(self, model, required_cache_size, beam_size, transformer=False):
        super().__init__()
        self.ctc = model.ctc
        self.subsampling_rate = model.encoder.embed.subsampling_rate
        self.embed = model.encoder.embed
        self.global_cmvn = model.encoder.global_cmvn
        self.required_cache_size = required_cache_size
        self.beam_size = beam_size
        self.encoder = model.encoder
        self.transformer = transformer

    def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask):
        """Streaming Encoder
        Args:
            xs (torch.Tensor): chunk input, with shape (b, time, mel-dim),
                where `time == (chunk_size - 1) * subsample_rate + \
                        subsample.right_context + 1`
            offset (torch.Tensor): offset with shape (b, 1)
                        1 is retained for triton deployment
            required_cache_size (int): cache size required for next chunk
                compuation
                > 0: actual cache size
                <= 0: not allowed in streaming gpu encoder                   `
            att_cache (torch.Tensor): cache tensor for KEY & VALUE in
                transformer/conformer attention, with shape
                (b, elayers, head, cache_t1, d_k * 2), where
                `head * d_k == hidden-dim` and
                `cache_t1 == chunk_size * num_decoding_left_chunks`.
            cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
                (b, elayers, b, hidden-dim, cache_t2), where
                `cache_t2 == cnn.lorder - 1`
            cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size)
                 in a batch of request, each request may have different
                 history cache. Cache mask is used to indidate the effective
                 cache for each request
        Returns:
            torch.Tensor: log probabilities of ctc output and cutoff by beam size
                with shape (b, chunk_size, beam)
            torch.Tensor: index of top beam size probabilities for each timestep
                with shape (b, chunk_size, beam)
            torch.Tensor: output of current input xs,
                with shape (b, chunk_size, hidden-dim).
            torch.Tensor: new attention cache required for next chunk, with
                same shape (b, elayers, head, cache_t1, d_k * 2)
                as the original att_cache
            torch.Tensor: new conformer cnn cache required for next chunk, with
                same shape as the original cnn_cache.
            torch.Tensor: new cache mask, with same shape as the original
                cache mask
        """
        offset = offset.squeeze(1)
        T = chunk_xs.size(1)
        chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1)
        # B X 1 X T
        chunk_mask = chunk_mask.to(chunk_xs.dtype)
        # transpose batch & num_layers dim
        att_cache = torch.transpose(att_cache, 0, 1)
        cnn_cache = torch.transpose(cnn_cache, 0, 1)

        # rewrite encoder.forward_chunk
        # <---------forward_chunk START--------->
        xs = self.global_cmvn(chunk_xs)
        # chunk mask is important for batch inferencing since
        # different sequence in a batch has different length
        xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset)
        cache_size = att_cache.size(3)  # required cache size
        masks = torch.cat((cache_mask, chunk_mask), dim=2)
        index = offset - cache_size

        pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1))
        pos_emb = pos_emb.to(dtype=xs.dtype)

        next_cache_start = -self.required_cache_size
        r_cache_mask = masks[:, :, next_cache_start:]

        r_att_cache = []
        r_cnn_cache = []
        for i, layer in enumerate(self.encoder.encoders):
            xs, _, new_att_cache, new_cnn_cache = layer(
                xs, masks, pos_emb, att_cache=att_cache[i], cnn_cache=cnn_cache[i]
            )
            #   shape(new_att_cache) is (B, head, attention_key_size, d_k * 2),
            #   shape(new_cnn_cache) is (B, hidden-dim, cache_t2)
            r_att_cache.append(new_att_cache[:, :, next_cache_start:, :].unsqueeze(1))
            if not self.transformer:
                r_cnn_cache.append(new_cnn_cache.unsqueeze(1))
        if self.encoder.normalize_before:
            chunk_out = self.encoder.after_norm(xs)
        else:
            chunk_out = xs

        r_att_cache = torch.cat(r_att_cache, dim=1)  # concat on layers idx
        if not self.transformer:
            r_cnn_cache = torch.cat(r_cnn_cache, dim=1)  # concat on layers

        # <---------forward_chunk END--------->

        log_ctc_probs = self.ctc.log_softmax(chunk_out)
        log_probs, log_probs_idx = torch.topk(log_ctc_probs, self.beam_size, dim=2)
        log_probs = log_probs.to(chunk_xs.dtype)

        r_offset = offset + chunk_out.shape[1]
        # the below ops not supported in Tensorrt
        # chunk_out_lens = torch.div(chunk_lens, subsampling_rate,
        #                   rounding_mode='floor')
        chunk_out_lens = chunk_lens // self.subsampling_rate
        r_offset = r_offset.unsqueeze(1)

        return (
            log_probs,
            log_probs_idx,
            chunk_out,
            chunk_out_lens,
            r_offset,
            r_att_cache,
            r_cnn_cache,
            r_cache_mask,
        )


class StreamingSqueezeformerEncoder(torch.nn.Module):
    def __init__(self, model, required_cache_size, beam_size):
        super().__init__()
        self.ctc = model.ctc
        self.subsampling_rate = model.encoder.embed.subsampling_rate
        self.embed = model.encoder.embed
        self.global_cmvn = model.encoder.global_cmvn
        self.required_cache_size = required_cache_size
        self.beam_size = beam_size
        self.encoder = model.encoder
        self.reduce_idx = model.encoder.reduce_idx
        self.recover_idx = model.encoder.recover_idx
        if self.reduce_idx is None:
            self.time_reduce = None
        else:
            if self.recover_idx is None:
                self.time_reduce = "normal"  # no recovery at the end
            else:
                self.time_reduce = "recover"  # recovery at the end
                assert len(self.reduce_idx) == len(self.recover_idx)

    def calculate_downsampling_factor(self, i: int) -> int:
        if self.reduce_idx is None:
            return 1
        else:
            reduce_exp, recover_exp = 0, 0
            for exp, rd_idx in enumerate(self.reduce_idx):
                if i >= rd_idx:
                    reduce_exp = exp + 1
            if self.recover_idx is not None:
                for exp, rc_idx in enumerate(self.recover_idx):
                    if i >= rc_idx:
                        recover_exp = exp + 1
            return int(2 ** (reduce_exp - recover_exp))

    def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask):
        """Streaming Encoder
        Args:
            xs (torch.Tensor): chunk input, with shape (b, time, mel-dim),
                where `time == (chunk_size - 1) * subsample_rate + \
                        subsample.right_context + 1`
            offset (torch.Tensor): offset with shape (b, 1)
                        1 is retained for triton deployment
            required_cache_size (int): cache size required for next chunk
                compuation
                > 0: actual cache size
                <= 0: not allowed in streaming gpu encoder                   `
            att_cache (torch.Tensor): cache tensor for KEY & VALUE in
                transformer/conformer attention, with shape
                (b, elayers, head, cache_t1, d_k * 2), where
                `head * d_k == hidden-dim` and
                `cache_t1 == chunk_size * num_decoding_left_chunks`.
            cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
                (b, elayers, b, hidden-dim, cache_t2), where
                `cache_t2 == cnn.lorder - 1`
            cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size)
                 in a batch of request, each request may have different
                 history cache. Cache mask is used to indidate the effective
                 cache for each request
        Returns:
            torch.Tensor: log probabilities of ctc output and cutoff by beam size
                with shape (b, chunk_size, beam)
            torch.Tensor: index of top beam size probabilities for each timestep
                with shape (b, chunk_size, beam)
            torch.Tensor: output of current input xs,
                with shape (b, chunk_size, hidden-dim).
            torch.Tensor: new attention cache required for next chunk, with
                same shape (b, elayers, head, cache_t1, d_k * 2)
                as the original att_cache
            torch.Tensor: new conformer cnn cache required for next chunk, with
                same shape as the original cnn_cache.
            torch.Tensor: new cache mask, with same shape as the original
                cache mask
        """
        offset = offset.squeeze(1)
        T = chunk_xs.size(1)
        chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1)
        # B X 1 X T
        chunk_mask = chunk_mask.to(chunk_xs.dtype)
        # transpose batch & num_layers dim
        att_cache = torch.transpose(att_cache, 0, 1)
        cnn_cache = torch.transpose(cnn_cache, 0, 1)

        # rewrite encoder.forward_chunk
        # <---------forward_chunk START--------->
        xs = self.global_cmvn(chunk_xs)
        # chunk mask is important for batch inferencing since
        # different sequence in a batch has different length
        xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset)
        elayers, cache_size = att_cache.size(0), att_cache.size(3)
        att_mask = torch.cat((cache_mask, chunk_mask), dim=2)
        index = offset - cache_size

        pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1))
        pos_emb = pos_emb.to(dtype=xs.dtype)

        next_cache_start = -self.required_cache_size
        r_cache_mask = att_mask[:, :, next_cache_start:]

        r_att_cache = []
        r_cnn_cache = []
        mask_pad = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool)
        mask_pad = mask_pad.unsqueeze(1)
        max_att_len: int = 0
        recover_activations: List[
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
        ] = []
        index = 0
        xs_lens = torch.tensor([xs.size(1)], device=xs.device, dtype=torch.int)
        xs = self.encoder.preln(xs)
        for i, layer in enumerate(self.encoder.encoders):
            if self.reduce_idx is not None:
                if self.time_reduce is not None and i in self.reduce_idx:
                    recover_activations.append((xs, att_mask, pos_emb, mask_pad))
                    xs, xs_lens, att_mask, mask_pad = self.encoder.time_reduction_layer(
                        xs, xs_lens, att_mask, mask_pad
                    )
                    pos_emb = pos_emb[:, ::2, :]
                    if self.encoder.pos_enc_layer_type == "rel_pos_repaired":
                        pos_emb = pos_emb[:, : xs.size(1) * 2 - 1, :]
                    index += 1

            if self.recover_idx is not None:
                if self.time_reduce == "recover" and i in self.recover_idx:
                    index -= 1
                    (
                        recover_tensor,
                        recover_att_mask,
                        recover_pos_emb,
                        recover_mask_pad,
                    ) = recover_activations[index]
                    # recover output length for ctc decode
                    xs = xs.unsqueeze(2).repeat(1, 1, 2, 1).flatten(1, 2)
                    xs = self.encoder.time_recover_layer(xs)
                    recoverd_t = recover_tensor.size(1)
                    xs = recover_tensor + xs[:, :recoverd_t, :].contiguous()
                    att_mask = recover_att_mask
                    pos_emb = recover_pos_emb
                    mask_pad = recover_mask_pad

            factor = self.calculate_downsampling_factor(i)

            xs, _, new_att_cache, new_cnn_cache = layer(
                xs,
                att_mask,
                pos_emb,
                att_cache=att_cache[i][:, :, ::factor, :][
                    :, :, : pos_emb.size(1) - xs.size(1), :
                ]
                if elayers > 0
                else att_cache[:, :, ::factor, :],
                cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache,
            )
            cached_att = new_att_cache[:, :, next_cache_start // factor :, :]
            cached_cnn = new_cnn_cache.unsqueeze(1)
            cached_att = (
                cached_att.unsqueeze(3).repeat(1, 1, 1, factor, 1).flatten(2, 3)
            )
            if i == 0:
                # record length for the first block as max length
                max_att_len = cached_att.size(2)
            r_att_cache.append(cached_att[:, :, :max_att_len, :].unsqueeze(1))
            r_cnn_cache.append(cached_cnn)

        chunk_out = xs
        r_att_cache = torch.cat(r_att_cache, dim=1)  # concat on layers idx
        r_cnn_cache = torch.cat(r_cnn_cache, dim=1)  # concat on layers

        # <---------forward_chunk END--------->

        log_ctc_probs = self.ctc.log_softmax(chunk_out)
        log_probs, log_probs_idx = torch.topk(log_ctc_probs, self.beam_size, dim=2)
        log_probs = log_probs.to(chunk_xs.dtype)

        r_offset = offset + chunk_out.shape[1]
        # the below ops not supported in Tensorrt
        # chunk_out_lens = torch.div(chunk_lens, subsampling_rate,
        #                   rounding_mode='floor')
        chunk_out_lens = chunk_lens // self.subsampling_rate
        r_offset = r_offset.unsqueeze(1)

        return (
            log_probs,
            log_probs_idx,
            chunk_out,
            chunk_out_lens,
            r_offset,
            r_att_cache,
            r_cnn_cache,
            r_cache_mask,
        )


class StreamingEfficientConformerEncoder(torch.nn.Module):
    def __init__(self, model, required_cache_size, beam_size):
        super().__init__()
        self.ctc = model.ctc
        self.subsampling_rate = model.encoder.embed.subsampling_rate
        self.embed = model.encoder.embed
        self.global_cmvn = model.encoder.global_cmvn
        self.required_cache_size = required_cache_size
        self.beam_size = beam_size
        self.encoder = model.encoder

        # Efficient Conformer
        self.stride_layer_idx = model.encoder.stride_layer_idx
        self.stride = model.encoder.stride
        self.num_blocks = model.encoder.num_blocks
        self.cnn_module_kernel = model.encoder.cnn_module_kernel

    def calculate_downsampling_factor(self, i: int) -> int:
        factor = 1
        for idx, stride_idx in enumerate(self.stride_layer_idx):
            if i > stride_idx:
                factor *= self.stride[idx]
        return factor

    def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask):
        """Streaming Encoder
        Args:
            chunk_xs (torch.Tensor): chunk input, with shape (b, time, mel-dim),
                where `time == (chunk_size - 1) * subsample_rate + \
                        subsample.right_context + 1`
            chunk_lens (torch.Tensor):
            offset (torch.Tensor): offset with shape (b, 1)
                        1 is retained for triton deployment
            att_cache (torch.Tensor): cache tensor for KEY & VALUE in
                transformer/conformer attention, with shape
                (b, elayers, head, cache_t1, d_k * 2), where
                `head * d_k == hidden-dim` and
                `cache_t1 == chunk_size * num_decoding_left_chunks`.
            cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
                (b, elayers, hidden-dim, cache_t2), where
                `cache_t2 == cnn.lorder - 1`
            cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size)
                 in a batch of request, each request may have different
                 history cache. Cache mask is used to indidate the effective
                 cache for each request
        Returns:
            torch.Tensor: log probabilities of ctc output and cutoff by beam size
                with shape (b, chunk_size, beam)
            torch.Tensor: index of top beam size probabilities for each timestep
                with shape (b, chunk_size, beam)
            torch.Tensor: output of current input xs,
                with shape (b, chunk_size, hidden-dim).
            torch.Tensor: new attention cache required for next chunk, with
                same shape (b, elayers, head, cache_t1, d_k * 2)
                as the original att_cache
            torch.Tensor: new conformer cnn cache required for next chunk, with
                same shape as the original cnn_cache.
            torch.Tensor: new cache mask, with same shape as the original
                cache mask
        """
        offset = offset.squeeze(1)  # (b, )
        offset *= self.calculate_downsampling_factor(self.num_blocks + 1)

        T = chunk_xs.size(1)
        chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1)  # (b, 1, T)
        # B X 1 X T
        chunk_mask = chunk_mask.to(chunk_xs.dtype)
        # transpose batch & num_layers dim
        #   Shape(att_cache): (elayers, b, head, cache_t1, d_k * 2)
        #   Shape(cnn_cache): (elayers, b, outsize, cnn_kernel)
        att_cache = torch.transpose(att_cache, 0, 1)
        cnn_cache = torch.transpose(cnn_cache, 0, 1)

        # rewrite encoder.forward_chunk
        # <---------forward_chunk START--------->
        xs = self.global_cmvn(chunk_xs)
        # chunk mask is important for batch inferencing since
        # different sequence in a batch has different length
        xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset)
        cache_size = att_cache.size(3)  # required cache size
        masks = torch.cat((cache_mask, chunk_mask), dim=2)
        att_mask = torch.cat((cache_mask, chunk_mask), dim=2)
        index = offset - cache_size

        pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1))
        pos_emb = pos_emb.to(dtype=xs.dtype)

        next_cache_start = -self.required_cache_size
        r_cache_mask = masks[:, :, next_cache_start:]

        r_att_cache = []
        r_cnn_cache = []
        mask_pad = chunk_mask.to(torch.bool)
        max_att_len, max_cnn_len = 0, 0  # for repeat_interleave of new_att_cache
        for i, layer in enumerate(self.encoder.encoders):
            factor = self.calculate_downsampling_factor(i)
            # NOTE(xcsong): Before layer.forward
            #   shape(att_cache[i:i + 1]) is (b, head, cache_t1, d_k * 2),
            #   shape(cnn_cache[i])       is (b=1, hidden-dim, cache_t2)
            # shape(new_att_cache) = [ batch, head, time2, outdim//head * 2 ]
            att_cache_trunc = 0
            if xs.size(1) + att_cache.size(3) / factor > pos_emb.size(1):
                # The time step is not divisible by the downsampling multiple
                # We propose to double the chunk_size.
                att_cache_trunc = (
                    xs.size(1) + att_cache.size(3) // factor - pos_emb.size(1) + 1
                )
            xs, _, new_att_cache, new_cnn_cache = layer(
                xs,
                att_mask,
                pos_emb,
                mask_pad=mask_pad,
                att_cache=att_cache[i][:, :, ::factor, :][:, :, att_cache_trunc:, :],
                cnn_cache=cnn_cache[i, :, :, :] if cnn_cache.size(0) > 0 else cnn_cache,
            )

            if i in self.stride_layer_idx:
                # compute time dimension for next block
                efficient_index = self.stride_layer_idx.index(i)
                att_mask = att_mask[
                    :, :: self.stride[efficient_index], :: self.stride[efficient_index]
                ]
                mask_pad = mask_pad[
                    :, :: self.stride[efficient_index], :: self.stride[efficient_index]
                ]
                pos_emb = pos_emb[:, :: self.stride[efficient_index], :]

            # shape(new_att_cache) = [batch, head, time2, outdim]
            new_att_cache = new_att_cache[:, :, next_cache_start // factor :, :]
            # shape(new_cnn_cache) = [batch, 1, outdim, cache_t2]
            new_cnn_cache = new_cnn_cache.unsqueeze(1)  # shape(1):layerID

            # use repeat_interleave to new_att_cache
            # new_att_cache = new_att_cache.repeat_interleave(repeats=factor, dim=2)
            new_att_cache = (
                new_att_cache.unsqueeze(3).repeat(1, 1, 1, factor, 1).flatten(2, 3)
            )
            # padding new_cnn_cache to cnn.lorder for casual convolution
            new_cnn_cache = F.pad(
                new_cnn_cache, (self.cnn_module_kernel - 1 - new_cnn_cache.size(3), 0)
            )

            if i == 0:
                # record length for the first block as max length
                max_att_len = new_att_cache.size(2)
                max_cnn_len = new_cnn_cache.size(3)

            # update real shape of att_cache and cnn_cache
            r_att_cache.append(new_att_cache[:, :, -max_att_len:, :].unsqueeze(1))
            r_cnn_cache.append(new_cnn_cache[:, :, :, -max_cnn_len:])

        if self.encoder.normalize_before:
            chunk_out = self.encoder.after_norm(xs)
        else:
            chunk_out = xs

        # shape of r_att_cache: (b, elayers, head, time2, outdim)
        r_att_cache = torch.cat(r_att_cache, dim=1)  # concat on layers idx
        # shape of r_cnn_cache: (b, elayers, outdim, cache_t2)
        r_cnn_cache = torch.cat(r_cnn_cache, dim=1)  # concat on layers

        # <---------forward_chunk END--------->

        log_ctc_probs = self.ctc.log_softmax(chunk_out)
        log_probs, log_probs_idx = torch.topk(log_ctc_probs, self.beam_size, dim=2)
        log_probs = log_probs.to(chunk_xs.dtype)

        r_offset = offset + chunk_out.shape[1]
        # the below ops not supported in Tensorrt
        # chunk_out_lens = torch.div(chunk_lens, subsampling_rate,
        #                   rounding_mode='floor')
        chunk_out_lens = (
            chunk_lens
            // self.subsampling_rate
            // self.calculate_downsampling_factor(self.num_blocks + 1)
        )
        chunk_out_lens += 1
        r_offset = r_offset.unsqueeze(1)

        return (
            log_probs,
            log_probs_idx,
            chunk_out,
            chunk_out_lens,
            r_offset,
            r_att_cache,
            r_cnn_cache,
            r_cache_mask,
        )


class Decoder(torch.nn.Module):
    def __init__(
        self,
        decoder: TransformerDecoder,
        ctc_weight: float = 0.5,
        reverse_weight: float = 0.0,
        beam_size: int = 10,
        decoder_fastertransformer: bool = False,
    ):
        super().__init__()
        self.decoder = decoder
        self.ctc_weight = ctc_weight
        self.reverse_weight = reverse_weight
        self.beam_size = beam_size
        self.decoder_fastertransformer = decoder_fastertransformer

    def forward(
        self,
        encoder_out: torch.Tensor,
        encoder_lens: torch.Tensor,
        hyps_pad_sos_eos: torch.Tensor,
        hyps_lens_sos: torch.Tensor,
        r_hyps_pad_sos_eos: torch.Tensor,
        ctc_score: torch.Tensor,
    ):
        """Encoder
        Args:
            encoder_out: B x T x F
            encoder_lens: B
            hyps_pad_sos_eos: B x beam x (T2+1),
                        hyps with sos & eos and padded by ignore id
            hyps_lens_sos: B x beam, length for each hyp with sos
            r_hyps_pad_sos_eos: B x beam x (T2+1),
                    reversed hyps with sos & eos and padded by ignore id
            ctc_score: B x beam, ctc score for each hyp
        Returns:
            decoder_out: B x beam x T2 x V
            r_decoder_out: B x beam x T2 x V
            best_index: B
        """
        B, T, F = encoder_out.shape
        bz = self.beam_size
        B2 = B * bz
        encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F)
        encoder_mask = ~make_pad_mask(encoder_lens, T).unsqueeze(1)
        encoder_mask = encoder_mask.repeat(1, bz, 1).view(B2, 1, T)
        T2 = hyps_pad_sos_eos.shape[2] - 1
        hyps_pad = hyps_pad_sos_eos.view(B2, T2 + 1)
        hyps_lens = hyps_lens_sos.view(
            B2,
        )
        hyps_pad_sos = hyps_pad[:, :-1].contiguous()
        hyps_pad_eos = hyps_pad[:, 1:].contiguous()

        r_hyps_pad = r_hyps_pad_sos_eos.view(B2, T2 + 1)
        r_hyps_pad_sos = r_hyps_pad[:, :-1].contiguous()
        r_hyps_pad_eos = r_hyps_pad[:, 1:].contiguous()

        decoder_out, r_decoder_out, _ = self.decoder(
            encoder_out,
            encoder_mask,
            hyps_pad_sos,
            hyps_lens,
            r_hyps_pad_sos,
            self.reverse_weight,
        )
        decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
        V = decoder_out.shape[-1]
        decoder_out = decoder_out.view(B2, T2, V)
        mask = ~make_pad_mask(hyps_lens, T2)  # B2 x T2
        # mask index, remove ignore id
        index = torch.unsqueeze(hyps_pad_eos * mask, 2)
        score = decoder_out.gather(2, index).squeeze(2)  # B2 X T2
        # mask padded part
        score = score * mask
        decoder_out = decoder_out.view(B, bz, T2, V)
        if self.reverse_weight > 0:
            r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
            r_decoder_out = r_decoder_out.view(B2, T2, V)
            index = torch.unsqueeze(r_hyps_pad_eos * mask, 2)
            r_score = r_decoder_out.gather(2, index).squeeze(2)
            r_score = r_score * mask
            score = score * (1 - self.reverse_weight) + self.reverse_weight * r_score
            r_decoder_out = r_decoder_out.view(B, bz, T2, V)
        score = torch.sum(score, axis=1)  # B2
        score = torch.reshape(score, (B, bz)) + self.ctc_weight * ctc_score
        best_index = torch.argmax(score, dim=1)
        if self.decoder_fastertransformer:
            return decoder_out, best_index
        else:
            return best_index


def to_numpy(tensors):
    out = []
    if type(tensors) == torch.tensor:
        tensors = [tensors]
    for tensor in tensors:
        if tensor.requires_grad:
            tensor = tensor.detach().cpu().numpy()
        else:
            tensor = tensor.cpu().numpy()
        out.append(tensor)
    return out


def test(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True):
    for a, b in zip(xlist, blist):
        try:
            torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
        except AssertionError as error:
            if tolerate_small_mismatch:
                print(error)
            else:
                raise


def export_offline_encoder(model, configs, args, logger, encoder_onnx_path):
    bz = 32
    seq_len = 100
    beam_size = args.beam_size
    feature_size = configs["input_dim"]

    speech = torch.randn(bz, seq_len, feature_size, dtype=torch.float32)
    speech_lens = torch.randint(low=10, high=seq_len, size=(bz,), dtype=torch.int32)
    encoder = Encoder(model.encoder, model.ctc, beam_size)
    encoder.eval()

    torch.onnx.export(
        encoder,
        (speech, speech_lens),
        encoder_onnx_path,
        export_params=True,
        opset_version=13,
        do_constant_folding=True,
        input_names=["speech", "speech_lengths"],
        output_names=[
            "encoder_out",
            "encoder_out_lens",
            "ctc_log_probs",
            "beam_log_probs",
            "beam_log_probs_idx",
        ],
        dynamic_axes={
            "speech": {0: "B", 1: "T"},
            "speech_lengths": {0: "B"},
            "encoder_out": {0: "B", 1: "T_OUT"},
            "encoder_out_lens": {0: "B"},
            "ctc_log_probs": {0: "B", 1: "T_OUT"},
            "beam_log_probs": {0: "B", 1: "T_OUT"},
            "beam_log_probs_idx": {0: "B", 1: "T_OUT"},
        },
        verbose=False,
    )

    with torch.no_grad():
        o0, o1, o2, o3, o4 = encoder(speech, speech_lens)

    providers = ["CUDAExecutionProvider"]
    ort_session = onnxruntime.InferenceSession(encoder_onnx_path, providers=providers)
    ort_inputs = {"speech": to_numpy(speech), "speech_lengths": to_numpy(speech_lens)}
    ort_outs = ort_session.run(None, ort_inputs)

    # check encoder output
    test(to_numpy([o0, o1, o2, o3, o4]), ort_outs)
    logger.info("export offline onnx encoder succeed!")
    onnx_config = {
        "beam_size": args.beam_size,
        "reverse_weight": args.reverse_weight,
        "ctc_weight": args.ctc_weight,
        "fp16": args.fp16,
    }
    return onnx_config


def export_online_encoder(model, configs, args, logger, encoder_onnx_path):
    decoding_chunk_size = args.decoding_chunk_size
    subsampling = model.encoder.embed.subsampling_rate
    context = model.encoder.embed.right_context + 1
    decoding_window = (decoding_chunk_size - 1) * subsampling + context
    batch_size = 32
    audio_len = decoding_window
    feature_size = configs["input_dim"]
    output_size = configs["encoder_conf"]["output_size"]
    num_layers = configs["encoder_conf"]["num_blocks"]
    # in transformer the cnn module will not be available
    transformer = False
    cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1) - 1
    if not cnn_module_kernel:
        transformer = True
    num_decoding_left_chunks = args.num_decoding_left_chunks
    required_cache_size = decoding_chunk_size * num_decoding_left_chunks
    if configs["encoder"] == "squeezeformer":
        encoder = StreamingSqueezeformerEncoder(
            model, required_cache_size, args.beam_size
        )
    elif configs["encoder"] == "efficientConformer":
        encoder = StreamingEfficientConformerEncoder(
            model, required_cache_size, args.beam_size
        )
    else:
        encoder = StreamingEncoder(
            model, required_cache_size, args.beam_size, transformer
        )
    encoder.eval()

    # begin to export encoder
    chunk_xs = torch.randn(batch_size, audio_len, feature_size, dtype=torch.float32)
    chunk_lens = torch.ones(batch_size, dtype=torch.int32) * audio_len

    offset = torch.arange(0, batch_size).unsqueeze(1)
    #  (elayers, b, head, cache_t1, d_k * 2)
    head = configs["encoder_conf"]["attention_heads"]
    d_k = configs["encoder_conf"]["output_size"] // head
    att_cache = torch.randn(
        batch_size, num_layers, head, required_cache_size, d_k * 2, dtype=torch.float32
    )
    cnn_cache = torch.randn(
        batch_size, num_layers, output_size, cnn_module_kernel, dtype=torch.float32
    )

    cache_mask = torch.ones(batch_size, 1, required_cache_size, dtype=torch.float32)
    input_names = [
        "chunk_xs",
        "chunk_lens",
        "offset",
        "att_cache",
        "cnn_cache",
        "cache_mask",
    ]
    output_names = [
        "log_probs",
        "log_probs_idx",
        "chunk_out",
        "chunk_out_lens",
        "r_offset",
        "r_att_cache",
        "r_cnn_cache",
        "r_cache_mask",
    ]
    input_tensors = (chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask)
    if transformer:
        output_names.pop(6)

    all_names = input_names + output_names
    dynamic_axes = {}
    for name in all_names:
        # only the first dimension is dynamic
        # all other dimension is fixed
        dynamic_axes[name] = {0: "B"}

    torch.onnx.export(
        encoder,
        input_tensors,
        encoder_onnx_path,
        export_params=True,
        opset_version=14,
        do_constant_folding=True,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
        verbose=False,
    )

    with torch.no_grad():
        torch_outs = encoder(
            chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask
        )
    if transformer:
        torch_outs = list(torch_outs).pop(6)
    ort_session = onnxruntime.InferenceSession(
        encoder_onnx_path, providers=["CUDAExecutionProvider"]
    )
    ort_inputs = {}

    input_tensors = to_numpy(input_tensors)
    for idx, name in enumerate(input_names):
        ort_inputs[name] = input_tensors[idx]
    if transformer:
        del ort_inputs["cnn_cache"]
    ort_outs = ort_session.run(None, ort_inputs)
    test(to_numpy(torch_outs), ort_outs, rtol=1e-03, atol=1e-05)
    logger.info("export to onnx streaming encoder succeed!")
    onnx_config = {
        "subsampling_rate": subsampling,
        "context": context,
        "decoding_chunk_size": decoding_chunk_size,
        "num_decoding_left_chunks": num_decoding_left_chunks,
        "beam_size": args.beam_size,
        "fp16": args.fp16,
        "feat_size": feature_size,
        "decoding_window": decoding_window,
        "cnn_module_kernel_cache": cnn_module_kernel,
    }
    return onnx_config


def export_rescoring_decoder(
    model, configs, args, logger, decoder_onnx_path, decoder_fastertransformer
):
    bz, seq_len = 32, 100
    beam_size = args.beam_size
    decoder = Decoder(
        model.decoder,
        model.ctc_weight,
        model.reverse_weight,
        beam_size,
        decoder_fastertransformer,
    )
    decoder.eval()

    hyps_pad_sos_eos = torch.randint(low=3, high=1000, size=(bz, beam_size, seq_len))
    hyps_lens_sos = torch.randint(
        low=3, high=seq_len, size=(bz, beam_size), dtype=torch.int32
    )
    r_hyps_pad_sos_eos = torch.randint(low=3, high=1000, size=(bz, beam_size, seq_len))

    output_size = configs["encoder_conf"]["output_size"]
    encoder_out = torch.randn(bz, seq_len, output_size, dtype=torch.float32)
    encoder_out_lens = torch.randint(low=3, high=seq_len, size=(bz,), dtype=torch.int32)
    ctc_score = torch.randn(bz, beam_size, dtype=torch.float32)

    input_names = [
        "encoder_out",
        "encoder_out_lens",
        "hyps_pad_sos_eos",
        "hyps_lens_sos",
        "r_hyps_pad_sos_eos",
        "ctc_score",
    ]
    output_names = ["best_index"]
    if decoder_fastertransformer:
        output_names.insert(0, "decoder_out")

    torch.onnx.export(
        decoder,
        (
            encoder_out,
            encoder_out_lens,
            hyps_pad_sos_eos,
            hyps_lens_sos,
            r_hyps_pad_sos_eos,
            ctc_score,
        ),
        decoder_onnx_path,
        export_params=True,
        opset_version=13,
        do_constant_folding=True,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes={
            "encoder_out": {0: "B", 1: "T"},
            "encoder_out_lens": {0: "B"},
            "hyps_pad_sos_eos": {0: "B", 2: "T2"},
            "hyps_lens_sos": {0: "B"},
            "r_hyps_pad_sos_eos": {0: "B", 2: "T2"},
            "ctc_score": {0: "B"},
            "best_index": {0: "B"},
        },
        verbose=False,
    )
    with torch.no_grad():
        o0 = decoder(
            encoder_out,
            encoder_out_lens,
            hyps_pad_sos_eos,
            hyps_lens_sos,
            r_hyps_pad_sos_eos,
            ctc_score,
        )
    providers = ["CUDAExecutionProvider"]
    ort_session = onnxruntime.InferenceSession(decoder_onnx_path, providers=providers)

    input_tensors = [
        encoder_out,
        encoder_out_lens,
        hyps_pad_sos_eos,
        hyps_lens_sos,
        r_hyps_pad_sos_eos,
        ctc_score,
    ]
    ort_inputs = {}
    input_tensors = to_numpy(input_tensors)
    for idx, name in enumerate(input_names):
        ort_inputs[name] = input_tensors[idx]

    # if model.reverse weight == 0,
    # the r_hyps_pad will be removed
    # from the onnx decoder since it doen't play any role
    if model.reverse_weight == 0:
        del ort_inputs["r_hyps_pad_sos_eos"]
    ort_outs = ort_session.run(None, ort_inputs)

    # check decoder output
    if decoder_fastertransformer:
        test(to_numpy(o0), ort_outs, rtol=1e-03, atol=1e-05)
    else:
        test(to_numpy([o0]), ort_outs, rtol=1e-03, atol=1e-05)
    logger.info("export to onnx decoder succeed!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="export x86_gpu model")
    parser.add_argument("--config", required=True, help="config file")
    parser.add_argument("--checkpoint", required=True, help="checkpoint model")
    parser.add_argument(
        "--cmvn_file",
        required=False,
        default="",
        type=str,
        help="global_cmvn file, default path is in config file",
    )
    parser.add_argument(
        "--reverse_weight",
        default=-1.0,
        type=float,
        required=False,
        help="reverse weight for bitransformer," + "default value is in config file",
    )
    parser.add_argument(
        "--ctc_weight",
        default=-1.0,
        type=float,
        required=False,
        help="ctc weight, default value is in config file",
    )
    parser.add_argument(
        "--beam_size",
        default=10,
        type=int,
        required=False,
        help="beam size would be ctc output size",
    )
    parser.add_argument(
        "--output_onnx_dir",
        default="onnx_model",
        help="output onnx encoder and decoder directory",
    )
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="whether to export fp16 model, default false",
    )
    # arguments for streaming encoder
    parser.add_argument(
        "--streaming",
        action="store_true",
        help="whether to export streaming encoder, default false",
    )
    parser.add_argument(
        "--decoding_chunk_size",
        default=16,
        type=int,
        required=False,
        help="the decoding chunk size, <=0 is not supported",
    )
    parser.add_argument(
        "--num_decoding_left_chunks",
        default=5,
        type=int,
        required=False,
        help="number of left chunks, <= 0 is not supported",
    )
    parser.add_argument(
        "--decoder_fastertransformer",
        action="store_true",
        help="return decoder_out and best_index for ft",
    )
    args = parser.parse_args()

    torch.manual_seed(0)
    torch.set_printoptions(precision=10)

    with open(args.config, "r") as fin:
        configs = yaml.load(fin, Loader=yaml.FullLoader)
    if args.cmvn_file and os.path.exists(args.cmvn_file):
        configs["cmvn_file"] = args.cmvn_file
    if args.reverse_weight != -1.0 and "reverse_weight" in configs["model_conf"]:
        configs["model_conf"]["reverse_weight"] = args.reverse_weight
        print("Update reverse weight to", args.reverse_weight)
    if args.ctc_weight != -1:
        print("Update ctc weight to ", args.ctc_weight)
        configs["model_conf"]["ctc_weight"] = args.ctc_weight
    configs["encoder_conf"]["use_dynamic_chunk"] = False

    model = init_model(configs)
    load_checkpoint(model, args.checkpoint)
    model.eval()

    if not os.path.exists(args.output_onnx_dir):
        os.mkdir(args.output_onnx_dir)
    encoder_onnx_path = os.path.join(args.output_onnx_dir, "encoder.onnx")
    export_enc_func = None
    if args.streaming:
        assert args.decoding_chunk_size > 0
        assert args.num_decoding_left_chunks > 0
        export_enc_func = export_online_encoder
    else:
        export_enc_func = export_offline_encoder

    onnx_config = export_enc_func(model, configs, args, logger, encoder_onnx_path)

    decoder_onnx_path = os.path.join(args.output_onnx_dir, "decoder.onnx")
    export_rescoring_decoder(
        model, configs, args, logger, decoder_onnx_path, args.decoder_fastertransformer
    )

    if args.fp16:
        try:
            import onnxmltools
            from onnxmltools.utils.float16_converter import convert_float_to_float16
        except ImportError:
            print("Please install onnxmltools!")
            sys.exit(1)
        encoder_onnx_model = onnxmltools.utils.load_model(encoder_onnx_path)
        encoder_onnx_model = convert_float_to_float16(encoder_onnx_model)
        encoder_onnx_path = os.path.join(args.output_onnx_dir, "encoder_fp16.onnx")
        onnxmltools.utils.save_model(encoder_onnx_model, encoder_onnx_path)
        decoder_onnx_model = onnxmltools.utils.load_model(decoder_onnx_path)
        decoder_onnx_model = convert_float_to_float16(decoder_onnx_model)
        decoder_onnx_path = os.path.join(args.output_onnx_dir, "decoder_fp16.onnx")
        onnxmltools.utils.save_model(decoder_onnx_model, decoder_onnx_path)
    # dump configurations

    config_dir = os.path.join(args.output_onnx_dir, "config.yaml")
    with open(config_dir, "w") as out:
        yaml.dump(onnx_config, out)