algohunt
initial_commit
c295391
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from typing import Callable
import collections
from torch import Tensor
from itertools import repeat
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
to_2tuple = _ntuple(2)
class ResidualBlock(nn.Module):
"""
ResidualBlock: construct a block of two conv layers with residual connections
"""
def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes,
planes,
kernel_size=kernel_size,
padding=1,
stride=stride,
padding_mode="zeros",
)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=kernel_size,
padding=1,
padding_mode="zeros",
)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
else:
raise NotImplementedError
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
self.norm3,
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class Mlp(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.0,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class AttnBlock(nn.Module):
def __init__(
self,
hidden_size,
num_heads,
attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
mlp_ratio=4.0,
**block_kwargs
):
"""
Self attention block
"""
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size)
self.norm2 = nn.LayerNorm(hidden_size)
self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
def forward(self, x, mask=None):
# Prepare the mask for PyTorch's attention (it expects a different format)
# attn_mask = mask if mask is not None else None
# Normalize before attention
x = self.norm1(x)
# PyTorch's MultiheadAttention returns attn_output, attn_output_weights
# attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
attn_output, _ = self.attn(x, x, x)
# Add & Norm
x = x + attn_output
x = x + self.mlp(self.norm2(x))
return x
class CrossAttnBlock(nn.Module):
def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
"""
Cross attention block
"""
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size)
self.norm_context = nn.LayerNorm(hidden_size)
self.norm2 = nn.LayerNorm(hidden_size)
self.cross_attn = nn.MultiheadAttention(
embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
def forward(self, x, context, mask=None):
# Normalize inputs
x = self.norm1(x)
context = self.norm_context(context)
# Apply cross attention
# Note: nn.MultiheadAttention returns attn_output, attn_output_weights
attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
# Add & Norm
x = x + attn_output
x = x + self.mlp(self.norm2(x))
return x