# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import torch.nn.functional as F from functools import partial from typing import Callable import collections from torch import Tensor from itertools import repeat # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) return tuple(repeat(x, n)) return parse def exists(val): return val is not None def default(val, d): return val if exists(val) else d to_2tuple = _ntuple(2) class ResidualBlock(nn.Module): """ ResidualBlock: construct a block of two conv layers with residual connections """ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d( in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros", ) self.conv2 = nn.Conv2d( planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros", ) self.relu = nn.ReLU(inplace=True) num_groups = planes // 8 if norm_fn == "group": self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) if not stride == 1: self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) elif norm_fn == "batch": self.norm1 = nn.BatchNorm2d(planes) self.norm2 = nn.BatchNorm2d(planes) if not stride == 1: self.norm3 = nn.BatchNorm2d(planes) elif norm_fn == "instance": self.norm1 = nn.InstanceNorm2d(planes) self.norm2 = nn.InstanceNorm2d(planes) if not stride == 1: self.norm3 = nn.InstanceNorm2d(planes) elif norm_fn == "none": self.norm1 = nn.Sequential() self.norm2 = nn.Sequential() if not stride == 1: self.norm3 = nn.Sequential() else: raise NotImplementedError if stride == 1: self.downsample = None else: self.downsample = nn.Sequential( nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3, ) def forward(self, x): y = x y = self.relu(self.norm1(self.conv1(y))) y = self.relu(self.norm2(self.conv2(y))) if self.downsample is not None: x = self.downsample(x) return self.relu(x + y) class Mlp(nn.Module): """MLP as used in Vision Transformer, MLP-Mixer and related networks""" def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, norm_layer=None, bias=True, drop=0.0, use_conv=False, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x class AttnBlock(nn.Module): def __init__( self, hidden_size, num_heads, attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, mlp_ratio=4.0, **block_kwargs ): """ Self attention block """ super().__init__() self.norm1 = nn.LayerNorm(hidden_size) self.norm2 = nn.LayerNorm(hidden_size) self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) mlp_hidden_dim = int(hidden_size * mlp_ratio) self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) def forward(self, x, mask=None): # Prepare the mask for PyTorch's attention (it expects a different format) # attn_mask = mask if mask is not None else None # Normalize before attention x = self.norm1(x) # PyTorch's MultiheadAttention returns attn_output, attn_output_weights # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) attn_output, _ = self.attn(x, x, x) # Add & Norm x = x + attn_output x = x + self.mlp(self.norm2(x)) return x class CrossAttnBlock(nn.Module): def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): """ Cross attention block """ super().__init__() self.norm1 = nn.LayerNorm(hidden_size) self.norm_context = nn.LayerNorm(hidden_size) self.norm2 = nn.LayerNorm(hidden_size) self.cross_attn = nn.MultiheadAttention( embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs ) mlp_hidden_dim = int(hidden_size * mlp_ratio) self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) def forward(self, x, context, mask=None): # Normalize inputs x = self.norm1(x) context = self.norm_context(context) # Apply cross attention # Note: nn.MultiheadAttention returns attn_output, attn_output_weights attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) # Add & Norm x = x + attn_output x = x + self.mlp(self.norm2(x)) return x