import torch.nn.functional as F
from torch import nn


class PreactResBlock(nn.Sequential):
    def __init__(self, dim):
        super().__init__(
            nn.GroupNorm(dim // 16, dim),
            nn.GELU(),
            nn.Conv2d(dim, dim, 3, padding=1),
            nn.GroupNorm(dim // 16, dim),
            nn.GELU(),
            nn.Conv2d(dim, dim, 3, padding=1),
        )

    def forward(self, x):
        return x + super().forward(x)


class UNetBlock(nn.Module):
    def __init__(self, input_dim, output_dim=None, scale_factor=1.0):
        super().__init__()
        if output_dim is None:
            output_dim = input_dim
        self.pre_conv = nn.Conv2d(input_dim, output_dim, 3, padding=1)
        self.res_block1 = PreactResBlock(output_dim)
        self.res_block2 = PreactResBlock(output_dim)
        self.downsample = self.upsample = nn.Identity()
        if scale_factor > 1:
            self.upsample = nn.Upsample(scale_factor=scale_factor)
        elif scale_factor < 1:
            self.downsample = nn.Upsample(scale_factor=scale_factor)

    def forward(self, x, h=None):
        """
        Args:
            x: (b c h w), last output
            h: (b c h w), skip output
        Returns:
            o: (b c h w), output
            s: (b c h w), skip output
        """
        x = self.upsample(x)
        if h is not None:
            assert x.shape == h.shape, f"{x.shape} != {h.shape}"
            x = x + h
        x = self.pre_conv(x)
        x = self.res_block1(x)
        x = self.res_block2(x)
        return self.downsample(x), x


class UNet(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=16, num_blocks=4, num_middle_blocks=2):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_proj = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
        self.encoder_blocks = nn.ModuleList(
            [
                UNetBlock(input_dim=hidden_dim * 2**i, output_dim=hidden_dim * 2 ** (i + 1), scale_factor=0.5)
                for i in range(num_blocks)
            ]
        )
        self.middle_blocks = nn.ModuleList(
            [UNetBlock(input_dim=hidden_dim * 2**num_blocks) for _ in range(num_middle_blocks)]
        )
        self.decoder_blocks = nn.ModuleList(
            [
                UNetBlock(input_dim=hidden_dim * 2 ** (i + 1), output_dim=hidden_dim * 2**i, scale_factor=2)
                for i in reversed(range(num_blocks))
            ]
        )
        self.head = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.GELU(),
            nn.Conv2d(hidden_dim, output_dim, 1),
        )

    @property
    def scale_factor(self):
        return 2 ** len(self.encoder_blocks)

    def pad_to_fit(self, x):
        """
        Args:
            x: (b c h w), input
        Returns:
            x: (b c h' w'), padded input
        """
        hpad = (self.scale_factor - x.shape[2] % self.scale_factor) % self.scale_factor
        wpad = (self.scale_factor - x.shape[3] % self.scale_factor) % self.scale_factor
        return F.pad(x, (0, wpad, 0, hpad))

    def forward(self, x):
        """
        Args:
            x: (b c h w), input
        Returns:
            o: (b c h w), output
        """
        shape = x.shape

        x = self.pad_to_fit(x)
        x = self.input_proj(x)

        s_list = []
        for block in self.encoder_blocks:
            x, s = block(x)
            s_list.append(s)

        for block in self.middle_blocks:
            x, _ = block(x)

        for block, s in zip(self.decoder_blocks, reversed(s_list)):
            x, _ = block(x, s)

        x = self.head(x)
        x = x[..., : shape[2], : shape[3]]

        return x

    def test(self, shape=(3, 512, 256)):
        import ptflops

        macs, params = ptflops.get_model_complexity_info(
            self,
            shape,
            as_strings=True,
            print_per_layer_stat=True,
            verbose=True,
        )

        print(f"macs: {macs}")
        print(f"params: {params}")


def main():
    model = UNet(3, 3)
    model.test()


if __name__ == "__main__":
    main()