from abc import ABCMeta, abstractmethod
from typing import List, Optional, Tuple
from torch import Tensor

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from mmcv import ops
from mmcv.cnn import ConvModule, Linear
from mmengine.model import BaseModule

class BaseRoIExtractor(BaseModule, metaclass=ABCMeta):
    """Base class for RoI extractor.

    Args:
        roi_layer (:obj:`ConfigDict` or dict): Specify RoI layer type and
            arguments.
        out_channels (int): Output channels of RoI layers.
        featmap_strides (list[int]): Strides of input feature maps.
        init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
            dict], optional): Initialization config dict. Defaults to None.
    """

    def __init__(self,
                 roi_layer,
                 out_channels: int,
                 featmap_strides: List[int],
                 init_cfg=None) -> None:
        super().__init__(init_cfg=init_cfg)
        self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides)
        self.out_channels = out_channels
        self.featmap_strides = featmap_strides

    @property
    def num_inputs(self) -> int:
        """int: Number of input feature maps."""
        return len(self.featmap_strides)

    def build_roi_layers(self, layer_cfg,
                         featmap_strides: List[int]) -> nn.ModuleList:
        """Build RoI operator to extract feature from each level feature map.

        Args:
            layer_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and
                config RoI layer operation. Options are modules under
                ``mmcv/ops`` such as ``RoIAlign``.
            featmap_strides (list[int]): The stride of input feature map w.r.t
                to the original image size, which would be used to scale RoI
                coordinate (original image coordinate system) to feature
                coordinate system.

        Returns:
            :obj:`nn.ModuleList`: The RoI extractor modules for each level
                feature map.
        """

        cfg = layer_cfg.copy()
        layer_type = cfg.pop('type')
        if isinstance(layer_type, str):
            assert hasattr(ops, layer_type)
            layer_cls = getattr(ops, layer_type)
        else:
            layer_cls = layer_type
        roi_layers = nn.ModuleList(
            [layer_cls(spatial_scale=1 / s, **cfg) for s in featmap_strides])
        return roi_layers

    def roi_rescale(self, rois: Tensor, scale_factor: float) -> Tensor:
        """Scale RoI coordinates by scale factor.

        Args:
            rois (Tensor): RoI (Region of Interest), shape (n, 5)
            scale_factor (float): Scale factor that RoI will be multiplied by.

        Returns:
            Tensor: Scaled RoI.
        """

        cx = (rois[:, 1] + rois[:, 3]) * 0.5
        cy = (rois[:, 2] + rois[:, 4]) * 0.5
        w = rois[:, 3] - rois[:, 1]
        h = rois[:, 4] - rois[:, 2]
        new_w = w * scale_factor
        new_h = h * scale_factor
        x1 = cx - new_w * 0.5
        x2 = cx + new_w * 0.5
        y1 = cy - new_h * 0.5
        y2 = cy + new_h * 0.5
        new_rois = torch.stack((rois[:, 0], x1, y1, x2, y2), dim=-1)
        return new_rois

    @abstractmethod
    def forward(self,
                feats: Tuple[Tensor],
                rois: Tensor,
                roi_scale_factor: Optional[float] = None) -> Tensor:
        """Extractor ROI feats.

        Args:
            feats (Tuple[Tensor]): Multi-scale features.
            rois (Tensor): RoIs with the shape (n, 5) where the first
                column indicates batch id of each RoI.
            roi_scale_factor (Optional[float]): RoI scale factor.
                Defaults to None.

        Returns:
            Tensor: RoI feature.
        """
        pass


