"""This file is modified version from mmsegmentation (https://github.com/open-mmlab/mmsegmentation)""" import torch import torch.nn as nn from torch.nn import functional as F class PPM(nn.ModuleList): """Pooling Pyramid Module used in PSPNet. Args: pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid Module. in_channels (int): Input channels. channels (int): Channels after modules, before conv_seg. conv_cfg (dict|None): Config of conv layers. norm_cfg (dict|None): Config of norm layers. act_cfg (dict): Config of activation layers. align_corners (bool): align_corners argument of F.interpolate. """ def __init__(self, pool_scales, in_channels, channels): super(PPM, self).__init__() self.pool_scales = pool_scales self.in_channels = in_channels self.channels = channels for pool_scale in pool_scales: self.append( nn.Sequential( nn.AdaptiveAvgPool2d(pool_scale), nn.Conv2d(self.in_channels, self.channels, kernel_size=1), nn.ReLU() ) ) def forward(self, x): """Forward function.""" ppm_outs = [] for ppm in self: ppm_out = ppm(x) upsampled_ppm_out = F.interpolate( ppm_out.float(), size=x.size()[2:], mode='bilinear', align_corners=False).to(torch.bfloat16) ppm_outs.append(upsampled_ppm_out) return ppm_outs class UPerHead(nn.Module): """Unified Perceptual Parsing for Scene Understanding. This head is the implementation of `UPerNet `_. Args: pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid Module applied on the last feature. Default: (1, 2, 3, 6). """ def __init__(self, in_channels = (96, 192, 384, 768), channels = 256, pool_scales=(1, 2, 3, 6),): super(UPerHead, self).__init__() # PSP Module self.in_channels = in_channels self.channels = channels self.psp_modules = PPM( pool_scales, self.in_channels[-1], self.channels ) self.bottleneck = nn.Sequential( nn.Conv2d(self.in_channels[-1] + len(pool_scales) * self.channels, self.channels, kernel_size=3, padding=1), nn.ReLU()) # FPN Module self.lateral_convs = nn.ModuleList() self.fpn_convs = nn.ModuleList() for in_channels in self.in_channels[:-1]: # skip the top layer l_conv = nn.Sequential( nn.Conv2d(in_channels, self.channels, kernel_size=1, padding=0), nn.ReLU()) fpn_conv = nn.Sequential( nn.Conv2d(self.channels, self.channels, kernel_size=3, padding=1), nn.ReLU()) self.lateral_convs.append(l_conv) self.fpn_convs.append(fpn_conv) self.fpn_bottleneck = nn.Sequential( nn.Conv2d(len(self.in_channels) * self.channels, self.channels, kernel_size=3, padding=1), nn.ReLU()) def psp_forward(self, inputs): """Forward function of PSP module.""" x = inputs[-1] psp_outs = [x] psp_outs.extend(self.psp_modules(x)) psp_outs = torch.cat(psp_outs, dim=1) output = self.bottleneck(psp_outs) return output def forward(self, inputs): """Forward function. inputs = {x_96, x_192, x_384, x_768} """ laterals = [ lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs) ] laterals.append(self.psp_forward(inputs)) # build top-down path used_backbone_levels = len(laterals) for i in range(used_backbone_levels - 1, 0, -1): prev_shape = laterals[i - 1].shape[2:] laterals[i - 1] = laterals[i - 1] + F.interpolate( laterals[i].float(), size = prev_shape, mode='bilinear', align_corners = False ).to(torch.bfloat16) # build outputs fpn_outs = [ self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1) ] # append psp feature fpn_outs.append(laterals[-1]) for i in range(used_backbone_levels - 1, 0, -1): fpn_outs[i] = F.interpolate( fpn_outs[i].float(), size=fpn_outs[0].shape[2:], mode='bilinear', align_corners=False).to(torch.bfloat16) fpn_outs = torch.cat(fpn_outs, dim=1) output = self.fpn_bottleneck(fpn_outs) return output