class MLVLFuseModule(nn.Module):
    def __init__(self, input_dims=1024, embed_dims=1024, num_levels=3, num_fuse=4):
        super(MLVLFuseModule, self).__init__()
        self.embed_dims = embed_dims
        self.num_levels = num_levels
        self.num_fuse = num_fuse
        self.input_dims = input_dims
        self.shuffle_channles = embed_dims // 4

        # contains the tuple of level indices that will do the interaction
        self.fuse_lvl_list = []
        num_levels = self.num_levels
        for lvl in range(num_levels):
            top_lvl = min(lvl + 1, num_levels - 1)
            dow_lvl = max(lvl - 1, 0)
            tar_lvl = lvl
            self.fuse_lvl_list.append((tar_lvl, top_lvl, dow_lvl))

        self.remain_chs = self.embed_dims - self.shuffle_channles * 2
        self._init_layers()

    def generate_coordinate(self, featmap_sizes, device='cuda'):

        x_range = torch.linspace(-1, 1, featmap_sizes[-1], device=device)
        y_range = torch.linspace(-1, 1, featmap_sizes[-2], device=device)
        y, x = torch.meshgrid(y_range, x_range)
        y = y.expand([featmap_sizes[0], 1, -1, -1])
        x = x.expand([featmap_sizes[0], 1, -1, -1])
        coord_feat = torch.cat([x, y], 1)

        return coord_feat

    def _init_layers(self):
        self.input_conv = nn.ModuleList([nn.Conv2d(self.input_dims + 2,
                                                   self.embed_dims, 1)
                                         for _ in range(self.num_levels)])
        self.fuse_convs = nn.ModuleList()
        for i in range(self.num_fuse):
            self.fuse_convs.append(
                ConvModule(self.embed_dims,
                           self.embed_dims,
                           3,
                           stride=1,
                           padding=3 // 2,
                           conv_cfg=None,
                           norm_cfg=dict(type='GN',
                                         num_groups=64,
                                         requires_grad=True)
                           ))

    def init_weights(self):
        pass

    def _single_shuffle(self, inputs, conv_module):
        if not isinstance(conv_module, (nn.ModuleList, list)):
            conv_module = [conv_module]
        for single_conv_m in conv_module:
            fused_inputs = []
            for fuse_lvl_tuple in self.fuse_lvl_list:
                tar_lvl, top_lvl, dow_lvl = fuse_lvl_tuple
                tar_input = inputs[tar_lvl]
                top_input = inputs[top_lvl]
                down_input = inputs[dow_lvl]
                remain = tar_input[:, :self.remain_chs]
                from_top = top_input[:, self.remain_chs:][:, self.shuffle_channles:]
                from_top = F.interpolate(from_top.to(torch.float32),
                                         size=tar_input.shape[-2:],
                                         mode='bilinear',
                                         align_corners=True)
                from_down = down_input[:, self.remain_chs:][:, :self.shuffle_channles]
                from_down = F.interpolate(from_down.to(torch.float32),
                                          size=tar_input.shape[-2:],
                                          mode='bilinear',
                                          align_corners=True)
                fused_inputs.append(
                    torch.cat([remain, from_top.to(remain.dtype), from_down.to(remain.dtype)], dim=1))
            fused_inputs = [single_conv_m(item) for item in fused_inputs]
            inputs = fused_inputs
        return inputs

    def forward(self, inputs, ):
        feat_size = [item.shape for item in inputs]
        new_inputs = []
        for feat, single_feat_size in zip(inputs, feat_size):
            coord_feat = self.generate_coordinate(
                single_feat_size, device=inputs[0].device)
            # feat = torch.cat([feat, coord_feat], dim=1)
            feat = torch.cat([feat, coord_feat.to(feat.dtype)], dim=1)
            new_inputs.append(feat)
        inputs = new_inputs

        inputs = [self.input_conv[lvl](item)
                  for lvl, item in enumerate(inputs)]

        for conv_m in self.fuse_convs:
            inputs = self._single_shuffle(inputs, [conv_m])
        return inputs


class MlvlRoIExtractor(BaseRoIExtractor):
    def __init__(self,
                 roi_layer,
                 out_channels,
                 featmap_strides,
                 embed_dims=1024,
                 stride=1,
                 norm_init=True,
                 fuse_level=3,
                 finest_scale=56,
                 init_cfg=None):
        super(MlvlRoIExtractor, self).__init__(roi_layer, out_channels,
                                               featmap_strides, init_cfg)
        self.embed_dims = embed_dims
        self.finest_scale = finest_scale
        self.fuse_level = fuse_level
        self.norm_init = norm_init

        self.pconvs = nn.ModuleList(
            nn.Conv2d(self.embed_dims, self.embed_dims, 3, stride=1, padding=1)
            for _ in range(self.fuse_level))
        self.pos_embedd = nn.Sequential(
            nn.Linear(4, 256),
            nn.ReLU(inplace=True),
            nn.LayerNorm(256),
            nn.Linear(256, 1024),
            nn.ReLU(inplace=True),
            nn.LayerNorm(1024),
        )
        self.updims = nn.Linear(1024, 4096)

        self.flatten_linear = nn.Linear(
            self.embed_dims * self.roi_layers[0].output_size[0] ** 2, 1024)

        self.norm_init_weights()

    #  self.dtype = torch.float32
    def norm_init_weights(self):
        pass

    def forward(self, feats, rois, roi_scale_factor=None):
        """Forward function."""
        num_imgs = len(rois)
        # feats = [item for item in feats]
        batch_rois = torch.cat(rois, dim=0).to(feats[0].dtype)
        pos_embedd = self.pos_embedd(batch_rois)
        out_size = self.roi_layers[0].output_size
        num_levels = len(feats)
        if feats[0].dim() == 3:
            h = w = int(math.sqrt(feats[0].shape[1]))
            assert h == 16
            assert w == 16
            b, c = feats[0].shape[0], feats[0].shape[-1]
            feats = [item.reshape(b, h, w, c).permute(0, 3, 1, 2)
                     for item in feats]
        new_rois = []
        for img_id, single_img_roi in enumerate(rois):
            # rescale to original img scale
            single_img_roi = single_img_roi * 224

            roi_img_id = single_img_roi.new_ones(len(single_img_roi)) * img_id
            single_img_roi = torch.cat(
                [roi_img_id[:, None], single_img_roi], dim=1)
            new_rois.append(single_img_roi)
        rois = torch.cat(new_rois)

        roi_feats = feats[0].new_zeros(self.fuse_level,
                                       rois.size(0), self.out_channels, *out_size)

        for i in range(num_levels):
            if len(rois) > 0:
                rois_ = rois
                ori_dtype = feats[i].dtype
                roi_feats_t = self.roi_layers[i](feats[i].to(
                    torch.float32), rois_.to(torch.float32))

                roi_feats[i] = roi_feats_t.to(ori_dtype)

            else:
                roi_feats += sum(
                    x.view(-1)[0]
                    for x in self.parameters()) * 0. + feats[i].sum() * 0.

        fuse_roi_feats = []
        for i in range(self.fuse_level):
            fuse_roi_feats.append(self.pconvs[i](roi_feats[i]))

        fuse_roi_feats = sum(fuse_roi_feats)
        fuse_roi_feats = F.relu(fuse_roi_feats)
        fuse_roi_feats = fuse_roi_feats.flatten(1, -1)
        fuse_roi_feats = self.flatten_linear(fuse_roi_feats)
        fuse_roi_feats = fuse_roi_feats + pos_embedd
        fuse_roi_feats = self.updims(fuse_roi_feats)
        query_feats = []
        for i in range(num_imgs):
            mask = rois[:, 0] == i
            query_feats.append(fuse_roi_feats[mask])

        return query_feats


class MLVLROIQueryModule(nn.Module):
    def __init__(self, embed_dims=1024, out_dims=4096,
                 num_levels=3):
        super(MLVLROIQueryModule, self).__init__()
        self.mlvl_fuse = MLVLFuseModule(input_dims=embed_dims,
                                        embed_dims=embed_dims,
                                        num_levels=num_levels,
                                        num_fuse=5)
        strids = [14 / 8, 14 / 4, 14 / 2, 14]
        assert len(strids) == num_levels
        bbox_roi_extractor = dict(roi_layer=dict(type='RoIAlign',
                                                 output_size=14,
                                                 sampling_ratio=2),
                                  out_channels=embed_dims,
                                  embed_dims=embed_dims,
                                  fuse_level=num_levels,
                                  featmap_strides=strids)

        self.roi_align = MlvlRoIExtractor(**bbox_roi_extractor)

    def forward(self, mlvl_feats, bboxes):
        if mlvl_feats[0].dim() == 3:
            h = w = int(math.sqrt(mlvl_feats[0].shape[1]))
            assert h == 24
            assert w == 24
            b, c = mlvl_feats[0].shape[0], mlvl_feats[0].shape[-1]
            mlvl_feats = [item.reshape(b, h, w, c).permute(0, 3, 1, 2) for item in mlvl_feats]
        base_shape = mlvl_feats[0].shape[-2:]
        num_level = len(mlvl_feats)
        to_shape = [(base_shape[0] * 2 ** level, base_shape[1] * 2 ** level)
                    for level in range(num_level)]
        to_shape = to_shape[::-1]
        for level in range(num_level):
            feat = mlvl_feats[level]
            shape = to_shape[level]
            # feat = feat
            # mlvl_feats[level] = F.interpolate(feat, size=shape, mode='bilinear', align_corners=True)
            # todo: temporary fix for "upsample_bilinear2d_out_frame" not implemented for 'BFloat16'
            feat = feat.to(torch.float32)
            mlvl_feats[level] = F.interpolate(
                feat, size=shape, mode='bilinear', align_corners=True)
            mlvl_feats[level] = mlvl_feats[level].to(torch.bfloat16)

        mlvl_feats = self.mlvl_fuse(mlvl_feats)

        return self.roi_align(mlvl_feats, bboxes)